Просмотр исходного кода

[UNet/TF2] Add tf-trt and SavedModel tests. Remove profiling tests.

Michal Marcinkiewicz 4 лет назад
Родитель
Сommit
ddbcd54056

+ 10 - 4
TensorFlow2/Segmentation/UNet_Medical/data_loading/data_loader.py

@@ -25,12 +25,13 @@ from PIL import Image, ImageSequence
 class Dataset:
     """Load, separate and prepare the data for training and prediction"""
 
-    def __init__(self, data_dir, batch_size, fold, augment=False, gpu_id=0, num_gpus=1, seed=0):
+    def __init__(self, data_dir, batch_size, fold, augment=False, gpu_id=0, num_gpus=1, seed=0, amp=False):
         if not os.path.exists(data_dir):
             raise FileNotFoundError('Cannot find data dir: {}'.format(data_dir))
         self._data_dir = data_dir
         self._batch_size = batch_size
         self._augment = augment
+        self.precision = tf.float16 if amp else tf.float32
 
         self._seed = seed
 
@@ -149,7 +150,7 @@ class Dataset:
         cond = tf.less(labels, 0.5 * tf.ones(tf.shape(input=labels)))
         labels = tf.where(cond, tf.zeros(tf.shape(input=labels)), tf.ones(tf.shape(input=labels)))
 
-        return inputs, labels
+        return tf.cast(inputs, self.precision), labels
 
     @tf.function
     def _preproc_eval_samples(self, inputs, labels):
@@ -162,7 +163,12 @@ class Dataset:
         cond = tf.less(labels, 0.5 * tf.ones(tf.shape(input=labels)))
         labels = tf.where(cond, tf.zeros(tf.shape(input=labels)), tf.ones(tf.shape(input=labels)))
 
-        return (inputs, labels)
+        return tf.cast(inputs, self.precision), labels
+
+    @tf.function
+    def _preproc_test_samples(self, inputs):
+        inputs = self._normalize_inputs(inputs)
+        return tf.cast(inputs, self.precision)
 
     def train_fn(self, drop_remainder=False):
         """Input function for training"""
@@ -195,7 +201,7 @@ class Dataset:
         dataset = tf.data.Dataset.from_tensor_slices(
             self._test_images)
         dataset = dataset.repeat(count=count)
-        dataset = dataset.map(self._normalize_inputs)
+        dataset = dataset.map(self._preproc_test_samples)
         dataset = dataset.batch(self._batch_size, drop_remainder=drop_remainder)
         dataset = dataset.prefetch(self._batch_size)
 

+ 2 - 1
TensorFlow2/Segmentation/UNet_Medical/main.py

@@ -51,7 +51,8 @@ def main():
                       augment=params.augment,
                       gpu_id=hvd.rank(),
                       num_gpus=hvd.size(),
-                      seed=params.seed)
+                      seed=params.seed,
+                      amp=params.use_amp)
 
     if 'train' in params.exec_mode:
         train(params, model, dataset, logger)

+ 56 - 0
TensorFlow2/Segmentation/UNet_Medical/model/tf_trt.py

@@ -0,0 +1,56 @@
+import os
+from operator import itemgetter
+
+import tensorflow as tf
+from tensorflow.python.compiler.tensorrt import trt_convert as trt
+from tensorflow.compat.v1.saved_model import tag_constants, signature_constants
+
+
+def export_model(model_dir, prec, tf_trt_model_dir=None):
+    model = tf.keras.models.load_model(os.path.join(model_dir, f'saved_model_{prec}'))
+    input_shape = [1, 572, 572, 1]
+    dummy_input = tf.constant(tf.zeros(input_shape, dtype=tf.float32 if prec=="fp32" else tf.float16))
+    _ = model(dummy_input, training=False)
+
+    trt_prec = trt.TrtPrecisionMode.FP32 if prec == "fp32" else trt.TrtPrecisionMode.FP16
+    converter = trt.TrtGraphConverterV2(
+        input_saved_model_dir=os.path.join(model_dir, f'saved_model_{prec}'),
+        conversion_params=trt.TrtConversionParams(precision_mode=trt_prec),
+    )
+    converter.convert()
+    tf_trt_model_dir = tf_trt_model_dir or f'/tmp/tf-trt_model_{prec}'
+    converter.save(tf_trt_model_dir)
+    print(f"TF-TRT model saved at {tf_trt_model_dir}")
+
+
+def _force_gpu_resync(func):
+    p = tf.constant(0.)  # Create small tensor to force GPU resync
+
+    def wrapper(*args, **kwargs):
+        rslt = func(*args, **kwargs)
+        (p + 1.).numpy()  # Sync the GPU
+        return rslt
+
+    return wrapper
+
+
+class TFTRTModel:
+    def __init__(self, model_dir, precision, output_tensor_name="output_1"):
+        temp_tftrt_dir = f"/tmp/tf-trt_model_{precision}"
+        export_model(model_dir, precision, temp_tftrt_dir)
+        saved_model_loaded = tf.saved_model.load(temp_tftrt_dir, tags=[tag_constants.SERVING])
+        print(f"TF-TRT model loaded from {temp_tftrt_dir}")
+        self.graph_func = saved_model_loaded.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+        self.output_tensor_name = output_tensor_name
+        self.precision = tf.float16 if precision == "amp" else tf.float32
+
+    def __call__(self, x, **kwargs):
+        return self.infer_step(x)
+
+    #@_force_gpu_resync
+    @tf.function(jit_compile=False)
+    def infer_step(self, batch_x):
+        if batch_x.dtype != self.precision:
+            batch_x = tf.cast(batch_x, self.precision)
+        output = self.graph_func(batch_x)
+        return itemgetter(self.output_tensor_name)(output)

+ 6 - 2
TensorFlow2/Segmentation/UNet_Medical/runtime/arguments.py

@@ -100,9 +100,12 @@ PARSER.add_argument('--use_amp', '--amp', dest='use_amp', action='store_true',
 PARSER.add_argument('--use_xla', '--xla', dest='use_xla', action='store_true',
                     help="""Train using XLA""")
 
-PARSER.add_argument('--use_trt', dest='use_trt', action='store_true',
+PARSER.add_argument('--use_tftrt', dest='use_tftrt', action='store_true',
                     help="""Use TF-TRT""")
 
+PARSER.add_argument('--use_savedmodel', dest='use_savedmodel', action='store_true',
+                    help="""Use SavedModel""")
+
 PARSER.add_argument('--resume_training', dest='resume_training', action='store_true',
                     help="""Resume training from a checkpoint""")
 
@@ -125,7 +128,8 @@ def parse_args(flags):
         'benchmark': flags.benchmark,
         'seed': flags.seed,
         'use_amp': flags.use_amp,
-        'use_trt': flags.use_trt,
+        'use_tftrt': flags.use_tftrt,
+        'use_savedmodel': flags.use_savedmodel,
         'use_xla': flags.use_xla,
         'resume_training': flags.resume_training,
     })

+ 22 - 4
TensorFlow2/Segmentation/UNet_Medical/runtime/run.py

@@ -21,6 +21,7 @@ import tensorflow as tf
 
 from runtime.losses import partial_losses
 from runtime.parse_results import process_performance_stats
+from model.tf_trt import export_model, TFTRTModel
 
 
 def train(params, model, dataset, logger):
@@ -101,6 +102,11 @@ def train(params, model, dataset, logger):
                 break
         if hvd.rank() == 0:
             checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint"))
+            if params.use_savedmodel:
+                prec = 'amp' if params.use_amp else 'fp32'
+                model.save(os.path.join(params.model_dir, f'saved_model_{prec}'))
+                if params.use_tftrt:
+                    export_model(params.model_dir, prec, os.path.join(params.model_dir, f'tf-trt_model_{prec}'))
 
     logger.flush()
 
@@ -110,9 +116,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
         print("No fold specified for evaluation. Please use --fold [int] to select a fold.")
     ce_loss = tf.keras.metrics.Mean(name='ce_loss')
     f1_loss = tf.keras.metrics.Mean(name='dice_loss')
-    checkpoint = tf.train.Checkpoint(model=model)
     if params.model_dir and restore_checkpoint:
-        checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
+        prec = 'amp' if params.use_amp else 'fp32'
+        if params.use_savedmodel:
+            model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
+        elif params.use_tftrt:
+            model = TFTRTModel(model_dir=params.model_dir, precision=prec)
+        else:
+            checkpoint = tf.train.Checkpoint(model=model)
+            checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
 
     def validation_step(features, labels):
         output_map = model(features, training=False)
@@ -135,9 +147,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
 
 
 def predict(params, model, dataset, logger):
-    checkpoint = tf.train.Checkpoint(model=model)
+    prec = 'amp' if params.use_amp else 'fp32'
     if params.model_dir:
-        checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
+        if params.use_savedmodel:
+            model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
+        elif params.use_tftrt:
+            model = TFTRTModel(model_dir=params.model_dir, precision=prec)
+        else:
+            checkpoint = tf.train.Checkpoint(model=model)
+            checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
 
     @tf.function
     def prediction_step(features):

+ 2 - 2
TensorFlow2/Segmentation/UNet_Medical/runtime/setup.py

@@ -57,8 +57,8 @@ def set_flags(params):
 
 
 def prepare_model_dir(params):
-    model_dir = os.path.join(params.model_dir, "model_checkpoint")
-    model_dir = model_dir if (hvd.rank() == 0 and not params.benchmark) else None
+    # model_dir = os.path.join(params.model_dir, "model_checkpoint")
+    model_dir = params.model_dir if (hvd.rank() == 0 and not params.benchmark) else None
     if model_dir is not None:
         os.makedirs(model_dir, exist_ok=True)
         if ('train' in params.exec_mode) and (not params.resume_training):