|
|
@@ -0,0 +1,221 @@
|
|
|
+# !/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+
|
|
|
+# ==============================================================================
|
|
|
+#
|
|
|
+# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+#
|
|
|
+# ==============================================================================
|
|
|
+
|
|
|
+"""
|
|
|
+Usage:
|
|
|
+
|
|
|
+ python export_saved_model.py \
|
|
|
+ --activation_fn='relu' \
|
|
|
+ --batch_size=16 \
|
|
|
+ --data_format='NCHW' \
|
|
|
+ --input_dtype="fp32" \
|
|
|
+ --export_dir="exported_models" \
|
|
|
+ --model_checkpoint_path="path/to/checkpoint/model.ckpt-2500" \
|
|
|
+ --unet_variant='tinyUNet' \
|
|
|
+ --use_xla \
|
|
|
+ --use_tf_amp
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import absolute_import
|
|
|
+from __future__ import division
|
|
|
+from __future__ import print_function
|
|
|
+
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+import pprint
|
|
|
+
|
|
|
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
+
|
|
|
+import tensorflow as tf
|
|
|
+
|
|
|
+from dllogger.logger import LOGGER
|
|
|
+
|
|
|
+from model.unet import UNet_v1
|
|
|
+from model.blocks.activation_blck import authorized_activation_fn
|
|
|
+
|
|
|
+from utils.cmdline_helper import _add_bool_argument
|
|
|
+
|
|
|
+
|
|
|
+def get_export_flags():
|
|
|
+ parser = argparse.ArgumentParser(description="JoC-UNet_v1-TF-ExportFlags")
|
|
|
+
|
|
|
+ parser.add_argument('--export_dir', default=None, required=True, type=str, help='The export directory.')
|
|
|
+ parser.add_argument('--model_checkpoint_path', default=None, required=True, help='Checkpoint path.')
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ '--data_format',
|
|
|
+ choices=['NHWC', 'NCHW'],
|
|
|
+ type=str,
|
|
|
+ default="NCHW",
|
|
|
+ required=False,
|
|
|
+ help="""Which Tensor format is used for computation inside the mode"""
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ '--input_dtype',
|
|
|
+ choices=['fp32', 'fp16'],
|
|
|
+ type=str,
|
|
|
+ default="fp32",
|
|
|
+ required=False,
|
|
|
+ help="""Tensorflow dtype of the input tensor"""
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ '--unet_variant',
|
|
|
+ default="tinyUNet",
|
|
|
+ choices=UNet_v1.authorized_models_variants,
|
|
|
+ type=str,
|
|
|
+ required=False,
|
|
|
+ help="""Which model size is used. This parameter control directly the size and the number of parameters"""
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ '--activation_fn',
|
|
|
+ choices=authorized_activation_fn,
|
|
|
+ type=str,
|
|
|
+ default="relu",
|
|
|
+ required=False,
|
|
|
+ help="""Which activation function is used after the convolution layers"""
|
|
|
+ )
|
|
|
+
|
|
|
+ _add_bool_argument(
|
|
|
+ parser=parser,
|
|
|
+ name="use_tf_amp",
|
|
|
+ default=False,
|
|
|
+ required=False,
|
|
|
+ help="Enable Automatic Mixed Precision Computation to maximise performance."
|
|
|
+ )
|
|
|
+
|
|
|
+ _add_bool_argument(
|
|
|
+ parser=parser,
|
|
|
+ name="use_xla",
|
|
|
+ default=False,
|
|
|
+ required=False,
|
|
|
+ help="Enable Tensorflow XLA to maximise performance."
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument('--batch_size', default=16, type=int, help='Evaluation batch size.')
|
|
|
+
|
|
|
+ FLAGS, unknown_args = parser.parse_known_args()
|
|
|
+
|
|
|
+ if len(unknown_args) > 0:
|
|
|
+
|
|
|
+ for bad_arg in unknown_args:
|
|
|
+ print("ERROR: Unknown command line arg: %s" % bad_arg)
|
|
|
+
|
|
|
+ raise ValueError("Invalid command line arg(s)")
|
|
|
+
|
|
|
+ return FLAGS
|
|
|
+
|
|
|
+
|
|
|
+def export_model(RUNNING_CONFIG):
|
|
|
+
|
|
|
+ if RUNNING_CONFIG.use_tf_amp:
|
|
|
+ os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
|
|
|
+
|
|
|
+ model = UNet_v1(
|
|
|
+ model_name="UNet_v1",
|
|
|
+ input_format="NHWC",
|
|
|
+ compute_format=RUNNING_CONFIG.data_format,
|
|
|
+ n_output_channels=1,
|
|
|
+ unet_variant=RUNNING_CONFIG.unet_variant,
|
|
|
+ weight_init_method="he_normal",
|
|
|
+ activation_fn=RUNNING_CONFIG.activation_fn
|
|
|
+ )
|
|
|
+
|
|
|
+ config_proto = tf.ConfigProto()
|
|
|
+
|
|
|
+ config_proto.allow_soft_placement = True
|
|
|
+ config_proto.log_device_placement = False
|
|
|
+
|
|
|
+ config_proto.gpu_options.allow_growth = True
|
|
|
+
|
|
|
+ if RUNNING_CONFIG.use_xla: # Only working on single GPU
|
|
|
+ LOGGER.log("XLA is activated - Experimental Feature")
|
|
|
+ config_proto.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
|
|
+
|
|
|
+ config_proto.gpu_options.force_gpu_compatible = True # Force pinned memory
|
|
|
+
|
|
|
+ run_config = tf.estimator.RunConfig(
|
|
|
+ model_dir=None,
|
|
|
+ tf_random_seed=None,
|
|
|
+ save_summary_steps=1e9, # disabled
|
|
|
+ save_checkpoints_steps=None,
|
|
|
+ save_checkpoints_secs=None,
|
|
|
+ session_config=config_proto,
|
|
|
+ keep_checkpoint_max=None,
|
|
|
+ keep_checkpoint_every_n_hours=1e9, # disabled
|
|
|
+ log_step_count_steps=1e9,
|
|
|
+ train_distribute=None,
|
|
|
+ device_fn=None,
|
|
|
+ protocol=None,
|
|
|
+ eval_distribute=None,
|
|
|
+ experimental_distribute=None
|
|
|
+ )
|
|
|
+
|
|
|
+ estimator = tf.estimator.Estimator(
|
|
|
+ model_fn=model,
|
|
|
+ model_dir=RUNNING_CONFIG.model_checkpoint_path,
|
|
|
+ config=run_config,
|
|
|
+ params={'debug_verbosity': 0}
|
|
|
+ )
|
|
|
+
|
|
|
+ LOGGER.log('[*] Exporting the model ...')
|
|
|
+
|
|
|
+ input_type = tf.float32 if RUNNING_CONFIG.input_dtype else tf.float16
|
|
|
+
|
|
|
+ def get_serving_input_receiver_fn():
|
|
|
+
|
|
|
+ input_shape = [RUNNING_CONFIG.batch_size, 512, 512, 1]
|
|
|
+
|
|
|
+ def serving_input_receiver_fn():
|
|
|
+ features = tf.placeholder(dtype=input_type, shape=input_shape, name='input_tensor')
|
|
|
+
|
|
|
+ return tf.estimator.export.TensorServingInputReceiver(features=features, receiver_tensors=features)
|
|
|
+
|
|
|
+ return serving_input_receiver_fn
|
|
|
+
|
|
|
+ export_path = estimator.export_saved_model(
|
|
|
+ export_dir_base=RUNNING_CONFIG.export_dir,
|
|
|
+ serving_input_receiver_fn=get_serving_input_receiver_fn(),
|
|
|
+ checkpoint_path=RUNNING_CONFIG.model_checkpoint_path
|
|
|
+ )
|
|
|
+
|
|
|
+ LOGGER.log('[*] Done! path: `%s`' % export_path.decode())
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+
|
|
|
+ tf.logging.set_verbosity(tf.logging.ERROR)
|
|
|
+ tf.disable_eager_execution()
|
|
|
+
|
|
|
+ flags = get_export_flags()
|
|
|
+
|
|
|
+ for endpattern in [".index", ".meta"]:
|
|
|
+ file_to_check = flags.model_checkpoint_path + endpattern
|
|
|
+ if not os.path.isfile(file_to_check):
|
|
|
+ raise FileNotFoundError("The checkpoint file `%s` does not exist" % file_to_check)
|
|
|
+
|
|
|
+ print(" ========================= Export Flags =========================\n")
|
|
|
+ pprint.pprint(dict(flags._get_kwargs()))
|
|
|
+ print("\n %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
|
|
|
+
|
|
|
+ export_model(flags)
|