callback_builder.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # Copyright 2020 Google Research. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Callback related utils."""
  16. from concurrent import futures
  17. import os
  18. from mpi4py import MPI
  19. import time
  20. import numpy as np
  21. import tensorflow as tf
  22. import horovod.tensorflow.keras.callbacks as hvd_callbacks
  23. from tensorflow_addons.optimizers import MovingAverage
  24. from typeguard import typechecked
  25. from typing import Any, List, MutableMapping, Text
  26. from model import inference, optimizer_builder
  27. from utils import model_utils
  28. from model import efficientdet_keras, coco_metric, label_util, postprocess
  29. from utils.horovod_utils import get_world_size, is_main_process
  30. class DisplayCallback(tf.keras.callbacks.Callback):
  31. """Display inference result callback."""
  32. def __init__(self, sample_image, output_dir, update_freq=1):
  33. super().__init__()
  34. image_file = tf.io.read_file(sample_image)
  35. self.sample_image = tf.expand_dims(
  36. tf.image.decode_jpeg(image_file, channels=3), axis=0)
  37. self.executor = futures.ThreadPoolExecutor(max_workers=1)
  38. self.update_freq = update_freq
  39. self.output_dir = output_dir
  40. def set_model(self, model: tf.keras.Model):
  41. self.train_model = model
  42. with tf.device('/cpu:0'):
  43. self.model = efficientdet_keras.EfficientDetModel(config=model.config)
  44. height, width = model_utils.parse_image_size(model.config.image_size)
  45. self.model.build((1, height, width, 3))
  46. self.file_writer = tf.summary.create_file_writer(self.output_dir)
  47. self.min_score_thresh = self.model.config.nms_configs['score_thresh'] or 0.4
  48. self.max_boxes_to_draw = (
  49. self.model.config.nms_configs['max_output_size'] or 100)
  50. def on_epoch_end(self, epoch, logs=None):
  51. if epoch % self.update_freq == 0:
  52. self.executor.submit(self.draw_inference, epoch)
  53. @tf.function
  54. def inference(self):
  55. return self.model(self.sample_image, training=False)
  56. def draw_inference(self, epoch):
  57. self.model.set_weights(self.train_model.get_weights())
  58. boxes, scores, classes, valid_len = self.inference()
  59. length = valid_len[0]
  60. image = inference.visualize_image(
  61. self.sample_image[0],
  62. boxes[0].numpy()[:length],
  63. classes[0].numpy().astype(np.int)[:length],
  64. scores[0].numpy()[:length],
  65. label_map=self.model.config.label_map,
  66. min_score_thresh=self.min_score_thresh,
  67. max_boxes_to_draw=self.max_boxes_to_draw)
  68. with self.file_writer.as_default():
  69. tf.summary.image('Test image', tf.expand_dims(image, axis=0), step=epoch)
  70. class BatchTimestamp(object):
  71. """A structure to store batch time stamp."""
  72. def __init__(self, batch_index, timestamp):
  73. self.batch_index = batch_index
  74. self.timestamp = timestamp
  75. def __repr__(self):
  76. return "'BatchTimestamp<batch_index: {}, timestamp: {}>'".format(
  77. self.batch_index, self.timestamp)
  78. class TimeHistory(tf.keras.callbacks.Callback):
  79. """Callback for Keras models."""
  80. def __init__(self, batch_size, logger, log_steps=1, logdir=None):
  81. """Callback for logging performance.
  82. Args:
  83. batch_size: Total batch size.
  84. log_steps: Interval of steps between logging of batch level stats.
  85. logdir: Optional directory to write TensorBoard summaries.
  86. """
  87. # TODO(wcromar): remove this parameter and rely on `logs` parameter of
  88. # on_train_batch_end()
  89. self.batch_size = batch_size
  90. super(TimeHistory, self).__init__()
  91. self.log_steps = log_steps
  92. self.last_log_step = 0
  93. self.steps_before_epoch = 0
  94. self.steps_in_epoch = 0
  95. self.start_time = None
  96. self.logger = logger
  97. self.step_per_epoch = 0
  98. if logdir:
  99. self.summary_writer = tf.summary.create_file_writer(logdir)
  100. else:
  101. self.summary_writer = None
  102. # Logs start of step 1 then end of each step based on log_steps interval.
  103. self.timestamp_log = []
  104. # Records the time each epoch takes to run from start to finish of epoch.
  105. self.epoch_runtime_log = []
  106. self.latency = []
  107. self.throughput = []
  108. @property
  109. def global_steps(self):
  110. """The current 1-indexed global step."""
  111. return self.steps_before_epoch + self.steps_in_epoch
  112. @property
  113. def average_steps_per_second(self):
  114. """The average training steps per second across all epochs."""
  115. return (self.global_steps - self.step_per_epoch) / sum(self.epoch_runtime_log[1:])
  116. @property
  117. def average_examples_per_second(self):
  118. """The average number of training examples per second across all epochs."""
  119. # return self.average_steps_per_second * self.batch_size
  120. ind = int(0.1*len(self.throughput))
  121. return sum(self.throughput[ind:])/(len(self.throughput[ind:]))
  122. @property
  123. def average_time_per_iteration(self):
  124. """The average time per iteration in seconds across all epochs."""
  125. ind = int(0.1*len(self.latency))
  126. return sum(self.latency[ind:])/(len(self.latency[ind:]))
  127. def on_train_end(self, logs=None):
  128. self.train_finish_time = time.time()
  129. if self.summary_writer:
  130. self.summary_writer.flush()
  131. def on_epoch_begin(self, epoch, logs=None):
  132. self.epoch_start = time.time()
  133. def on_batch_begin(self, batch, logs=None):
  134. if not self.start_time:
  135. self.start_time = time.time()
  136. # Record the timestamp of the first global step
  137. if not self.timestamp_log:
  138. self.timestamp_log.append(BatchTimestamp(self.global_steps,
  139. self.start_time))
  140. def on_batch_end(self, batch, logs=None):
  141. """Records elapse time of the batch and calculates examples per second."""
  142. self.steps_in_epoch = batch + 1
  143. steps_since_last_log = self.global_steps - self.last_log_step
  144. if steps_since_last_log >= self.log_steps:
  145. now = time.time()
  146. elapsed_time = now - self.start_time
  147. steps_per_second = steps_since_last_log / elapsed_time
  148. examples_per_second = steps_per_second * self.batch_size
  149. self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
  150. elapsed_time_str='{:.2f} seconds'.format(elapsed_time)
  151. self.logger.log(step='PARAMETER', data={'Latency': elapsed_time_str, 'fps': examples_per_second, 'steps': (self.last_log_step, self.global_steps)})
  152. self.logger.flush()
  153. if self.summary_writer:
  154. with self.summary_writer.as_default():
  155. tf.summary.scalar('global_step/sec', steps_per_second,
  156. self.global_steps)
  157. tf.summary.scalar('examples/sec', examples_per_second,
  158. self.global_steps)
  159. self.last_log_step = self.global_steps
  160. self.start_time = None
  161. self.latency.append(elapsed_time)
  162. self.throughput.append(examples_per_second)
  163. def on_epoch_end(self, epoch, logs=None):
  164. if epoch == 0:
  165. self.step_per_epoch = self.steps_in_epoch
  166. epoch_run_time = time.time() - self.epoch_start
  167. self.epoch_runtime_log.append(epoch_run_time)
  168. self.steps_before_epoch += self.steps_in_epoch
  169. self.steps_in_epoch = 0
  170. class LRTensorBoard(tf.keras.callbacks.Callback):
  171. def __init__(self, log_dir, **kwargs):
  172. super().__init__(**kwargs)
  173. self.summary_writer = tf.summary.create_file_writer(log_dir)
  174. self.steps_before_epoch = 0
  175. self.steps_in_epoch = 0
  176. @property
  177. def global_steps(self):
  178. """The current 1-indexed global step."""
  179. return self.steps_before_epoch + self.steps_in_epoch
  180. def on_batch_end(self, batch, logs=None):
  181. self.steps_in_epoch = batch + 1
  182. lr = self.model.optimizer.lr(self.global_steps)
  183. with self.summary_writer.as_default():
  184. summary = tf.summary.scalar('learning_rate', lr, self.global_steps)
  185. def on_epoch_end(self, epoch, logs=None):
  186. self.steps_before_epoch += self.steps_in_epoch
  187. self.steps_in_epoch = 0
  188. def on_train_end(self, logs=None):
  189. self.summary_writer.flush()
  190. class LoggingCallback(tf.keras.callbacks.Callback):
  191. def on_train_batch_end(self, batch, logs=None):
  192. print("Iter: {}".format(batch))
  193. for var in self.model.variables:
  194. # if 'dense' in var.name:
  195. # continue
  196. print("Var: {} {}".format(var.name, var.value))
  197. try:
  198. slot = self.model.optimizer.get_slot(var, "average")
  199. print("Avg: {}".format(slot))
  200. except KeyError as e:
  201. print("{} does not have ema average slot".format(var.name))
  202. def fetch_optimizer(model,opt_type) -> tf.keras.optimizers.Optimizer:
  203. """Get the base optimizer used by the current model."""
  204. # this is the case where our target optimizer is not wrapped by any other optimizer(s)
  205. if isinstance(model.optimizer,opt_type):
  206. return model.optimizer
  207. # Dive into nested optimizer object until we reach the target opt
  208. opt = model.optimizer
  209. while hasattr(opt, '_optimizer'):
  210. opt = opt._optimizer
  211. if isinstance(opt,opt_type):
  212. return opt
  213. raise TypeError(f'Failed to find {opt_type} in the nested optimizer object')
  214. class MovingAverageCallback(tf.keras.callbacks.Callback):
  215. """A Callback to be used with a `MovingAverage` optimizer.
  216. Applies moving average weights to the model during validation time to test
  217. and predict on the averaged weights rather than the current model weights.
  218. Once training is complete, the model weights will be overwritten with the
  219. averaged weights (by default).
  220. Attributes:
  221. overwrite_weights_on_train_end: Whether to overwrite the current model
  222. weights with the averaged weights from the moving average optimizer.
  223. **kwargs: Any additional callback arguments.
  224. """
  225. def __init__(self,
  226. overwrite_weights_on_train_end: bool = False,
  227. **kwargs):
  228. super(MovingAverageCallback, self).__init__(**kwargs)
  229. self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
  230. self.ema_opt = None
  231. def set_model(self, model: tf.keras.Model):
  232. super(MovingAverageCallback, self).set_model(model)
  233. self.ema_opt = fetch_optimizer(model, MovingAverage)
  234. self.ema_opt.shadow_copy(self.model.weights)
  235. def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
  236. self.ema_opt.swap_weights()
  237. def on_test_end(self, logs: MutableMapping[Text, Any] = None):
  238. self.ema_opt.swap_weights()
  239. def on_train_end(self, logs: MutableMapping[Text, Any] = None):
  240. if self.overwrite_weights_on_train_end:
  241. self.ema_opt.assign_average_vars(self.model.variables)
  242. class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
  243. """Saves and, optionally, assigns the averaged weights.
  244. Taken from tfa.callbacks.AverageModelCheckpoint [original class].
  245. NOTE1: The original class has a type check decorator, which prevents passing non-string save_freq (fix: removed)
  246. NOTE2: The original class may not properly handle layered (nested) optimizer objects (fix: use fetch_optimizer)
  247. Attributes:
  248. update_weights: If True, assign the moving average weights
  249. to the model, and save them. If False, keep the old
  250. non-averaged weights, but the saved model uses the
  251. average weights.
  252. See `tf.keras.callbacks.ModelCheckpoint` for the other args.
  253. """
  254. def __init__(
  255. self,
  256. update_weights: bool,
  257. filepath: str,
  258. monitor: str = 'val_loss',
  259. verbose: int = 0,
  260. save_best_only: bool = False,
  261. save_weights_only: bool = False,
  262. mode: str = 'auto',
  263. save_freq: str = 'epoch',
  264. **kwargs):
  265. super().__init__(
  266. filepath,
  267. monitor,
  268. verbose,
  269. save_best_only,
  270. save_weights_only,
  271. mode,
  272. save_freq,
  273. **kwargs)
  274. self.update_weights = update_weights
  275. self.ema_opt = None
  276. def set_model(self, model):
  277. self.ema_opt = fetch_optimizer(model, MovingAverage)
  278. return super().set_model(model)
  279. def _save_model(self, epoch, batch, logs):
  280. assert isinstance(self.ema_opt, MovingAverage)
  281. if self.update_weights:
  282. self.ema_opt.assign_average_vars(self.model.variables)
  283. return super()._save_model(epoch, batch, logs)
  284. else:
  285. # Note: `model.get_weights()` gives us the weights (non-ref)
  286. # whereas `model.variables` returns references to the variables.
  287. non_avg_weights = self.model.get_weights()
  288. self.ema_opt.assign_average_vars(self.model.variables)
  289. # result is currently None, since `super._save_model` doesn't
  290. # return anything, but this may change in the future.
  291. result = super()._save_model(epoch, batch, logs)
  292. self.model.set_weights(non_avg_weights)
  293. return result
  294. class StopEarlyCallback(tf.keras.callbacks.Callback):
  295. def __init__(self, num_epochs, stop_75, **kwargs):
  296. super(StopEarlyCallback, self).__init__(**kwargs)
  297. self.num_epochs = num_epochs
  298. self.stop_75 = stop_75
  299. def on_epoch_end(self, epoch, logs=None):
  300. if ((epoch + 1) > (0.75 * self.num_epochs) and self.stop_75) or ((epoch + 1) == 300):
  301. self.model.stop_training = True
  302. class COCOEvalCallback(tf.keras.callbacks.Callback):
  303. def __init__(self, eval_dataset, eval_freq, start_eval_epoch, eval_params, logger, **kwargs):
  304. super(COCOEvalCallback, self).__init__(**kwargs)
  305. self.dataset = eval_dataset
  306. self.eval_freq = eval_freq
  307. self.start_eval_epoch = start_eval_epoch
  308. self.eval_params = eval_params
  309. self.ema_opt = None
  310. self.logger = logger
  311. label_map = label_util.get_label_map(eval_params['label_map'])
  312. self.evaluator = coco_metric.EvaluationMetric(
  313. filename=eval_params['val_json_file'], label_map=label_map)
  314. self.pbar = tf.keras.utils.Progbar(eval_params['num_samples'])
  315. def set_model(self, model):
  316. self.ema_opt = fetch_optimizer(model, MovingAverage)
  317. return super().set_model(model)
  318. @tf.function
  319. def eval_model_fn(self, images, labels):
  320. cls_outputs, box_outputs = self.model(images, training=False)
  321. detections = postprocess.generate_detections(self.eval_params, cls_outputs, box_outputs,
  322. labels['image_scales'],
  323. labels['source_ids'])
  324. tf.numpy_function(self.evaluator.update_state,
  325. [labels['groundtruth_data'],
  326. postprocess.transform_detections(detections)], [])
  327. def evaluate(self, epoch):
  328. if self.eval_params['moving_average_decay'] > 0:
  329. self.ema_opt.swap_weights() # get ema weights
  330. self.evaluator.reset_states()
  331. # evaluate all images.
  332. for i, (images, labels) in enumerate(self.dataset):
  333. self.eval_model_fn(images, labels)
  334. if is_main_process():
  335. self.pbar.update(i)
  336. # gather detections from all ranks
  337. self.evaluator.gather()
  338. # compute the final eval results.
  339. if is_main_process():
  340. metrics = self.evaluator.result()
  341. metric_dict = {}
  342. for i, name in enumerate(self.evaluator.metric_names):
  343. metric_dict[name] = metrics[i]
  344. # csv format
  345. csv_metrics = ['AP','AP50','AP75','APs','APm','APl']
  346. csv_format = ",".join([str(epoch+1)] + [str(round(metric_dict[key] * 100, 2)) for key in csv_metrics])
  347. print(metric_dict, "csv format:", csv_format)
  348. self.logger.log(step=(), data={'epoch': epoch+1,
  349. 'validation_accuracy_mAP': round(metric_dict['AP'] * 100, 2)})
  350. if self.eval_params['moving_average_decay'] > 0:
  351. self.ema_opt.swap_weights() # get base weights
  352. MPI.COMM_WORLD.Barrier()
  353. def on_epoch_end(self, epoch, logs=None):
  354. if (epoch + 1) >= self.start_eval_epoch and (epoch + 1) % self.eval_freq == 0:
  355. self.evaluate(epoch)
  356. def get_callbacks(
  357. params, training_mode, eval_params, eval_dataset, logger,
  358. time_history=True, log_steps=1, lr_tb=True, benchmark=False
  359. ):
  360. """Get callbacks for given params."""
  361. callbacks = []
  362. if is_main_process():
  363. if benchmark == False:
  364. tb_callback = tf.keras.callbacks.TensorBoard(
  365. log_dir=params['model_dir'], profile_batch=0, histogram_freq = 1)
  366. callbacks.append(tb_callback)
  367. if params['moving_average_decay']:
  368. emackpt_callback = AverageModelCheckpoint(
  369. filepath=os.path.join(params['model_dir'], 'ema_weights', 'emackpt-{epoch:02d}'),
  370. update_weights=False,
  371. amp=params['mixed_precision'],
  372. verbose=1,
  373. save_freq='epoch',
  374. save_weights_only=True,
  375. period=params['checkpoint_period'])
  376. callbacks.append(emackpt_callback)
  377. ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
  378. os.path.join(params['model_dir'], 'ckpt'),
  379. verbose=1,
  380. save_freq='epoch',
  381. save_weights_only=True,
  382. period=params['checkpoint_period'])
  383. callbacks.append(ckpt_callback)
  384. if time_history:
  385. time_callback = TimeHistory(params['batch_size'] * get_world_size(),
  386. logger=logger,
  387. logdir=params['model_dir'],
  388. log_steps=log_steps)
  389. callbacks.append(time_callback)
  390. # log LR in tensorboard
  391. if lr_tb == True and benchmark == False:
  392. callbacks.append(LRTensorBoard(log_dir=params['model_dir']))
  393. hvd_callback = hvd_callbacks.BroadcastGlobalVariablesCallback(0)
  394. callbacks.append(hvd_callback)
  395. # for large batch sizes training schedule of 350/400 epochs gives better mAP
  396. # but the best mAP is generally reached after 75% of the training schedule.
  397. # So we can stop training at that point or continue to train until 300 epochs
  398. stop_75 = False if 'eval' in training_mode or '300' in training_mode else True
  399. early_stopping = StopEarlyCallback(params['num_epochs'], stop_75=stop_75)
  400. callbacks.append(early_stopping)
  401. if 'eval' in training_mode:
  402. cocoeval = COCOEvalCallback(eval_dataset,
  403. eval_freq=params['checkpoint_period'],
  404. start_eval_epoch=200,
  405. eval_params=eval_params,
  406. logger=logger)
  407. callbacks.append(cocoeval)
  408. if params['moving_average_decay']:
  409. callbacks.append(MovingAverageCallback())
  410. if params.get('sample_image', None):
  411. display_callback = DisplayCallback(
  412. params.get('sample_image', None),
  413. os.path.join(params['model_dir'], 'train'))
  414. callbacks.append(display_callback)
  415. return callbacks