| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import tensorflow as tf
- import horovod.tensorflow as hvd
- from model import resnet
- tf.app.flags.DEFINE_string(
- 'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
- 'used to train the model')
- tf.app.flags.DEFINE_integer(
- 'image_size', 224,
- 'The image size to use, otherwise use the model default_image_size.')
- tf.app.flags.DEFINE_integer(
- 'num_classes', 1001,
- 'The number of classes to predict.')
- tf.app.flags.DEFINE_integer(
- 'batch_size', None,
- 'Batch size for the exported model. Defaulted to "None" so batch size can '
- 'be specified at model runtime.')
- tf.app.flags.DEFINE_string('input_format', 'NCHW',
- 'The dataformat used by the layers in the model')
- tf.app.flags.DEFINE_string('compute_format', 'NCHW',
- 'The dataformat used by the layers in the model')
- tf.app.flags.DEFINE_string('checkpoint', '',
- 'The trained model checkpoint.')
- tf.app.flags.DEFINE_string(
- 'output_file', '', 'Where to save the resulting file to.')
- tf.app.flags.DEFINE_bool(
- 'quantize', False, 'whether to use quantized graph or not.')
- tf.app.flags.DEFINE_bool(
- 'symmetric', False, 'Using symmetric quantization or not.')
- tf.app.flags.DEFINE_bool(
- 'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
- tf.app.flags.DEFINE_bool(
- 'use_final_conv', False, 'whether to use quantized graph or not.')
- tf.app.flags.DEFINE_bool('write_text_graphdef', False,
- 'Whether to write a text version of graphdef.')
- FLAGS = tf.app.flags.FLAGS
- def main(_):
-
- # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
- hvd.init()
- if not FLAGS.output_file:
- raise ValueError('You must supply the path to save to with --output_file')
- tf.logging.set_verbosity(tf.logging.INFO)
- with tf.Graph().as_default() as graph:
- if FLAGS.input_format=='NCHW':
- input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
- else:
- input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
- input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
-
- resnet50_config = resnet.model_architectures[FLAGS.model_name]
- network = resnet.ResnetModel(FLAGS.model_name,
- FLAGS.num_classes,
- resnet50_config['layers'],
- resnet50_config['widths'],
- resnet50_config['expansions'],
- FLAGS.compute_format,
- FLAGS.input_format)
- probs, logits = network.build_model(
- input_images,
- training=False,
- reuse=False,
- use_final_conv=FLAGS.use_final_conv)
-
- if FLAGS.quantize:
- tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
- use_qdq=FLAGS.use_qdq)
- # Define the saver and restore the checkpoint
- saver = tf.train.Saver()
- with tf.Session() as sess:
- if FLAGS.checkpoint:
- saver.restore(sess, FLAGS.checkpoint)
- else:
- sess.run(tf.global_variables_initializer())
- graph_def = graph.as_graph_def()
- frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
- # Write out the frozen graph
- tf.io.write_graph(
- frozen_graph_def,
- os.path.dirname(FLAGS.output_file),
- os.path.basename(FLAGS.output_file),
- as_text=FLAGS.write_text_graphdef)
- if __name__ == '__main__':
- tf.app.run()
|