export_frozen_graph.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import tensorflow as tf
  21. from utils import hvd_wrapper as hvd
  22. from model import resnet
  23. tf.app.flags.DEFINE_string(
  24. 'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
  25. 'used to train the model')
  26. tf.app.flags.DEFINE_integer(
  27. 'image_size', 224,
  28. 'The image size to use, otherwise use the model default_image_size.')
  29. tf.app.flags.DEFINE_integer(
  30. 'num_classes', 1001,
  31. 'The number of classes to predict.')
  32. tf.app.flags.DEFINE_integer(
  33. 'batch_size', None,
  34. 'Batch size for the exported model. Defaulted to "None" so batch size can '
  35. 'be specified at model runtime.')
  36. tf.app.flags.DEFINE_string('input_format', 'NCHW',
  37. 'The dataformat used by the layers in the model')
  38. tf.app.flags.DEFINE_string('compute_format', 'NCHW',
  39. 'The dataformat used by the layers in the model')
  40. tf.app.flags.DEFINE_string('checkpoint', '',
  41. 'The trained model checkpoint.')
  42. tf.app.flags.DEFINE_string(
  43. 'output_file', '', 'Where to save the resulting file to.')
  44. tf.app.flags.DEFINE_bool(
  45. 'quantize', False, 'whether to use quantized graph or not.')
  46. tf.app.flags.DEFINE_bool(
  47. 'symmetric', False, 'Using symmetric quantization or not.')
  48. tf.app.flags.DEFINE_bool(
  49. 'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
  50. tf.app.flags.DEFINE_bool(
  51. 'use_final_conv', False, 'whether to use quantized graph or not.')
  52. tf.app.flags.DEFINE_bool('write_text_graphdef', False,
  53. 'Whether to write a text version of graphdef.')
  54. FLAGS = tf.app.flags.FLAGS
  55. def main(_):
  56. hvd.init()
  57. if not FLAGS.output_file:
  58. raise ValueError('You must supply the path to save to with --output_file')
  59. tf.logging.set_verbosity(tf.logging.INFO)
  60. with tf.Graph().as_default() as graph:
  61. if FLAGS.input_format=='NCHW':
  62. input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
  63. else:
  64. input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
  65. input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
  66. resnet50_config = resnet.model_architectures[FLAGS.model_name]
  67. network = resnet.ResnetModel(FLAGS.model_name,
  68. FLAGS.num_classes,
  69. resnet50_config['layers'],
  70. resnet50_config['widths'],
  71. resnet50_config['expansions'],
  72. FLAGS.compute_format,
  73. FLAGS.input_format)
  74. probs, logits = network.build_model(
  75. input_images,
  76. training=False,
  77. reuse=False,
  78. use_final_conv=FLAGS.use_final_conv)
  79. if FLAGS.quantize:
  80. tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
  81. use_qdq=FLAGS.use_qdq)
  82. # Define the saver and restore the checkpoint
  83. saver = tf.train.Saver()
  84. with tf.Session() as sess:
  85. if FLAGS.checkpoint:
  86. saver.restore(sess, FLAGS.checkpoint)
  87. else:
  88. sess.run(tf.global_variables_initializer())
  89. graph_def = graph.as_graph_def()
  90. frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
  91. # Write out the frozen graph
  92. tf.io.write_graph(
  93. frozen_graph_def,
  94. os.path.dirname(FLAGS.output_file),
  95. os.path.basename(FLAGS.output_file),
  96. as_text=FLAGS.write_text_graphdef)
  97. if __name__ == '__main__':
  98. tf.app.run()