run.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from time import time
  15. import horovod.tensorflow as hvd
  16. import numpy as np
  17. import tensorflow as tf
  18. import tensorflow_addons as tfa
  19. from tensorflow.python.compiler.tensorrt import trt_convert as trt
  20. from runtime.checkpoint import CheckpointManager
  21. from runtime.losses import DiceCELoss, WeightDecay
  22. from runtime.metrics import Dice, MetricAggregator, make_class_logger_metrics
  23. from runtime.utils import is_main_process, make_empty_dir, progress_bar
  24. def update_best_metrics(old, new, start_time, iteration, watch_metric=None):
  25. did_change = False
  26. for metric, value in new.items():
  27. if metric not in old or old[metric]["value"] < value:
  28. old[metric] = {"value": value, "timestamp": time() - start_time, "iter": int(iteration)}
  29. if watch_metric == metric:
  30. did_change = True
  31. return did_change
  32. def get_scheduler(args, total_steps):
  33. scheduler = {
  34. "poly": tf.keras.optimizers.schedules.PolynomialDecay(
  35. initial_learning_rate=args.learning_rate,
  36. end_learning_rate=args.end_learning_rate,
  37. decay_steps=total_steps,
  38. power=0.9,
  39. ),
  40. "cosine": tf.keras.optimizers.schedules.CosineDecay(
  41. initial_learning_rate=args.learning_rate, decay_steps=total_steps
  42. ),
  43. "cosine_annealing": tf.keras.optimizers.schedules.CosineDecayRestarts(
  44. initial_learning_rate=args.learning_rate,
  45. first_decay_steps=args.cosine_annealing_first_cycle_steps,
  46. m_mul=args.cosine_annealing_peak_decay,
  47. ),
  48. "none": args.learning_rate,
  49. }[args.scheduler.lower()]
  50. return scheduler
  51. def get_optimizer(args, scheduler):
  52. optimizer = {
  53. "sgd": tf.keras.optimizers.SGD(learning_rate=scheduler, momentum=args.momentum),
  54. "adam": tf.keras.optimizers.Adam(learning_rate=scheduler),
  55. "radam": tfa.optimizers.RectifiedAdam(learning_rate=scheduler),
  56. }[args.optimizer.lower()]
  57. if args.lookahead:
  58. optimizer = tfa.optimizers.Lookahead(optimizer)
  59. if args.amp:
  60. optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=True)
  61. return optimizer
  62. def get_epoch_size(args, batch_size, dataset_size):
  63. if args.steps_per_epoch:
  64. return args.steps_per_epoch
  65. div = args.gpus * (batch_size if args.dim == 3 else args.nvol)
  66. return (dataset_size + div - 1) // div
  67. def process_performance_stats(timestamps, batch_size, mode):
  68. deltas = np.diff(timestamps)
  69. deltas_ms = 1000 * deltas
  70. throughput_imgps = (1000.0 * batch_size / deltas_ms).mean()
  71. stats = {f"throughput_{mode}": throughput_imgps, f"latency_{mode}_mean": deltas_ms.mean()}
  72. for level in [90, 95, 99]:
  73. stats.update({f"latency_{mode}_{level}": np.percentile(deltas_ms, level)})
  74. return stats
  75. def benchmark(args, step_fn, data, steps, warmup_steps, logger, mode="train"):
  76. assert steps > warmup_steps, "Number of benchmarked steps has to be greater then number of warmup steps"
  77. timestamps = []
  78. wrapped_data = progress_bar(
  79. enumerate(data),
  80. quiet=args.quiet,
  81. desc=f"Benchmark ({mode})",
  82. unit="step",
  83. postfix={"phase": "warmup"},
  84. total=steps,
  85. )
  86. for step, (images, labels) in wrapped_data:
  87. output_map = step_fn(images, labels, warmup_batch=step == 0)
  88. if mode == "predict":
  89. with tf.device("/device:CPU:0"):
  90. output_map = tf.experimental.numpy.copy(output_map)
  91. if step >= warmup_steps:
  92. timestamps.append(time())
  93. if step == warmup_steps and is_main_process() and not args.quiet:
  94. wrapped_data.set_postfix(phase="benchmark")
  95. if step >= steps:
  96. break
  97. stats = process_performance_stats(timestamps, args.gpus * args.batch_size, mode=mode)
  98. logger.log_metrics(stats)
  99. def train(args, model, dataset, logger):
  100. train_data = dataset.train_dataset()
  101. epochs = args.epochs
  102. batch_size = args.batch_size if args.dim == 3 else args.nvol
  103. steps_per_epoch = get_epoch_size(args, batch_size, dataset.train_size())
  104. total_steps = epochs * steps_per_epoch
  105. scheduler = get_scheduler(args, total_steps)
  106. optimizer = get_optimizer(args, scheduler)
  107. loss_fn = DiceCELoss(
  108. y_one_hot=True,
  109. reduce_batch=args.reduce_batch,
  110. include_background=args.include_background,
  111. )
  112. wdecay = WeightDecay(factor=args.weight_decay)
  113. tstep = tf.Variable(0)
  114. @tf.function
  115. def train_step_fn(features, labels, warmup_batch=False):
  116. features, labels = model.adjust_batch(features, labels)
  117. with tf.GradientTape() as tape:
  118. output_map = model(features)
  119. dice_loss = model.compute_loss(loss_fn, labels, output_map)
  120. loss = dice_loss + wdecay(model)
  121. if args.amp:
  122. loss = optimizer.get_scaled_loss(loss)
  123. tape = hvd.DistributedGradientTape(tape)
  124. gradients = tape.gradient(loss, model.trainable_variables)
  125. if args.amp:
  126. gradients = optimizer.get_unscaled_gradients(gradients)
  127. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  128. # Note: broadcast should be done after the first gradient step to ensure optimizer initialization.
  129. if warmup_batch:
  130. hvd.broadcast_variables(model.variables, root_rank=0)
  131. hvd.broadcast_variables(optimizer.variables(), root_rank=0)
  132. return dice_loss
  133. dice_metrics = MetricAggregator(name="dice")
  134. checkpoint = CheckpointManager(
  135. args.ckpt_dir,
  136. strategy=args.ckpt_strategy,
  137. resume_training=args.resume_training,
  138. variables={"model": model, "optimizer": optimizer, "step": tstep, **dice_metrics.checkpoint_metrics()},
  139. )
  140. if args.benchmark:
  141. benchmark(args, train_step_fn, train_data, args.bench_steps, args.warmup_steps, logger)
  142. else:
  143. wrapped_data = progress_bar(
  144. train_data,
  145. quiet=args.quiet,
  146. desc="Train",
  147. postfix={"epoch": 1},
  148. unit="step",
  149. total=total_steps - int(tstep),
  150. )
  151. start_time = time()
  152. total_train_loss, dice_score = 0.0, 0.0
  153. for images, labels in wrapped_data:
  154. if tstep >= total_steps:
  155. break
  156. tstep.assign_add(1)
  157. loss = train_step_fn(images, labels, warmup_batch=tstep == 1)
  158. total_train_loss += float(loss)
  159. lr = scheduler(tstep) if callable(scheduler) else scheduler
  160. metrics = {"loss": float(loss), "learning_rate": float(lr)}
  161. if tstep % steps_per_epoch == 0:
  162. epoch = int(tstep // steps_per_epoch)
  163. if epoch > args.skip_eval:
  164. dice = evaluate(args, model, dataset, logger)
  165. dice_score = tf.reduce_mean(dice[1:])
  166. did_improve = dice_metrics.update(dice_score)
  167. metrics = dice_metrics.logger_metrics()
  168. metrics.update(make_class_logger_metrics(dice))
  169. if did_improve:
  170. metrics["time_to_train"] = time() - start_time
  171. logger.log_metrics(metrics=metrics, step=int(tstep))
  172. checkpoint.update(float(dice_score))
  173. logger.flush()
  174. else:
  175. checkpoint.update(None)
  176. if is_main_process() and not args.quiet:
  177. wrapped_data.set_postfix(epoch=epoch + 1)
  178. elif tstep % steps_per_epoch == 0:
  179. total_train_loss = 0.0
  180. metrics = {
  181. "train_loss": round(total_train_loss / steps_per_epoch, 5),
  182. "val_loss": round(1 - float(dice_score), 5),
  183. "dice": round(float(dice_metrics.metrics["max"].result()), 5),
  184. }
  185. logger.log_metrics(metrics=metrics)
  186. logger.flush()
  187. def evaluate(args, model, dataset, logger):
  188. dice = Dice(n_class=model.n_class)
  189. data_size = dataset.val_size()
  190. wrapped_data = progress_bar(
  191. enumerate(dataset.val_dataset()),
  192. quiet=args.quiet,
  193. desc="Validation",
  194. unit="step",
  195. total=data_size,
  196. )
  197. for i, (features, labels) in wrapped_data:
  198. if args.dim == 2:
  199. features, labels = features[0], labels[0]
  200. output_map = model.inference(features)
  201. dice.update_state(output_map, labels)
  202. if i + 1 == data_size:
  203. break
  204. result = dice.result()
  205. if args.exec_mode == "evaluate":
  206. metrics = {
  207. "eval_dice": float(tf.reduce_mean(result)),
  208. "eval_dice_nobg": float(tf.reduce_mean(result[1:])),
  209. }
  210. logger.log_metrics(metrics)
  211. return result
  212. def predict(args, model, dataset, logger):
  213. if args.benchmark:
  214. def predict_bench_fn(features, labels, warmup_batch):
  215. if args.dim == 2:
  216. features = features[0]
  217. if args.sw_benchmark:
  218. output_map = model.inference(features)
  219. else:
  220. output_map = model(features, training=False)
  221. return output_map
  222. benchmark(
  223. args,
  224. predict_bench_fn,
  225. dataset.train_dataset(),
  226. args.bench_steps,
  227. args.warmup_steps,
  228. logger,
  229. mode="predict",
  230. )
  231. else:
  232. if args.save_preds:
  233. prec = "amp" if args.amp else "fp32"
  234. dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}_{prec}"
  235. if args.tta:
  236. dir_name += "_tta"
  237. save_dir = args.results / dir_name
  238. make_empty_dir(save_dir)
  239. data_size = dataset.test_size()
  240. wrapped_data = progress_bar(
  241. enumerate(dataset.test_dataset()),
  242. quiet=args.quiet,
  243. desc="Predict",
  244. unit="step",
  245. total=data_size,
  246. )
  247. for i, (images, meta) in wrapped_data:
  248. features, _ = model.adjust_batch(images, None)
  249. pred = model.inference(features, training=False)
  250. if args.save_preds:
  251. model.save_pred(pred, meta, idx=i, data_module=dataset, save_dir=save_dir)
  252. if i + 1 == data_size:
  253. break
  254. def export_model(args, model):
  255. checkpoint = tf.train.Checkpoint(model=model)
  256. checkpoint.restore(tf.train.latest_checkpoint(args.ckpt_dir)).expect_partial()
  257. input_shape = [1, *model.patch_size, model.n_class]
  258. dummy_input = tf.constant(tf.zeros(input_shape, dtype=tf.float32))
  259. _ = model(dummy_input, training=False)
  260. prec = "amp" if args.amp else "fp32"
  261. path = str(args.results / f"saved_model_task_{args.task}_dim_{args.dim}_{prec}")
  262. tf.keras.models.save_model(model, str(path))
  263. trt_prec = trt.TrtPrecisionMode.FP32 if prec == "fp32" else trt.TrtPrecisionMode.FP16
  264. converter = trt.TrtGraphConverterV2(
  265. input_saved_model_dir=path,
  266. conversion_params=trt.TrtConversionParams(precision_mode=trt_prec),
  267. )
  268. converter.convert()
  269. trt_path = str(args.results / f"trt_saved_model_task_{args.task}_dim_{args.dim}_{prec}")
  270. converter.save(trt_path)