Преглед изворни кода

Add QAT instructions for RN50

Dheeraj Peri пре 5 година
родитељ
комит
c4f90be499

+ 104 - 0
TensorFlow/Classification/ConvNets/export_frozen_graph.py

@@ -0,0 +1,104 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+
+import tensorflow as tf
+
+import horovod.tensorflow as hvd
+from model import resnet_v1_5
+
+tf.app.flags.DEFINE_string(
+    'model_name', 'resnet50_v1.5', 'The name of the architecture to save. The default name was being ' 
+     'used to train the model')
+
+tf.app.flags.DEFINE_integer(
+    'image_size', 224,
+    'The image size to use, otherwise use the model default_image_size.')
+
+tf.app.flags.DEFINE_integer(
+    'num_classes', 1001,
+    'The number of classes to predict.')
+
+tf.app.flags.DEFINE_integer(
+    'batch_size', None,
+    'Batch size for the exported model. Defaulted to "None" so batch size can '
+    'be specified at model runtime.')
+
+
+tf.app.flags.DEFINE_string('input_format', 'NHWC',
+                           'The dataformat used by the layers in the model')
+
+tf.app.flags.DEFINE_string('compute_format', 'NHWC',
+                           'The dataformat used by the layers in the model')
+
+tf.app.flags.DEFINE_string('checkpoint', '',
+                           'The trained model checkpoint.')
+
+tf.app.flags.DEFINE_string(
+    'output_file', '', 'Where to save the resulting file to.')
+
+tf.app.flags.DEFINE_bool(
+    'quantize', False, 'whether to use quantized graph or not.')
+
+tf.app.flags.DEFINE_bool(
+    'symmetric', False, 'Using symmetric quantization or not.')
+
+
+tf.app.flags.DEFINE_bool(
+    'use_qdq', False, 'Use quantize and dequantize op instead of fake quant op')
+
+tf.app.flags.DEFINE_bool(
+    'use_final_conv', False, 'whether to use quantized graph or not.')
+
+tf.app.flags.DEFINE_bool('write_text_graphdef', False,
+                         'Whether to write a text version of graphdef.')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(_):
+  
+  # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
+  hvd.init()
+
+  if not FLAGS.output_file:
+    raise ValueError('You must supply the path to save to with --output_file')
+
+  tf.logging.set_verbosity(tf.logging.INFO)
+  with tf.Graph().as_default() as graph:
+    if FLAGS.input_format=='NCHW':
+        input_shape = [FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size]
+    else:
+        input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
+    input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
+    network = resnet_v1_5.ResnetModel(FLAGS.model_name, FLAGS.num_classes, FLAGS.compute_format, FLAGS.input_format)
+    probs, logits = network.build_model(
+                input_images,
+                training=False,
+                reuse=False,
+                use_final_conv=FLAGS.use_final_conv)
+
+    if FLAGS.quantize:
+        tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq)
+
+    # Define the saver and restore the checkpoint
+    saver = tf.train.Saver()
+    with tf.Session() as sess:
+        if FLAGS.checkpoint:
+            saver.restore(sess, FLAGS.checkpoint)
+        else:
+            sess.run(tf.global_variables_initializer())
+        graph_def = graph.as_graph_def()
+        frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [probs.op.name])
+
+    # Write out the frozen graph
+    tf.io.write_graph(
+        frozen_graph_def,
+        os.path.dirname(FLAGS.output_file),
+        os.path.basename(FLAGS.output_file),
+        as_text=FLAGS.write_text_graphdef)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 12 - 2
TensorFlow/Classification/ConvNets/main.py

@@ -94,6 +94,12 @@ if __name__ == "__main__":
             use_static_loss_scaling=FLAGS.use_static_loss_scaling,
             use_cosine_lr=FLAGS.use_cosine_lr,
             is_benchmark=FLAGS.mode == 'training_benchmark',
+            use_final_conv=FLAGS.use_final_conv,
+            quantize=FLAGS.quantize,
+            symmetric=FLAGS.symmetric,
+            quant_delay = FLAGS.quant_delay,
+            use_qdq = FLAGS.use_qdq,
+            finetune_checkpoint = FLAGS.finetune_checkpoint,
         )
 
     if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:
@@ -110,7 +116,11 @@ if __name__ == "__main__":
                 batch_size=FLAGS.batch_size,
                 log_every_n_steps=FLAGS.display_every,
                 is_benchmark=FLAGS.mode == 'inference_benchmark',
-                export_dir=FLAGS.export_dir
+                export_dir=FLAGS.export_dir,
+                quantize=FLAGS.quantize,
+                symmetric=FLAGS.symmetric,
+                use_final_conv=FLAGS.use_final_conv,
+                use_qdq=FLAGS.use_qdq
             )
 
     if FLAGS.mode == 'predict':
@@ -124,4 +134,4 @@ if __name__ == "__main__":
             raise NotImplementedError("Only single GPU inference is implemented.")
 
         elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:
-            runner.predict(FLAGS.to_predict)
+            runner.predict(FLAGS.to_predict, quantize=FLAGS.quantize, symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq, use_final_conv=FLAGS.use_final_conv)

+ 54 - 23
TensorFlow/Classification/ConvNets/model/resnet.py

@@ -21,6 +21,7 @@ from __future__ import print_function
 import tensorflow as tf
 
 import horovod.tensorflow as hvd
+import dllogger
 
 from model import layers
 from model import blocks
@@ -183,8 +184,12 @@ class ResnetModel(object):
             probs, logits = self.build_model(
                 features,
                 training=mode == tf.estimator.ModeKeys.TRAIN,
-                reuse=False
+                reuse=False,
+                use_final_conv=params['use_final_conv']
             )
+            
+            if mode!=tf.estimator.ModeKeys.PREDICT:
+                logits = tf.squeeze(logits)
 
             y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)
 
@@ -196,16 +201,25 @@ class ResnetModel(object):
             tf.identity(logits, name="logits_ref")
             tf.identity(probs, name="probs_ref")
             tf.identity(y_preds, name="y_preds_ref")
-
-            #if mode == tf.estimator.ModeKeys.TRAIN:
-            #    
-            #    assert (len(tf.trainable_variables()) == 161)
-            #
-            #else:
-            #    
-            #    assert (len(tf.trainable_variables()) == 0)
-
-
+            
+            if mode == tf.estimator.ModeKeys.TRAIN and params['quantize']:
+                dllogger.log(data={"QUANTIZATION AWARE TRAINING ENABLED": True}, step=tuple())
+                if params['symmetric']:
+                    dllogger.log(data={"MODE":"USING SYMMETRIC MODE"}, step=tuple())
+                    tf.contrib.quantize.experimental_create_training_graph(tf.get_default_graph(), symmetric=True, use_qdq=params['use_qdq'] ,quant_delay=params['quant_delay'])
+                else:
+                    dllogger.log(data={"MODE":"USING ASSYMETRIC MODE"}, step=tuple())
+                    tf.contrib.quantize.create_training_graph(tf.get_default_graph(), quant_delay=params['quant_delay'], use_qdq=params['use_qdq'])
+            
+            # Fix for restoring variables during fine-tuning of Resnet-50
+            if 'finetune_checkpoint' in params.keys():
+                train_vars = tf.trainable_variables()
+                train_var_dict = {}
+                for var in train_vars:
+                    train_var_dict[var.op.name] = var
+                dllogger.log(data={"Restoring variables from checkpoint": params['finetune_checkpoint']}, step=tuple())
+                tf.train.init_from_checkpoint(params['finetune_checkpoint'], train_var_dict)
+                
         if mode == tf.estimator.ModeKeys.PREDICT:
 
             predictions = {'classes': y_preds, 'probabilities': probs}
@@ -352,7 +366,7 @@ class ResnetModel(object):
 
 
 
-    def build_model(self, inputs, training=True, reuse=False):
+    def build_model(self, inputs, training=True, reuse=False, use_final_conv=False):
         
         with var_storage.model_variable_scope(
             self.model_hparams.model_name,
@@ -416,20 +430,37 @@ class ResnetModel(object):
 
             with tf.variable_scope("output"):
                 net = layers.reduce_mean(
-                    net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')
-
-                logits = layers.dense(
-                    inputs=net,
-                    units=self.model_hparams.n_classes,
-                    use_bias=True,
-                    trainable=training,
-                    kernel_initializer=self.dense_hparams.kernel_initializer,
-                    bias_initializer=self.dense_hparams.bias_initializer)
+                    net, keepdims=use_final_conv, data_format=self.model_hparams.compute_format, name='spatial_mean')
+
+                if use_final_conv:
+                    logits = layers.conv2d(
+                                    net,
+                                    n_channels=self.model_hparams.n_classes,
+                                    kernel_size=(1, 1),
+                                    strides=(1, 1),
+                                    padding='SAME',
+                                    data_format=self.model_hparams.compute_format,
+                                    dilation_rate=(1, 1),
+                                    use_bias=True,
+                                    kernel_initializer=self.dense_hparams.kernel_initializer,
+                                    bias_initializer=self.dense_hparams.bias_initializer,
+                                    trainable=training,
+                                    name='dense'
+                                )
+                else:
+                    logits = layers.dense(
+                        inputs=net,
+                        units=self.model_hparams.n_classes,
+                        use_bias=True,
+                        trainable=training,
+                        kernel_initializer=self.dense_hparams.kernel_initializer,
+                        bias_initializer=self.dense_hparams.bias_initializer)
 
                 if logits.dtype != tf.float32:
                     logits = tf.cast(logits, tf.float32)
-
-                probs = layers.softmax(logits, name="softmax", axis=1)
+                    
+                axis = 3 if self.model_hparams.compute_format=="NHWC" and use_final_conv else 1
+                probs = layers.softmax(logits, name="softmax", axis=axis)
 
             return probs, logits
 

+ 51 - 0
TensorFlow/Classification/ConvNets/postprocess_ckpt.py

@@ -0,0 +1,51 @@
+import tensorflow as tf
+import pdb
+import numpy as np
+import argparse
+import os
+import shutil
+
+def main(args):
+    with tf.Session() as sess:
+        ckpt = args.ckpt
+        new_ckpt=args.out
+        output_dir = "./new_ckpt_dir"
+        if os.path.isdir(output_dir):
+            shutil.rmtree(output_dir)
+        # Create an output directory
+        os.mkdir(output_dir)
+
+        new_ckpt_path = os.path.join(output_dir, new_ckpt)
+        with open(os.path.join(output_dir, "checkpoint"), 'w') as file:
+            file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")
+        file.close()
+        # Load all the variables
+        all_vars = tf.train.list_variables(ckpt)
+        ckpt_reader = tf.train.load_checkpoint(ckpt)
+        # Capture the dense layer weights and reshape them to a 4D tensor which would be 
+        # the weights of a 1x1 convolution layer. This code replaces the dense (FC) layer
+        # to a 1x1 conv layer. 
+        dense_layer = 'resnet50_v1.5/output/dense/kernel'
+        dense_layer_value=0.
+        new_var_list=[]
+        for var in all_vars:
+            curr_var = tf.train.load_variable(ckpt, var[0])
+            if var[0]==dense_layer:
+                dense_layer_value = curr_var
+            else:
+                new_var_list.append(tf.Variable(curr_var, name=var[0]))
+
+        new_var_value = np.reshape(dense_layer_value, [1, 1, 2048, 1001])
+        new_var = tf.Variable(new_var_value, name=dense_layer)
+        new_var_list.append(new_var)
+        
+        sess.run(tf.global_variables_initializer())
+        tf.train.Saver(var_list=new_var_list).save(sess, new_ckpt_path, write_meta_graph=False, write_state=False)
+        print ("Rewriting checkpoints completed")
+
+if __name__=='__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--ckpt', type=str, required=True)
+    parser.add_argument('--out', type=str, default='./new.ckpt')
+    args = parser.parse_args()
+    main(args)

+ 29 - 0
TensorFlow/Classification/ConvNets/resnet50v1.5/README.md

@@ -18,6 +18,7 @@ This repository provides a script and recipe to train the ResNet-50 v1.5 model t
     * [Scripts and sample code](#scripts-and-sample-code)
     * [Parameters](#parameters)
         * [The `main.py` script](#the-mainpy-script)
+    * [Quantization Aware training](#quantization-aware-training)
     * [Inference process](#inference-process)
 * [Performance](#performance)
     * [Benchmarking](#benchmarking)
@@ -351,8 +352,36 @@ optional arguments:
                         Limit memory fraction used by the training script for DALI
   --gpu_id GPU_ID       Specify the ID of the target GPU on a multi-device platform.
                         Effective only for single-GPU mode.
+  --quantize            Used to add quantization nodes in the graph (Default: Asymmetric quantization)
+  --symmetric           If --quantize mode is used, this option enables symmetric quantization
+  --use_qdq             Use quantize_and_dequantize (QDQ) op instead of FakeQuantWithMinMaxVars op for quantization. QDQ does only scaling.
+  --finetune_checkpoint Path to pre-trained checkpoint which can be used for fine-tuning
+  --quant_delay         Number of steps to be run before quantization starts to happen
 ```
 
+### Quantization Aware Training
+Quantization Aware training (QAT) simulates quantization during training by quantizing weights and activation layers. This will help reduce the loss in accuracy when we convert the network
+trained in FP32 to INT8 for faster inference. QAT introduces additional nodes in the graph which will be used to learn the dynamic ranges of weights and activation layers. Tensorflow provides
+a <a href="https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/contrib/quantize">quantization tool</a> which automatically adds these nodes in-place. Typical workflow
+for training QAT networks is to train a model until convergence and then finetune with the quantization layers. It is recommended that QAT is performed on a single GPU.
+
+* For 1 GPU
+    * Command: `sh resnet50v1.5/training/QAT/GPU1_RN50_QAT.sh <path to pre-trained ckpt dir> <path to dataset directory> <result_directory>`
+        
+It is recommended to finetune a model with quantization nodes rather than train a QAT model from scratch. The latter can also be performed by setting `quant_delay` parameter.
+`quant_delay` is the number of steps after which quantization nodes are added for QAT. If we are fine-tuning, `quant_delay` is set to 0. 
+        
+For QAT network, we use <a href="https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/quantization/quantize_and_dequantize">tf.quantization.quantize_and_dequantize operation</a>.
+These operations are automatically added at weights and activation layers in the RN50 by using `tf.contrib.quantize.experimental_create_training_graph` utility. Support for using `tf.quantization.quantize_and_dequantize` 
+operations for `tf.contrib.quantize.experimental_create_training_graph has been added in <a href="https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow">TensorFlow 20.01-py3 NGC container</a> and later versions, which is required for this task.
+
+#### Post process checkpoint
+  * `post_process_ckpt.py` is a utility to convert the final classification FC layer into a 1x1 convolution layer using the same weights. This is required to ensure TensorRT can parse QAT models successfully.
+  This script should be used after performing QAT to reshape the FC layer weights in the final checkpoint.
+  Arguments:
+     * `--ckpt` : Path to the trained checkpoint of RN50.
+     * `--out` : Name of the new checkpoint file which has the FC layer weights reshaped into 1x1 conv layer weights.
+
 ### Inference process
 To run inference on a single example with a checkpoint and a model script, use: 
 

+ 4 - 0
TensorFlow/Classification/ConvNets/resnet50v1.5/training/QAT/GPU1_RN50_QAT.sh

@@ -0,0 +1,4 @@
+# This script does Quantization aware training of Resnet-50 by finetuning on the pre-trained model using 1 GPU and a batch size of 32.
+# Usage ./GPU1_RN50_QAT.sh <path to the pre-trained model> <path to dataset> <path to results directory>
+
+python main.py --mode=train_and_evaluate --batch_size=32 --lr_warmup_epochs=1 --label_smoothing 0.1 --lr_init=0.00005 --momentum=0.875 --weight_decay=3.0517578125e-05 --finetune_checkpoint=$1 --data_dir=$2 --results_dir=$3 --quantize --symmetric --num_iter 10 --data_format NHWC

+ 27 - 6
TensorFlow/Classification/ConvNets/runtime/runner.py

@@ -337,7 +337,13 @@ class Runner(object):
         mixup=0.0,
         use_cosine_lr=False,
         use_static_loss_scaling=False,
-        is_benchmark=False
+        is_benchmark=False,
+        quantize=False,
+        symmetric=False,
+        quant_delay=0,
+        finetune_checkpoint=None,
+        use_final_conv=False,
+        use_qdq=False
     ):
 
         if iter_unit not in ["epoch", "batch"]:
@@ -432,9 +438,17 @@ class Runner(object):
             'label_smoothing': label_smoothing,
             'mixup': mixup,
             'num_decay_steps': num_decay_steps,
-            'use_cosine_lr': use_cosine_lr
+            'use_cosine_lr': use_cosine_lr,
+            'use_final_conv': use_final_conv,
+            'quantize': quantize,
+            'use_qdq': use_qdq,
+            'symmetric': symmetric,
+            'quant_delay': quant_delay
         }
-
+        
+        if finetune_checkpoint:
+           estimator_params['finetune_checkpoint']=finetune_checkpoint
+        
         image_classifier = self._get_estimator(
             mode='train',
             run_params=estimator_params,
@@ -511,6 +525,10 @@ class Runner(object):
         log_every_n_steps=1,
         is_benchmark=False,
         export_dir=None,
+        quantize=False,
+        symmetric=False,
+        use_qdq=False,
+        use_final_conv=False,
     ):
 
         if iter_unit not in ["epoch", "batch"]:
@@ -522,7 +540,10 @@ class Runner(object):
         if hvd_utils.is_using_hvd() and hvd.rank() != 0:
             raise RuntimeError('Multi-GPU inference is not supported')
 
-        estimator_params = {}
+        estimator_params = {'quantize': quantize,
+                            'symmetric': symmetric,
+                            'use_qdq': use_qdq,
+                            'use_final_conv': use_final_conv}
 
         image_classifier = self._get_estimator(
             mode='validation',
@@ -649,9 +670,9 @@ class Runner(object):
 
         print('Model evaluation finished')
 
-    def predict(self, to_predict):
+    def predict(self, to_predict, quantize=False, symmetric=False, use_qdq=False, use_final_conv=False):
 
-        estimator_params = {}
+        estimator_params = {'quantize': quantize, 'symmetric': symmetric, 'use_qdq': use_qdq, 'use_final_conv': use_final_conv}
 
         if to_predict is not None:
             filenames = runner_utils.parse_inference_input(to_predict)

+ 44 - 0
TensorFlow/Classification/ConvNets/utils/cmdline_helper.py

@@ -128,6 +128,50 @@ def parse_cmdline(available_arch):
         default='.',
         help="""Directory in which to write training logs, summaries and checkpoints."""
     )
+    
+    p.add_argument(
+        '--finetune_checkpoint',
+        required=False,
+        default=None,
+        type=str,
+        help="Path to pre-trained checkpoint which will be used for fine-tuning"
+    )
+    
+    _add_bool_argument(
+        parser=p, name="use_final_conv", default=False, required=False, help="Use cosine learning rate schedule."
+    )
+
+    p.add_argument(
+        '--quant_delay',
+        type=int,
+        default=0,
+        required=False,
+        help="Number of steps to be run before quantization starts to happen"
+    )
+
+    _add_bool_argument(
+        parser=p,
+        name="quantize",
+        default=False,
+        required=False,
+        help="Quantize weights and activations during training. (Defaults to Assymmetric quantization)"
+    )
+
+    _add_bool_argument(
+        parser=p,
+        name="use_qdq",
+        default=False,
+        required=False,
+        help="Use QDQV3 op instead of FakeQuantWithMinMaxVars op for quantization. QDQv3 does only scaling"
+    )
+
+    _add_bool_argument(
+        parser=p,
+        name="symmetric",
+        default=False,
+        required=False,
+        help="Quantize weights and activations during training using symmetric quantization."
+    )
 
     p.add_argument(
         '--log_filename',