export.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import tensorflow as tf
  16. from utils.data_loader import MSDDataset
  17. from utils.model_fn import vnet_v2
  18. from utils.tf_export import to_savedmodel, to_tf_trt, to_onnx
  19. PARSER = argparse.ArgumentParser(description="V-Net")
  20. PARSER.add_argument('--to', dest='to', choices=['savedmodel', 'tftrt', 'onnx'], required=True)
  21. PARSER.add_argument('--use_amp', dest='use_amp', action='store_true', default=False)
  22. PARSER.add_argument('--use_xla', dest='use_xla', action='store_true', default=False)
  23. PARSER.add_argument('--compress', dest='compress', action='store_true', default=False)
  24. PARSER.add_argument('--input_shape',
  25. nargs='+',
  26. type=int,
  27. help="""Model's input shape""")
  28. PARSER.add_argument('--data_dir',
  29. type=str,
  30. help="""Directory where the dataset is located""")
  31. PARSER.add_argument('--checkpoint_dir',
  32. type=str,
  33. help="""Directory where the checkpoint is located""")
  34. PARSER.add_argument('--savedmodel_dir',
  35. type=str,
  36. help="""Directory where the savedModel is located""")
  37. PARSER.add_argument('--precision',
  38. type=str,
  39. choices=['FP32', 'FP16', 'INT8'],
  40. help="""Precision for the model""")
  41. def main():
  42. """
  43. Starting point of the application
  44. """
  45. flags = PARSER.parse_args()
  46. if flags.to == 'savedmodel':
  47. params = {
  48. 'labels': ['0', '1', '2'],
  49. 'batch_size': 1,
  50. 'input_shape': flags.input_shape,
  51. 'convolution_size': 3,
  52. 'downscale_blocks': [3, 3, 3],
  53. 'upscale_blocks': [3, 3],
  54. 'upsampling': 'transposed_conv',
  55. 'pooling': 'conv_pool',
  56. 'normalization_layer': 'batchnorm',
  57. 'activation': 'relu'
  58. }
  59. to_savedmodel(input_shape=flags.input_shape,
  60. model_fn=vnet_v2,
  61. checkpoint_dir=flags.checkpoint_dir,
  62. output_dir='./saved_model',
  63. input_names=['IteratorGetNext'],
  64. output_names=['vnet/loss/total_loss_ref'],
  65. use_amp=flags.use_amp,
  66. use_xla=flags.use_xla,
  67. compress=flags.compress,
  68. params=argparse.Namespace(**params))
  69. if flags.to == 'tftrt':
  70. ds = MSDDataset(json_path=flags.data_dir + "/dataset.json",
  71. interpolator='linear')
  72. iterator = ds.test_fn(count=1).make_one_shot_iterator()
  73. features = iterator.get_next()
  74. sess = tf.Session()
  75. def input_data():
  76. return {'input_tensor:0': sess.run(features)}
  77. to_tf_trt(savedmodel_dir=flags.savedmodel_dir,
  78. output_dir='./tf_trt_model',
  79. precision=flags.precision,
  80. feed_dict_fn=input_data,
  81. num_runs=1,
  82. output_tensor_names=['vnet/Softmax:0'],
  83. compress=flags.compress)
  84. if flags.to == 'onnx':
  85. raise NotImplementedError('Currently ONNX not supported for 3D models')
  86. if __name__ == '__main__':
  87. main()