train.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # Copyright 2020 Google Research. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """The main training script."""
  16. import os
  17. import time
  18. from mpi4py import MPI
  19. from absl import app
  20. from absl import flags
  21. from absl import logging
  22. import tensorflow as tf
  23. import horovod.tensorflow.keras as hvd
  24. from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
  25. import dllogger as DLLogger
  26. from model import anchors, callback_builder, coco_metric, dataloader
  27. from model import efficientdet_keras, label_util, optimizer_builder, postprocess
  28. from utils import hparams_config, model_utils, setup, train_lib, util_keras
  29. from utils.horovod_utils import is_main_process, get_world_size, get_rank
  30. # Model specific paramenters
  31. flags.DEFINE_string('training_mode', 'traineval', '(train/train300/traineval)')
  32. flags.DEFINE_string(
  33. 'training_file_pattern', None,
  34. 'Glob for training data files (e.g., COCO train - minival set)')
  35. flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name.')
  36. flags.DEFINE_string('model_dir', None, 'Location of model_dir')
  37. flags.DEFINE_integer('batch_size', 64, 'training local batch size')
  38. flags.DEFINE_integer('eval_batch_size', 64, 'evaluation local batch size')
  39. flags.DEFINE_integer('num_examples_per_epoch', 120000,
  40. 'Number of examples in one epoch (coco default is 117266)')
  41. flags.DEFINE_integer('num_epochs', None, 'Number of epochs for training')
  42. flags.DEFINE_bool('benchmark', False, 'Train for a fixed number of steps for performance')
  43. flags.DEFINE_integer('benchmark_steps', 100, 'Train for these many steps to benchmark training performance')
  44. flags.DEFINE_bool('use_fake_data', False, 'Use fake input.')
  45. flags.DEFINE_bool('use_xla', True, 'Use XLA')
  46. flags.DEFINE_bool('amp', True, 'Enable mixed precision training')
  47. flags.DEFINE_bool('set_num_threads', True, 'Set inter-op and intra-op parallelism threads')
  48. flags.DEFINE_string('log_filename', 'time_log.txt', 'Filename for dllogger logs')
  49. flags.DEFINE_integer('log_steps', 1, 'Interval of steps between logging of batch level stats')
  50. flags.DEFINE_bool('lr_tb', False, 'Log learning rate at each step to TB')
  51. flags.DEFINE_bool('enable_map_parallelization', True, 'Parallelize stateless map transformations in dataloader')
  52. flags.DEFINE_integer('checkpoint_period', 10, 'Save ema model weights after every X epochs for eval')
  53. flags.DEFINE_string('pretrained_ckpt', None,
  54. 'Start training from this EfficientDet checkpoint.')
  55. flags.DEFINE_string('backbone_init', None,
  56. 'Initialize backbone weights from checkpoint in this directory.')
  57. flags.DEFINE_string(
  58. 'hparams', '', 'Comma separated k=v pairs of hyperparameters or a module'
  59. ' containing attributes to use as hyperparameters.')
  60. flags.DEFINE_float('lr', None, 'Learning rate')
  61. flags.DEFINE_float('warmup_value', 0.0001, 'Initial warmup value')
  62. flags.DEFINE_float('warmup_epochs', None, 'Number of warmup epochs')
  63. flags.DEFINE_integer('seed', None, 'Random seed')
  64. flags.DEFINE_bool('debug', False, 'Enable debug mode')
  65. flags.DEFINE_bool('time_history', True, 'Get time history')
  66. flags.DEFINE_bool('validate', False, 'Get validation loss after each epoch')
  67. flags.DEFINE_string('val_file_pattern', None,
  68. 'Glob for eval tfrecords, e.g. coco/val-*.tfrecord.')
  69. flags.DEFINE_string(
  70. 'val_json_file', None,
  71. 'COCO validation JSON containing golden bounding boxes. If None, use the '
  72. 'ground truth from the dataloader. Ignored if testdev_dir is not None.')
  73. flags.DEFINE_string('testdev_dir', None,
  74. 'COCO testdev dir. If not None, ignorer val_json_file.')
  75. flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for '
  76. 'evaluation.')
  77. FLAGS = flags.FLAGS
  78. def main(_):
  79. # get e2e training time
  80. begin = time.time()
  81. logging.info("Training started at: {}".format(time.asctime()))
  82. hvd.init()
  83. # Parse and override hparams
  84. config = hparams_config.get_detection_config(FLAGS.model_name)
  85. config.override(FLAGS.hparams)
  86. if FLAGS.num_epochs: # NOTE: remove this flag after updating all docs.
  87. config.num_epochs = FLAGS.num_epochs
  88. if FLAGS.lr:
  89. config.learning_rate = FLAGS.lr
  90. if FLAGS.warmup_value:
  91. config.lr_warmup_init = FLAGS.warmup_value
  92. if FLAGS.warmup_epochs:
  93. config.lr_warmup_epoch = FLAGS.warmup_epochs
  94. config.backbone_init = FLAGS.backbone_init
  95. config.mixed_precision = FLAGS.amp
  96. config.image_size = model_utils.parse_image_size(config.image_size)
  97. # get eval config
  98. eval_config = hparams_config.get_detection_config(FLAGS.model_name)
  99. eval_config.override(FLAGS.hparams)
  100. eval_config.val_json_file = FLAGS.val_json_file
  101. eval_config.val_file_pattern = FLAGS.val_file_pattern
  102. eval_config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS
  103. eval_config.drop_remainder = False # eval all examples w/o drop.
  104. eval_config.image_size = model_utils.parse_image_size(eval_config['image_size'])
  105. # setup
  106. setup.set_flags(FLAGS, config, training=True)
  107. if FLAGS.debug:
  108. tf.config.experimental_run_functions_eagerly(True)
  109. tf.debugging.set_log_device_placement(True)
  110. tf.random.set_seed(111111)
  111. logging.set_verbosity(logging.DEBUG)
  112. # Check data path
  113. if FLAGS.training_file_pattern is None or FLAGS.val_file_pattern is None or FLAGS.val_json_file is None:
  114. raise RuntimeError('You must specify --training_file_pattern, --val_file_pattern and --val_json_file for training.')
  115. steps_per_epoch = (FLAGS.num_examples_per_epoch + (FLAGS.batch_size * get_world_size()) - 1) // (FLAGS.batch_size * get_world_size())
  116. if FLAGS.benchmark == True:
  117. # For ci perf training runs, run for a fixed number of iterations per epoch
  118. steps_per_epoch = FLAGS.benchmark_steps
  119. params = dict(
  120. config.as_dict(),
  121. model_name=FLAGS.model_name,
  122. model_dir=FLAGS.model_dir,
  123. steps_per_epoch=steps_per_epoch,
  124. checkpoint_period=FLAGS.checkpoint_period,
  125. batch_size=FLAGS.batch_size,
  126. num_shards=get_world_size(),
  127. val_json_file=FLAGS.val_json_file,
  128. testdev_dir=FLAGS.testdev_dir,
  129. mode='train')
  130. logging.info('Training params: {}'.format(params))
  131. # make output dir if it does not exist
  132. tf.io.gfile.makedirs(FLAGS.model_dir)
  133. # dllogger setup
  134. backends = []
  135. if is_main_process():
  136. log_path = os.path.join(FLAGS.model_dir, FLAGS.log_filename)
  137. backends+=[
  138. JSONStreamBackend(verbosity=Verbosity.VERBOSE, filename=log_path),
  139. StdOutBackend(verbosity=Verbosity.DEFAULT)]
  140. DLLogger.init(backends=backends)
  141. DLLogger.metadata('avg_fps_training', {'unit': 'images/s'})
  142. DLLogger.metadata('avg_fps_training_per_GPU', {'unit': 'images/s'})
  143. DLLogger.metadata('avg_latency_training', {'unit': 's'})
  144. DLLogger.metadata('training_loss', {'unit': None})
  145. DLLogger.metadata('e2e_training_time', {'unit': 's'})
  146. def get_dataset(is_training, params):
  147. file_pattern = (
  148. FLAGS.training_file_pattern
  149. if is_training else FLAGS.val_file_pattern)
  150. if not file_pattern:
  151. raise ValueError('No matching files.')
  152. return dataloader.InputReader(
  153. file_pattern,
  154. is_training=is_training,
  155. use_fake_data=FLAGS.use_fake_data,
  156. max_instances_per_image=config.max_instances_per_image,
  157. enable_map_parallelization=FLAGS.enable_map_parallelization)(
  158. params)
  159. num_samples = (FLAGS.eval_samples + get_world_size() - 1) // get_world_size()
  160. num_samples = (num_samples + FLAGS.eval_batch_size - 1) // FLAGS.eval_batch_size
  161. eval_config.num_samples = num_samples
  162. def get_eval_dataset(eval_config):
  163. dataset = dataloader.InputReader(
  164. FLAGS.val_file_pattern,
  165. is_training=False,
  166. max_instances_per_image=eval_config.max_instances_per_image)(
  167. eval_config, batch_size=FLAGS.eval_batch_size)
  168. dataset = dataset.shard(get_world_size(), get_rank())
  169. dataset = dataset.take(num_samples)
  170. return dataset
  171. eval_dataset = get_eval_dataset(eval_config)
  172. # pick focal loss implementation
  173. focal_loss = train_lib.StableFocalLoss(
  174. params['alpha'],
  175. params['gamma'],
  176. label_smoothing=params['label_smoothing'],
  177. reduction=tf.keras.losses.Reduction.NONE)
  178. model = train_lib.EfficientDetNetTrain(params['model_name'], config)
  179. model.build((None, *config.image_size, 3))
  180. model.compile(
  181. optimizer=optimizer_builder.get_optimizer(params),
  182. loss={
  183. 'box_loss':
  184. train_lib.BoxLoss(
  185. params['delta'], reduction=tf.keras.losses.Reduction.NONE),
  186. 'box_iou_loss':
  187. train_lib.BoxIouLoss(
  188. params['iou_loss_type'],
  189. params['min_level'],
  190. params['max_level'],
  191. params['num_scales'],
  192. params['aspect_ratios'],
  193. params['anchor_scale'],
  194. params['image_size'],
  195. reduction=tf.keras.losses.Reduction.NONE),
  196. 'class_loss': focal_loss,
  197. 'seg_loss':
  198. tf.keras.losses.SparseCategoricalCrossentropy(
  199. from_logits=True,
  200. reduction=tf.keras.losses.Reduction.NONE)
  201. })
  202. train_from_epoch = util_keras.restore_ckpt(model, params['model_dir'],
  203. config.moving_average_decay, steps_per_epoch=steps_per_epoch)
  204. print("training_mode: {}".format(FLAGS.training_mode))
  205. callbacks = callback_builder.get_callbacks(params, FLAGS.training_mode, eval_config, eval_dataset,
  206. DLLogger, FLAGS.time_history, FLAGS.log_steps, FLAGS.lr_tb, FLAGS.benchmark)
  207. history = model.fit(
  208. get_dataset(True, params=params),
  209. epochs=params['num_epochs'],
  210. steps_per_epoch=steps_per_epoch,
  211. initial_epoch=train_from_epoch,
  212. callbacks=callbacks,
  213. verbose=1 if is_main_process() else 0,
  214. validation_data=get_dataset(False, params=params) if FLAGS.validate else None,
  215. validation_steps=(FLAGS.eval_samples // FLAGS.eval_batch_size) if FLAGS.validate else None)
  216. if is_main_process():
  217. model.save_weights(os.path.join(FLAGS.model_dir, 'ckpt-final'))
  218. # log final stats
  219. stats = {}
  220. for callback in callbacks:
  221. if isinstance(callback, callback_builder.TimeHistory):
  222. if callback.epoch_runtime_log:
  223. stats['avg_fps_training'] = callback.average_examples_per_second
  224. stats['avg_fps_training_per_GPU'] = callback.average_examples_per_second / get_world_size()
  225. stats['avg_latency_training'] = callback.average_time_per_iteration
  226. if history and history.history:
  227. train_hist = history.history
  228. #Gets final loss from training.
  229. stats['training_loss'] = float(hvd.allreduce(tf.constant(train_hist['loss'][-1], dtype=tf.float32), average=True))
  230. if os.path.exists(os.path.join(FLAGS.model_dir,'ema_weights')):
  231. ckpt_epoch = "%02d" % sorted(set([int(f.rsplit('.')[0].rsplit('-')[1])
  232. for f in os.listdir(os.path.join(FLAGS.model_dir,'ema_weights'))
  233. if 'emackpt' in f]), reverse=True)[0]
  234. ckpt = os.path.join(FLAGS.model_dir, 'ema_weights', 'emackpt-' + str(ckpt_epoch))
  235. util_keras.restore_ckpt(model, ckpt, eval_config.moving_average_decay,
  236. steps_per_epoch=0, skip_mismatch=False, expect_partial=True)
  237. if is_main_process():
  238. model.save(os.path.join(FLAGS.model_dir, 'emackpt-final'))
  239. else:
  240. ckpt_epoch = 'final'
  241. ckpt = os.path.join(FLAGS.model_dir, 'ckpt-' + ckpt_epoch)
  242. if is_main_process():
  243. model.save(os.path.join(FLAGS.model_dir, 'ckpt-' + ckpt_epoch))
  244. # Start evaluation of final ema checkpoint
  245. logging.set_verbosity(logging.WARNING)
  246. @tf.function
  247. def model_fn(images, labels):
  248. cls_outputs, box_outputs = model(images, training=False)
  249. detections = postprocess.generate_detections(eval_config, cls_outputs, box_outputs,
  250. labels['image_scales'],
  251. labels['source_ids'])
  252. tf.numpy_function(evaluator.update_state,
  253. [labels['groundtruth_data'],
  254. postprocess.transform_detections(detections)], [])
  255. if FLAGS.benchmark == False and (FLAGS.training_mode == 'train' or FLAGS.num_epochs < 200):
  256. # Evaluator for AP calculation.
  257. label_map = label_util.get_label_map(eval_config.label_map)
  258. evaluator = coco_metric.EvaluationMetric(
  259. filename=eval_config.val_json_file, label_map=label_map)
  260. evaluator.reset_states()
  261. # evaluate all images.
  262. pbar = tf.keras.utils.Progbar(num_samples)
  263. for i, (images, labels) in enumerate(eval_dataset):
  264. model_fn(images, labels)
  265. if is_main_process():
  266. pbar.update(i)
  267. # gather detections from all ranks
  268. evaluator.gather()
  269. if is_main_process():
  270. # compute the final eval results.
  271. metrics = evaluator.result()
  272. metric_dict = {}
  273. for i, name in enumerate(evaluator.metric_names):
  274. metric_dict[name] = metrics[i]
  275. if label_map:
  276. for i, cid in enumerate(sorted(label_map.keys())):
  277. name = 'AP_/%s' % label_map[cid]
  278. metric_dict[name] = metrics[i + len(evaluator.metric_names)]
  279. # csv format
  280. csv_metrics = ['AP','AP50','AP75','APs','APm','APl']
  281. csv_format = ",".join([str(ckpt_epoch)] + [str(round(metric_dict[key] * 100, 2)) for key in csv_metrics])
  282. print(FLAGS.model_name, metric_dict, "csv format:", csv_format)
  283. DLLogger.log(step=(), data={'epoch': ckpt_epoch,
  284. 'validation_accuracy_mAP': round(metric_dict['AP'] * 100, 2)})
  285. DLLogger.flush()
  286. MPI.COMM_WORLD.Barrier()
  287. if is_main_process():
  288. stats['e2e_training_time'] = time.time() - begin
  289. DLLogger.log(step=(), data=stats)
  290. DLLogger.flush()
  291. if __name__ == '__main__':
  292. logging.set_verbosity(logging.INFO)
  293. app.run(main)