|
|
@@ -21,6 +21,7 @@ import tensorflow as tf
|
|
|
|
|
|
from runtime.losses import partial_losses
|
|
|
from runtime.parse_results import process_performance_stats
|
|
|
+from model.tf_trt import export_model, TFTRTModel
|
|
|
|
|
|
|
|
|
def train(params, model, dataset, logger):
|
|
|
@@ -101,6 +102,11 @@ def train(params, model, dataset, logger):
|
|
|
break
|
|
|
if hvd.rank() == 0:
|
|
|
checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint"))
|
|
|
+ if params.use_savedmodel:
|
|
|
+ prec = 'amp' if params.use_amp else 'fp32'
|
|
|
+ model.save(os.path.join(params.model_dir, f'saved_model_{prec}'))
|
|
|
+ if params.use_tftrt:
|
|
|
+ export_model(params.model_dir, prec, os.path.join(params.model_dir, f'tf-trt_model_{prec}'))
|
|
|
|
|
|
logger.flush()
|
|
|
|
|
|
@@ -110,9 +116,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
|
|
|
print("No fold specified for evaluation. Please use --fold [int] to select a fold.")
|
|
|
ce_loss = tf.keras.metrics.Mean(name='ce_loss')
|
|
|
f1_loss = tf.keras.metrics.Mean(name='dice_loss')
|
|
|
- checkpoint = tf.train.Checkpoint(model=model)
|
|
|
if params.model_dir and restore_checkpoint:
|
|
|
- checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
|
|
|
+ prec = 'amp' if params.use_amp else 'fp32'
|
|
|
+ if params.use_savedmodel:
|
|
|
+ model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
|
|
|
+ elif params.use_tftrt:
|
|
|
+ model = TFTRTModel(model_dir=params.model_dir, precision=prec)
|
|
|
+ else:
|
|
|
+ checkpoint = tf.train.Checkpoint(model=model)
|
|
|
+ checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
|
|
|
|
|
|
def validation_step(features, labels):
|
|
|
output_map = model(features, training=False)
|
|
|
@@ -135,9 +147,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
|
|
|
|
|
|
|
|
|
def predict(params, model, dataset, logger):
|
|
|
- checkpoint = tf.train.Checkpoint(model=model)
|
|
|
+ prec = 'amp' if params.use_amp else 'fp32'
|
|
|
if params.model_dir:
|
|
|
- checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
|
|
|
+ if params.use_savedmodel:
|
|
|
+ model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
|
|
|
+ elif params.use_tftrt:
|
|
|
+ model = TFTRTModel(model_dir=params.model_dir, precision=prec)
|
|
|
+ else:
|
|
|
+ checkpoint = tf.train.Checkpoint(model=model)
|
|
|
+ checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()
|
|
|
|
|
|
@tf.function
|
|
|
def prediction_step(features):
|