main.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Entry point of the application.
  15. This file serves as entry point to the training of UNet for segmentation of neuronal processes.
  16. Example:
  17. Training can be adjusted by modifying the arguments specified below::
  18. $ python main.py --exec_mode train --model_dir /datasets ...
  19. """
  20. import os
  21. import horovod.tensorflow as hvd
  22. import math
  23. import numpy as np
  24. import tensorflow as tf
  25. from PIL import Image
  26. from utils.setup import prepare_model_dir, get_logger, build_estimator, set_flags
  27. from utils.cmd_util import PARSER, parse_args
  28. from utils.data_loader import Dataset
  29. from utils.hooks.profiling_hook import ProfilingHook
  30. from utils.hooks.training_hook import TrainingHook
  31. def main(_):
  32. """
  33. Starting point of the application
  34. """
  35. hvd.init()
  36. set_flags()
  37. params = parse_args(PARSER.parse_args())
  38. model_dir = prepare_model_dir(params)
  39. logger = get_logger(params)
  40. estimator = build_estimator(params, model_dir)
  41. dataset = Dataset(data_dir=params.data_dir,
  42. batch_size=params.batch_size,
  43. fold=params.crossvalidation_idx,
  44. augment=params.augment,
  45. gpu_id=hvd.rank(),
  46. num_gpus=hvd.size(),
  47. seed=params.seed)
  48. if 'train' in params.exec_mode:
  49. max_steps = params.max_steps // (1 if params.benchmark else hvd.size())
  50. hooks = [hvd.BroadcastGlobalVariablesHook(0),
  51. TrainingHook(logger,
  52. max_steps=max_steps,
  53. log_every=params.log_every)]
  54. if params.benchmark and hvd.rank() == 0:
  55. hooks.append(ProfilingHook(logger,
  56. batch_size=params.batch_size,
  57. log_every=params.log_every,
  58. warmup_steps=params.warmup_steps,
  59. mode='train'))
  60. estimator.train(
  61. input_fn=dataset.train_fn,
  62. steps=max_steps,
  63. hooks=hooks)
  64. if 'evaluate' in params.exec_mode:
  65. if hvd.rank() == 0:
  66. results = estimator.evaluate(input_fn=dataset.eval_fn, steps=dataset.eval_size)
  67. logger.log(step=(),
  68. data={"eval_ce_loss": float(results["eval_ce_loss"]),
  69. "eval_dice_loss": float(results["eval_dice_loss"]),
  70. "eval_total_loss": float(results["eval_total_loss"]),
  71. "eval_dice_score": float(results["eval_dice_score"])})
  72. if 'predict' in params.exec_mode:
  73. if hvd.rank() == 0:
  74. predict_steps = dataset.test_size
  75. hooks = None
  76. if params.benchmark:
  77. hooks = [ProfilingHook(logger,
  78. batch_size=params.batch_size,
  79. log_every=params.log_every,
  80. warmup_steps=params.warmup_steps,
  81. mode="test")]
  82. predict_steps = params.warmup_steps * 2 * params.batch_size
  83. predictions = estimator.predict(
  84. input_fn=lambda: dataset.test_fn(count=math.ceil(predict_steps / dataset.test_size)),
  85. hooks=hooks)
  86. binary_masks = [np.argmax(p['logits'], axis=-1).astype(np.uint8) * 255 for p in predictions]
  87. if not params.benchmark:
  88. multipage_tif = [Image.fromarray(mask).resize(size=(512, 512), resample=Image.BILINEAR)
  89. for mask in binary_masks]
  90. output_dir = os.path.join(params.model_dir, 'pred')
  91. if not os.path.exists(output_dir):
  92. os.makedirs(output_dir)
  93. multipage_tif[0].save(os.path.join(output_dir, 'test-masks.tif'),
  94. compression="tiff_deflate",
  95. save_all=True,
  96. append_images=multipage_tif[1:])
  97. if __name__ == '__main__':
  98. tf.compat.v1.app.run()