training_hooks.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import time
  17. import tensorflow as tf
  18. import dllogger
  19. import signal
  20. from utils import hvd_wrapper as hvd
  21. __all__ = ['TrainingLoggingHook', 'TrainingPartitionHook']
  22. class MeanAccumulator:
  23. def __init__(self):
  24. self.sum = 0
  25. self.count = 0
  26. def consume(self, value):
  27. self.sum += value
  28. self.count += 1
  29. def value(self):
  30. if self.count:
  31. return self.sum / self.count
  32. else:
  33. return 0
  34. class TrainingLoggingHook(tf.estimator.SessionRunHook):
  35. def __init__(
  36. self, global_batch_size, num_steps, num_samples, num_epochs, steps_per_epoch, warmup_steps=20, logging_steps=1
  37. ):
  38. self.global_batch_size = global_batch_size
  39. self.num_steps = num_steps
  40. self.num_samples = num_samples
  41. self.num_epochs = num_epochs
  42. self.steps_per_epoch = steps_per_epoch
  43. self.warmup_steps = warmup_steps
  44. self.logging_steps = logging_steps
  45. self.current_step = 0
  46. self.current_epoch = 0
  47. self.t0 = None
  48. self.mean_throughput = MeanAccumulator()
  49. # Determines if its the last step of the epoch
  50. def _last_step_of_epoch(self, global_step):
  51. return (global_step + 1) // self.steps_per_epoch > (global_step // self.steps_per_epoch)
  52. def before_run(self, run_context):
  53. run_args = tf.train.SessionRunArgs(
  54. fetches=[
  55. tf.train.get_global_step(), 'cross_entropy_loss_ref:0', 'l2_loss_ref:0', 'total_loss_ref:0',
  56. 'learning_rate_ref:0'
  57. ]
  58. )
  59. self.t0 = time.time()
  60. return run_args
  61. def after_run(self, run_context, run_values):
  62. global_step, cross_entropy, l2_loss, total_loss, learning_rate = run_values.results
  63. batch_time = time.time() - self.t0
  64. ips = self.global_batch_size / batch_time
  65. metrics = {
  66. "imgs_per_sec": ips,
  67. "cross_entropy": cross_entropy,
  68. "l2_loss": l2_loss,
  69. "total_loss": total_loss,
  70. "learning_rate": learning_rate
  71. }
  72. if self.current_step >= self.warmup_steps:
  73. self.mean_throughput.consume(metrics['imgs_per_sec'])
  74. if (self.current_step % self.logging_steps) == 0:
  75. metrics = {k: float(v) for k, v in metrics.items()}
  76. dllogger.log(data=metrics, step=(int(global_step // self.steps_per_epoch), int(global_step)))
  77. self.current_step += 1
  78. if self._last_step_of_epoch(global_step):
  79. metrics = {
  80. "cross_entropy": cross_entropy,
  81. "l2_loss": l2_loss,
  82. "total_loss": total_loss,
  83. "learning_rate": learning_rate
  84. }
  85. metrics = {k: float(v) for k, v in metrics.items()}
  86. dllogger.log(data=metrics, step=(int(global_step // self.steps_per_epoch), ))
  87. self.current_epoch += 1
  88. class TrainingPartitionHook(tf.estimator.SessionRunHook):
  89. def __init__(self, sync_freq=10):
  90. super().__init__()
  91. self.signal_recieved = False
  92. self.sync_freq = sync_freq
  93. self.global_step = 0
  94. signal.signal(signal.SIGUSR1, self._signal_handler)
  95. signal.signal(signal.SIGTERM, self._signal_handler)
  96. def begin(self):
  97. if hvd.size() > 1:
  98. with tf.device("/cpu:0"):
  99. self.input_op = tf.placeholder(tf.int32, shape=())
  100. self.allreduce_op = hvd.hvd_global_object.allreduce(
  101. self.input_op, op=hvd.hvd_global_object.Sum, name="signal_handler_all_reduce")
  102. def before_run(self, run_context):
  103. fetches = [tf.train.get_global_step()]
  104. feed_dict = None
  105. if hvd.size() > 1 and (self.global_step % self.sync_freq) == 0:
  106. fetches += [self.allreduce_op]
  107. feed_dict = {self.input_op: int(self.signal_recieved)}
  108. return tf.train.SessionRunArgs(fetches, feed_dict=feed_dict)
  109. def after_run(self, run_context, run_values):
  110. self.global_step = run_values.results[0] + 1
  111. if hvd.size() > 1 and len(run_values.results) == 2:
  112. if run_values.results[1] > 0:
  113. run_context.request_stop()
  114. elif self.signal_recieved:
  115. run_context.request_stop()
  116. def _signal_handler(self, signum, frame):
  117. print("Stop signal received")
  118. self.signal_recieved = True