export.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import argparse
  2. import tensorflow as tf
  3. from dlexport.tensorflow import to_savedmodel, to_onnx, to_tensorrt
  4. from utils.data_loader import Dataset
  5. from utils.model_fn import unet_fn
  6. PARSER = argparse.ArgumentParser(description="U-Net medical")
  7. PARSER.add_argument('--to', dest='to', choices=['savedmodel', 'tensorrt', 'onnx'], required=True)
  8. PARSER.add_argument('--use_amp', dest='use_amp', action='store_true', default=False)
  9. PARSER.add_argument('--use_xla', dest='use_xla', action='store_true', default=False)
  10. PARSER.add_argument('--compress', dest='compress', action='store_true', default=False)
  11. PARSER.add_argument('--input_shape',
  12. nargs='+',
  13. type=int,
  14. help="""Directory where to download the dataset""")
  15. PARSER.add_argument('--data_dir',
  16. type=str,
  17. help="""Directory where to download the dataset""")
  18. PARSER.add_argument('--checkpoint_dir',
  19. type=str,
  20. help="""Directory where to download the dataset""")
  21. PARSER.add_argument('--savedmodel_dir',
  22. type=str,
  23. help="""Directory where to download the dataset""")
  24. PARSER.add_argument('--precision',
  25. type=str,
  26. choices=['FP32', 'FP16', 'INT8'],
  27. help="""Directory where to download the dataset""")
  28. def main():
  29. """
  30. Starting point of the application
  31. """
  32. flags = PARSER.parse_args()
  33. if flags.to == 'savedmodel':
  34. to_savedmodel(input_shape=flags.input_shape,
  35. model_fn=unet_fn,
  36. src_dir=flags.checkpoint_dir,
  37. dst_dir='./saved_model',
  38. input_names=['IteratorGetNext'],
  39. output_names=['total_loss_ref'],
  40. use_amp=flags.use_amp,
  41. use_xla=flags.use_xla,
  42. compress=flags.compress)
  43. if flags.to == 'tensorrt':
  44. ds = Dataset(data_dir=flags.data_dir,
  45. batch_size=1,
  46. augment=False,
  47. gpu_id=0,
  48. num_gpus=1,
  49. seed=42)
  50. iterator = ds.test_fn(count=1).make_one_shot_iterator()
  51. features = iterator.get_next()
  52. sess = tf.Session()
  53. def input_data():
  54. return {'input_tensor:0': sess.run(features)}
  55. to_tensorrt(src_dir=flags.savedmodel_dir,
  56. dst_dir='./tf_trt_model',
  57. precision=flags.precision,
  58. feed_dict_fn=input_data,
  59. num_runs=1,
  60. output_tensor_names=['Softmax:0'],
  61. compress=flags.compress)
  62. if flags.to == 'onnx':
  63. to_onnx(src_dir=flags.savedmodel_dir,
  64. dst_dir='./onnx_model',
  65. compress=flags.compress)
  66. if __name__ == '__main__':
  67. main()