export_frozen_graph.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import os
  5. import tensorflow as tf
  6. import horovod.tensorflow as hvd
  7. from model import resnet
  8. tf.app.flags.DEFINE_string(
  9. 'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
  10. 'used to train the model')
  11. tf.app.flags.DEFINE_integer(
  12. 'image_size', 224,
  13. 'The image size to use, otherwise use the model default_image_size.')
  14. tf.app.flags.DEFINE_integer(
  15. 'num_classes', 1001,
  16. 'The number of classes to predict.')
  17. tf.app.flags.DEFINE_integer(
  18. 'batch_size', None,
  19. 'Batch size for the exported model. Defaulted to "None" so batch size can '
  20. 'be specified at model runtime.')
  21. tf.app.flags.DEFINE_string('input_format', 'NCHW',
  22. 'The dataformat used by the layers in the model')
  23. tf.app.flags.DEFINE_string('compute_format', 'NCHW',
  24. 'The dataformat used by the layers in the model')
  25. tf.app.flags.DEFINE_string('checkpoint', '',
  26. 'The trained model checkpoint.')
  27. tf.app.flags.DEFINE_string(
  28. 'output_file', '', 'Where to save the resulting file to.')
  29. tf.app.flags.DEFINE_bool(
  30. 'quantize', False, 'whether to use quantized graph or not.')
  31. tf.app.flags.DEFINE_bool(
  32. 'symmetric', False, 'Using symmetric quantization or not.')
  33. tf.app.flags.DEFINE_bool(
  34. 'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
  35. tf.app.flags.DEFINE_bool(
  36. 'use_final_conv', False, 'whether to use quantized graph or not.')
  37. tf.app.flags.DEFINE_bool('write_text_graphdef', False,
  38. 'Whether to write a text version of graphdef.')
  39. FLAGS = tf.app.flags.FLAGS
  40. def main(_):
  41. # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
  42. hvd.init()
  43. if not FLAGS.output_file:
  44. raise ValueError('You must supply the path to save to with --output_file')
  45. tf.logging.set_verbosity(tf.logging.INFO)
  46. with tf.Graph().as_default() as graph:
  47. if FLAGS.input_format=='NCHW':
  48. input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
  49. else:
  50. input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
  51. input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
  52. resnet50_config = resnet.model_architectures[FLAGS.model_name]
  53. network = resnet.ResnetModel(FLAGS.model_name,
  54. FLAGS.num_classes,
  55. resnet50_config['layers'],
  56. resnet50_config['widths'],
  57. resnet50_config['expansions'],
  58. FLAGS.compute_format,
  59. FLAGS.input_format)
  60. probs, logits = network.build_model(
  61. input_images,
  62. training=False,
  63. reuse=False,
  64. use_final_conv=FLAGS.use_final_conv)
  65. if FLAGS.quantize:
  66. tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
  67. use_qdq=FLAGS.use_qdq)
  68. # Define the saver and restore the checkpoint
  69. saver = tf.train.Saver()
  70. with tf.Session() as sess:
  71. if FLAGS.checkpoint:
  72. saver.restore(sess, FLAGS.checkpoint)
  73. else:
  74. sess.run(tf.global_variables_initializer())
  75. graph_def = graph.as_graph_def()
  76. frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
  77. # Write out the frozen graph
  78. tf.io.write_graph(
  79. frozen_graph_def,
  80. os.path.dirname(FLAGS.output_file),
  81. os.path.basename(FLAGS.output_file),
  82. as_text=FLAGS.write_text_graphdef)
  83. if __name__ == '__main__':
  84. tf.app.run()