main.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import os
  17. import warnings
  18. warnings.simplefilter("ignore")
  19. import tensorflow as tf
  20. import horovod.tensorflow as hvd
  21. import dllogger
  22. from utils import hvd_utils
  23. from runtime import Runner
  24. from model.resnet import model_architectures
  25. from utils.cmdline_helper import parse_cmdline
  26. if __name__ == "__main__":
  27. tf.logging.set_verbosity(tf.logging.ERROR)
  28. FLAGS = parse_cmdline(model_architectures.keys())
  29. hvd.init()
  30. if hvd.rank() == 0:
  31. log_path = os.path.join(FLAGS.results_dir, FLAGS.log_filename)
  32. os.makedirs(FLAGS.results_dir, exist_ok=True)
  33. dllogger.init(
  34. backends=[
  35. dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE, filename=log_path),
  36. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
  37. ]
  38. )
  39. else:
  40. dllogger.init(backends=[])
  41. dllogger.log(data=vars(FLAGS), step='PARAMETER')
  42. runner = Runner(
  43. # ========= Model HParams ========= #
  44. n_classes=1001,
  45. architecture=FLAGS.arch,
  46. input_format='NHWC',
  47. compute_format=FLAGS.data_format,
  48. dtype=tf.float32 if FLAGS.precision == 'fp32' else tf.float16,
  49. n_channels=3,
  50. height=224,
  51. width=224,
  52. distort_colors=False,
  53. log_dir=FLAGS.results_dir,
  54. model_dir=FLAGS.model_dir if FLAGS.model_dir is not None else FLAGS.results_dir,
  55. data_dir=FLAGS.data_dir,
  56. data_idx_dir=FLAGS.data_idx_dir,
  57. weight_init=FLAGS.weight_init,
  58. use_xla=FLAGS.use_xla,
  59. use_tf_amp=FLAGS.use_tf_amp,
  60. use_dali=FLAGS.use_dali,
  61. gpu_memory_fraction=FLAGS.gpu_memory_fraction,
  62. gpu_id=FLAGS.gpu_id,
  63. seed=FLAGS.seed
  64. )
  65. if FLAGS.mode in ["train", "train_and_evaluate", "training_benchmark"]:
  66. runner.train(
  67. iter_unit=FLAGS.iter_unit,
  68. num_iter=FLAGS.num_iter,
  69. run_iter=FLAGS.run_iter,
  70. batch_size=FLAGS.batch_size,
  71. warmup_steps=FLAGS.warmup_steps,
  72. log_every_n_steps=FLAGS.display_every,
  73. weight_decay=FLAGS.weight_decay,
  74. lr_init=FLAGS.lr_init,
  75. lr_warmup_epochs=FLAGS.lr_warmup_epochs,
  76. momentum=FLAGS.momentum,
  77. loss_scale=FLAGS.loss_scale,
  78. label_smoothing=FLAGS.label_smoothing,
  79. mixup=FLAGS.mixup,
  80. use_static_loss_scaling=FLAGS.use_static_loss_scaling,
  81. use_cosine_lr=FLAGS.use_cosine_lr,
  82. is_benchmark=FLAGS.mode == 'training_benchmark',
  83. use_final_conv=FLAGS.use_final_conv,
  84. quantize=FLAGS.quantize,
  85. symmetric=FLAGS.symmetric,
  86. quant_delay = FLAGS.quant_delay,
  87. use_qdq = FLAGS.use_qdq,
  88. finetune_checkpoint=FLAGS.finetune_checkpoint,
  89. )
  90. if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:
  91. if FLAGS.mode == 'inference_benchmark' and hvd_utils.is_using_hvd():
  92. raise NotImplementedError("Only single GPU inference is implemented.")
  93. elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
  94. runner.evaluate(
  95. iter_unit=FLAGS.iter_unit if FLAGS.mode != "train_and_evaluate" else "epoch",
  96. num_iter=FLAGS.num_iter if FLAGS.mode != "train_and_evaluate" else 1,
  97. warmup_steps=FLAGS.warmup_steps,
  98. batch_size=FLAGS.batch_size,
  99. log_every_n_steps=FLAGS.display_every,
  100. is_benchmark=FLAGS.mode == 'inference_benchmark',
  101. export_dir=FLAGS.export_dir,
  102. quantize=FLAGS.quantize,
  103. symmetric=FLAGS.symmetric,
  104. use_final_conv=FLAGS.use_final_conv,
  105. use_qdq=FLAGS.use_qdq
  106. )
  107. if FLAGS.mode == 'predict':
  108. if FLAGS.to_predict is None:
  109. raise ValueError("No data to predict on.")
  110. if not os.path.isfile(FLAGS.to_predict):
  111. raise ValueError("Only prediction on single images is supported!")
  112. if hvd_utils.is_using_hvd():
  113. raise NotImplementedError("Only single GPU inference is implemented.")
  114. elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
  115. runner.predict(FLAGS.to_predict, quantize=FLAGS.quantize, symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq, use_final_conv=FLAGS.use_final_conv)