cmdline_helper.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # ==============================================================================
  4. #
  5. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. #
  19. # ==============================================================================
  20. import argparse
  21. from datasets import known_datasets
  22. from model.unet import UNet_v1
  23. from model.blocks.activation_blck import authorized_activation_fn
  24. def _add_bool_argument(parser, name=None, default=False, required=False, help=None):
  25. if not isinstance(default, bool):
  26. raise ValueError()
  27. feature_parser = parser.add_mutually_exclusive_group(required=required)
  28. feature_parser.add_argument('--' + name, dest=name, action='store_true', help=help, default=default)
  29. feature_parser.add_argument('--no' + name, dest=name, action='store_false')
  30. feature_parser.set_defaults(name=default)
  31. def parse_cmdline():
  32. p = argparse.ArgumentParser(description="JoC-UNet_v1-TF")
  33. p.add_argument(
  34. '--unet_variant',
  35. default="tinyUNet",
  36. choices=UNet_v1.authorized_models_variants,
  37. type=str,
  38. required=False,
  39. help="""Which model size is used. This parameter control directly the size and the number of parameters"""
  40. )
  41. p.add_argument(
  42. '--activation_fn',
  43. choices=authorized_activation_fn,
  44. type=str,
  45. default="relu",
  46. required=False,
  47. help="""Which activation function is used after the convolution layers"""
  48. )
  49. p.add_argument(
  50. '--exec_mode',
  51. choices=['train', 'train_and_evaluate', 'evaluate', 'training_benchmark', 'inference_benchmark'],
  52. type=str,
  53. required=True,
  54. help="""Which execution mode to run the model into"""
  55. )
  56. p.add_argument(
  57. '--iter_unit',
  58. choices=['epoch', 'batch'],
  59. type=str,
  60. required=True,
  61. help="""Will the model be run for X batches or X epochs ?"""
  62. )
  63. p.add_argument('--num_iter', type=int, required=True, help="""Number of iterations to run.""")
  64. p.add_argument('--batch_size', type=int, required=True, help="""Size of each minibatch per GPU.""")
  65. p.add_argument(
  66. '--warmup_step',
  67. default=200,
  68. type=int,
  69. required=False,
  70. help="""Number of steps considered as warmup and not taken into account for performance measurements."""
  71. )
  72. p.add_argument(
  73. '--results_dir',
  74. type=str,
  75. required=True,
  76. help="""Directory in which to write training logs, summaries and checkpoints."""
  77. )
  78. p.add_argument(
  79. '--log_dir',
  80. type=str,
  81. required=False,
  82. default="dlloger_out.json",
  83. help="""Directory in which to write logs."""
  84. )
  85. _add_bool_argument(
  86. parser=p,
  87. name="save_eval_results_to_json",
  88. default=False,
  89. required=False,
  90. help="Whether to save evaluation results in JSON format."
  91. )
  92. p.add_argument('--data_dir', required=False, default=None, type=str, help="Path to dataset directory")
  93. p.add_argument(
  94. '--dataset_name',
  95. choices=list(known_datasets.keys()),
  96. type=str,
  97. required=True,
  98. help="""Name of the dataset used in this run (only DAGM2007 is supported atm.)"""
  99. )
  100. p.add_argument(
  101. '--dataset_classID',
  102. default=None,
  103. type=int,
  104. required=False,
  105. help="""ClassID to consider to train or evaluate the network (used for DAGM)."""
  106. )
  107. p.add_argument(
  108. '--data_format',
  109. choices=['NHWC', 'NCHW'],
  110. type=str,
  111. default="NCHW",
  112. required=False,
  113. help="""Which Tensor format is used for computation inside the mode"""
  114. )
  115. _add_bool_argument(
  116. parser=p,
  117. name="amp",
  118. default=False,
  119. required=False,
  120. help="Enable Automatic Mixed Precision to speedup FP32 computation using tensor cores"
  121. )
  122. _add_bool_argument(
  123. parser=p, name="xla", default=False, required=False, help="Enable Tensorflow XLA to maximise performance."
  124. )
  125. p.add_argument(
  126. '--weight_init_method',
  127. choices=UNet_v1.authorized_weight_init_methods,
  128. default="he_normal",
  129. type=str,
  130. required=False,
  131. help="""Which initialisation method is used to randomly intialize the model during training"""
  132. )
  133. p.add_argument('--learning_rate', default=1e-4, type=float, required=False, help="""Learning rate value.""")
  134. p.add_argument(
  135. '--learning_rate_decay_factor',
  136. default=0.8,
  137. type=float,
  138. required=False,
  139. help="""Decay factor to decrease the learning rate."""
  140. )
  141. p.add_argument(
  142. '--learning_rate_decay_steps',
  143. default=500,
  144. type=int,
  145. required=False,
  146. help="""Decay factor to decrease the learning rate."""
  147. )
  148. p.add_argument('--rmsprop_decay', default=0.9, type=float, required=False, help="""RMSProp - Decay value.""")
  149. p.add_argument('--rmsprop_momentum', default=0.8, type=float, required=False, help="""RMSProp - Momentum value.""")
  150. p.add_argument('--weight_decay', default=1e-5, type=float, required=False, help="""Weight Decay scale factor""")
  151. _add_bool_argument(
  152. parser=p, name="use_auto_loss_scaling", default=False, required=False, help="Use AutoLossScaling with TF-AMP"
  153. )
  154. p.add_argument(
  155. '--loss_fn_name',
  156. type=str,
  157. default="adaptive_loss",
  158. required=False,
  159. help="""Loss function Name to use to train the network"""
  160. )
  161. _add_bool_argument(
  162. parser=p, name="augment_data", default=True, required=False, help="Choose whether to use data augmentation"
  163. )
  164. p.add_argument(
  165. '--display_every',
  166. type=int,
  167. default=50,
  168. required=False,
  169. help="""How often (in batches) to print out debug information."""
  170. )
  171. p.add_argument(
  172. '--debug_verbosity',
  173. choices=[0, 1, 2],
  174. default=0,
  175. type=int,
  176. required=False,
  177. help="""Verbosity Level: 0 minimum, 1 with layer creation debug info, 2 with layer + var creation debug info."""
  178. )
  179. p.add_argument('--seed', type=int, default=None, help="""Random seed.""")
  180. FLAGS, unknown_args = p.parse_known_args()
  181. if len(unknown_args) > 0:
  182. for bad_arg in unknown_args:
  183. print("ERROR: Unknown command line arg: %s" % bad_arg)
  184. raise ValueError("Invalid command line arg(s)")
  185. return FLAGS