main.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ Entry point of the application.
  15. This file serves as entry point to the implementation of UNet3D for
  16. medical image segmentation.
  17. Example usage:
  18. $ python main.py --exec_mode train --data_dir ./data --batch_size 2
  19. --max_steps 1600 --amp
  20. All arguments are listed under `python main.py -h`.
  21. Full argument definition can be found in `arguments.py`.
  22. """
  23. import os
  24. import numpy as np
  25. import horovod.tensorflow as hvd
  26. from model.model_fn import unet_3d
  27. from dataset.data_loader import Dataset, CLASSES
  28. from runtime.hooks import get_hooks
  29. from runtime.arguments import PARSER
  30. from runtime.setup import build_estimator, set_flags, get_logger
  31. def parse_evaluation_results(result, logger, step=()):
  32. """
  33. Parse DICE scores from the evaluation results
  34. :param result: Dictionary with metrics collected by the optimizer
  35. :param logger: Logger object
  36. :return:
  37. """
  38. data = {CLASSES[i]: float(result[CLASSES[i]]) for i in range(len(CLASSES))}
  39. data['mean_dice'] = sum([result[CLASSES[i]] for i in range(len(CLASSES))]) / len(CLASSES)
  40. data['whole_tumor'] = float(result['whole_tumor'])
  41. if hvd.rank() == 0:
  42. logger.log(step=step, data=data)
  43. return data
  44. def main():
  45. """ Starting point of the application """
  46. hvd.init()
  47. set_flags()
  48. params = PARSER.parse_args()
  49. logger = get_logger(params)
  50. dataset = Dataset(data_dir=params.data_dir,
  51. batch_size=params.batch_size,
  52. fold_idx=params.fold,
  53. n_folds=params.num_folds,
  54. input_shape=params.input_shape,
  55. params=params)
  56. estimator = build_estimator(params=params, model_fn=unet_3d)
  57. hooks = get_hooks(params, logger)
  58. if 'train' in params.exec_mode:
  59. max_steps = params.max_steps // (1 if params.benchmark else hvd.size())
  60. estimator.train(
  61. input_fn=dataset.train_fn,
  62. steps=max_steps,
  63. hooks=hooks)
  64. if 'evaluate' in params.exec_mode:
  65. result = estimator.evaluate(input_fn=dataset.eval_fn, steps=dataset.eval_size)
  66. _ = parse_evaluation_results(result, logger)
  67. if params.exec_mode == 'predict':
  68. if hvd.rank() == 0:
  69. predictions = estimator.predict(
  70. input_fn=dataset.test_fn, hooks=hooks)
  71. for idx, pred in enumerate(predictions):
  72. volume = pred['predictions']
  73. if not params.benchmark:
  74. np.save(os.path.join(params.model_dir, "vol_{}.npy".format(idx)), volume)
  75. if __name__ == '__main__':
  76. main()