소스 검색

Update frozen graph script and instructions

Dheeraj Peri 5 년 전
부모
커밋
e0f399def4

+ 17 - 8
TensorFlow/Classification/ConvNets/export_frozen_graph.py

@@ -6,10 +6,10 @@ import os
 import tensorflow as tf
 
 import horovod.tensorflow as hvd
-from model import resnet_v1_5
+from model import resnet
 
 tf.app.flags.DEFINE_string(
-    'model_name', 'resnet50_v1.5', 'The name of the architecture to save. The default name was being ' 
+    'model_name', 'resnet50', 'The name of the architecture to save. The default name was being ' 
      'used to train the model')
 
 tf.app.flags.DEFINE_integer(
@@ -26,10 +26,10 @@ tf.app.flags.DEFINE_integer(
     'be specified at model runtime.')
 
 
-tf.app.flags.DEFINE_string('input_format', 'NHWC',
+tf.app.flags.DEFINE_string('input_format', 'NCHW',
                            'The dataformat used by the layers in the model')
 
-tf.app.flags.DEFINE_string('compute_format', 'NHWC',
+tf.app.flags.DEFINE_string('compute_format', 'NCHW',
                            'The dataformat used by the layers in the model')
 
 tf.app.flags.DEFINE_string('checkpoint', '',
@@ -72,15 +72,24 @@ def main(_):
     else:
         input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
     input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
-    network = resnet_v1_5.ResnetModel(FLAGS.model_name, FLAGS.num_classes, FLAGS.compute_format, FLAGS.input_format)
+    
+    resnet50_config = resnet.model_architectures[FLAGS.model_name]
+    network = resnet.ResnetModel(FLAGS.model_name, 
+                                 FLAGS.num_classes,
+                                 resnet50_config['layers'],
+                                 resnet50_config['widths'],
+                                 resnet50_config['expansions'],
+                                 FLAGS.compute_format,
+                                 FLAGS.input_format)
     probs, logits = network.build_model(
                 input_images,
                 training=False,
                 reuse=False,
                 use_final_conv=FLAGS.use_final_conv)
-
+    
     if FLAGS.quantize:
-        tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq)
+        tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric, 
+                                                           use_qdq=FLAGS.use_qdq)
 
     # Define the saver and restore the checkpoint
     saver = tf.train.Saver()
@@ -101,4 +110,4 @@ def main(_):
 
 
 if __name__ == '__main__':
-  tf.app.run()
+  tf.app.run()

+ 4 - 2
TensorFlow/Classification/ConvNets/model/layers/conv2d.py

@@ -31,7 +31,8 @@ def conv2d(
     use_bias=True,
     kernel_initializer=tf.variance_scaling_initializer(),
     bias_initializer=tf.zeros_initializer(),
-    trainable=True
+    trainable=True,
+    name=None
 ):
 
     if data_format not in ['NHWC', 'NCHW']:
@@ -52,7 +53,8 @@ def conv2d(
         kernel_initializer=kernel_initializer,
         bias_initializer=bias_initializer,
         trainable=trainable,
-        activation=None
+        activation=None,
+        name=name
     )
     
     return net

+ 14 - 1
TensorFlow/Classification/ConvNets/resnet50v1.5/README.md

@@ -373,7 +373,7 @@ It is recommended to finetune a model with quantization nodes rather than train
         
 For QAT network, we use <a href="https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/quantization/quantize_and_dequantize">tf.quantization.quantize_and_dequantize operation</a>.
 These operations are automatically added at weights and activation layers in the RN50 by using `tf.contrib.quantize.experimental_create_training_graph` utility. Support for using `tf.quantization.quantize_and_dequantize` 
-operations for `tf.contrib.quantize.experimental_create_training_graph has been added in <a href="https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow">TensorFlow 20.01-py3 NGC container</a> and later versions, which is required for this task.
+operations for `tf.contrib.quantize.experimental_create_training_graph` has been added in <a href="https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow">TensorFlow 20.01-py3 NGC container</a> and later versions, which is required for this task.
 
 #### Post process checkpoint
   * `post_process_ckpt.py` is a utility to convert the final classification FC layer into a 1x1 convolution layer using the same weights. This is required to ensure TensorRT can parse QAT models successfully.
@@ -382,6 +382,19 @@ operations for `tf.contrib.quantize.experimental_create_training_graph has been
      * `--ckpt` : Path to the trained checkpoint of RN50.
      * `--out` : Name of the new checkpoint file which has the FC layer weights reshaped into 1x1 conv layer weights.
 
+### Exporting Frozen graphs
+To export frozen graphs (which can be used for inference with <a href="https://developer.nvidia.com/tensorrt">TensorRT</a>), use:
+
+`python export_frozen_graph.py --checkpoint <path_to_checkpoint> --quantize --use_final_conv --use_qdq --symmetric --input_format NCHW --compute_format NCHW --output_file=<output_file_name>`
+
+Arguments:
+
+* `--checkpoint` : Optional argument to export the model with checkpoint weights.
+* `--quantize` : Optional flag to export quantized graphs.
+* `--use_qdq` : Use quantize_and_dequantize (QDQ) op instead of FakeQuantWithMinMaxVars op for quantization. QDQ does only scaling. 
+* `--input_format` : Data format of input tensor (Default: NCHW). Use NCHW format to optimize the graph with TensorRT.
+* `--compute_format` : Data format of the operations in the network (Default: NCHW). Use NCHW format to optimize the graph with TensorRT.
+
 ### Inference process
 To run inference on a single example with a checkpoint and a model script, use: