evaluate.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. Evaluation script used to calculate accuracy of trained model
  3. """
  4. import os
  5. import hydra
  6. from omegaconf import DictConfig
  7. import tensorflow as tf
  8. from tensorflow.keras import mixed_precision
  9. from data_generators import data_generator
  10. from utils.general_utils import join_paths, set_gpus, suppress_warnings
  11. from models.model import prepare_model
  12. from losses.loss import DiceCoefficient
  13. from losses.unet_loss import unet3p_hybrid_loss
  14. def evaluate(cfg: DictConfig):
  15. """
  16. Evaluate or calculate accuracy of given model
  17. """
  18. # suppress TensorFlow and DALI warnings
  19. suppress_warnings()
  20. if cfg.USE_MULTI_GPUS.VALUE:
  21. # change number of visible gpus for evaluation
  22. set_gpus(cfg.USE_MULTI_GPUS.GPU_IDS)
  23. # update batch size according to available gpus
  24. data_generator.update_batch_size(cfg)
  25. if cfg.OPTIMIZATION.AMP:
  26. print("Enabling Automatic Mixed Precision(AMP) training")
  27. policy = mixed_precision.Policy('mixed_float16')
  28. mixed_precision.set_global_policy(policy)
  29. if cfg.OPTIMIZATION.XLA:
  30. print("Enabling Automatic Mixed Precision(XLA) training")
  31. tf.config.optimizer.set_jit(True)
  32. # create model
  33. strategy = None
  34. if cfg.USE_MULTI_GPUS.VALUE:
  35. # multi gpu training using tensorflow mirrored strategy
  36. strategy = tf.distribute.MirroredStrategy(
  37. cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()
  38. )
  39. print('Number of visible gpu devices: {}'.format(strategy.num_replicas_in_sync))
  40. with strategy.scope():
  41. optimizer = tf.keras.optimizers.Adam(
  42. learning_rate=cfg.HYPER_PARAMETERS.LEARNING_RATE
  43. ) # optimizer
  44. if cfg.OPTIMIZATION.AMP:
  45. optimizer = mixed_precision.LossScaleOptimizer(
  46. optimizer,
  47. dynamic=True
  48. )
  49. dice_coef = DiceCoefficient(post_processed=True, classes=cfg.OUTPUT.CLASSES)
  50. dice_coef = tf.keras.metrics.MeanMetricWrapper(name="dice_coef", fn=dice_coef)
  51. model = prepare_model(cfg, training=True)
  52. else:
  53. optimizer = tf.keras.optimizers.Adam(
  54. learning_rate=cfg.HYPER_PARAMETERS.LEARNING_RATE
  55. ) # optimizer
  56. if cfg.OPTIMIZATION.AMP:
  57. optimizer = mixed_precision.LossScaleOptimizer(
  58. optimizer,
  59. dynamic=True
  60. )
  61. dice_coef = DiceCoefficient(post_processed=True, classes=cfg.OUTPUT.CLASSES)
  62. dice_coef = tf.keras.metrics.MeanMetricWrapper(name="dice_coef", fn=dice_coef)
  63. model = prepare_model(cfg, training=True)
  64. model.compile(
  65. optimizer=optimizer,
  66. loss=unet3p_hybrid_loss,
  67. metrics=[dice_coef],
  68. )
  69. # weights model path
  70. checkpoint_path = join_paths(
  71. cfg.WORK_DIR,
  72. cfg.CALLBACKS.MODEL_CHECKPOINT.PATH,
  73. f"{cfg.MODEL.WEIGHTS_FILE_NAME}.hdf5"
  74. )
  75. assert os.path.exists(checkpoint_path), \
  76. f"Model weight's file does not exist at \n{checkpoint_path}"
  77. # TODO: verify without augment it produces same results
  78. # load model weights
  79. model.load_weights(checkpoint_path, by_name=True, skip_mismatch=True)
  80. model.summary()
  81. # data generators
  82. val_generator = data_generator.get_data_generator(cfg, "VAL", strategy)
  83. validation_steps = data_generator.get_iterations(cfg, mode="VAL")
  84. # evaluation metric
  85. evaluation_metric = "dice_coef"
  86. if len(model.outputs) > 1:
  87. evaluation_metric = f"{model.output_names[0]}_dice_coef"
  88. result = model.evaluate(
  89. x=val_generator,
  90. steps=validation_steps,
  91. workers=cfg.DATALOADER_WORKERS,
  92. return_dict=True,
  93. )
  94. # return computed loss, validation accuracy, and it's metric name
  95. return result, evaluation_metric
  96. @hydra.main(version_base=None, config_path="configs", config_name="config")
  97. def main(cfg: DictConfig):
  98. """
  99. Read config file and pass to evaluate method
  100. """
  101. result, evaluation_metric = evaluate(cfg)
  102. print(result)
  103. print(f"Validation dice coefficient: {result[evaluation_metric]}")
  104. if __name__ == "__main__":
  105. main()