|
|
@@ -21,6 +21,7 @@ from __future__ import print_function
|
|
|
import tensorflow as tf
|
|
|
|
|
|
import horovod.tensorflow as hvd
|
|
|
+import dllogger
|
|
|
|
|
|
from model import layers
|
|
|
from model import blocks
|
|
|
@@ -183,8 +184,12 @@ class ResnetModel(object):
|
|
|
probs, logits = self.build_model(
|
|
|
features,
|
|
|
training=mode == tf.estimator.ModeKeys.TRAIN,
|
|
|
- reuse=False
|
|
|
+ reuse=False,
|
|
|
+ use_final_conv=params['use_final_conv']
|
|
|
)
|
|
|
+
|
|
|
+ if mode!=tf.estimator.ModeKeys.PREDICT:
|
|
|
+ logits = tf.squeeze(logits)
|
|
|
|
|
|
y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)
|
|
|
|
|
|
@@ -196,16 +201,25 @@ class ResnetModel(object):
|
|
|
tf.identity(logits, name="logits_ref")
|
|
|
tf.identity(probs, name="probs_ref")
|
|
|
tf.identity(y_preds, name="y_preds_ref")
|
|
|
-
|
|
|
- #if mode == tf.estimator.ModeKeys.TRAIN:
|
|
|
- #
|
|
|
- # assert (len(tf.trainable_variables()) == 161)
|
|
|
- #
|
|
|
- #else:
|
|
|
- #
|
|
|
- # assert (len(tf.trainable_variables()) == 0)
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+ if mode == tf.estimator.ModeKeys.TRAIN and params['quantize']:
|
|
|
+ dllogger.log(data={"QUANTIZATION AWARE TRAINING ENABLED": True}, step=tuple())
|
|
|
+ if params['symmetric']:
|
|
|
+ dllogger.log(data={"MODE":"USING SYMMETRIC MODE"}, step=tuple())
|
|
|
+ tf.contrib.quantize.experimental_create_training_graph(tf.get_default_graph(), symmetric=True, use_qdq=params['use_qdq'] ,quant_delay=params['quant_delay'])
|
|
|
+ else:
|
|
|
+ dllogger.log(data={"MODE":"USING ASSYMETRIC MODE"}, step=tuple())
|
|
|
+ tf.contrib.quantize.create_training_graph(tf.get_default_graph(), quant_delay=params['quant_delay'], use_qdq=params['use_qdq'])
|
|
|
+
|
|
|
+ # Fix for restoring variables during fine-tuning of Resnet-50
|
|
|
+ if 'finetune_checkpoint' in params.keys():
|
|
|
+ train_vars = tf.trainable_variables()
|
|
|
+ train_var_dict = {}
|
|
|
+ for var in train_vars:
|
|
|
+ train_var_dict[var.op.name] = var
|
|
|
+ dllogger.log(data={"Restoring variables from checkpoint": params['finetune_checkpoint']}, step=tuple())
|
|
|
+ tf.train.init_from_checkpoint(params['finetune_checkpoint'], train_var_dict)
|
|
|
+
|
|
|
if mode == tf.estimator.ModeKeys.PREDICT:
|
|
|
|
|
|
predictions = {'classes': y_preds, 'probabilities': probs}
|
|
|
@@ -352,7 +366,7 @@ class ResnetModel(object):
|
|
|
|
|
|
|
|
|
|
|
|
- def build_model(self, inputs, training=True, reuse=False):
|
|
|
+ def build_model(self, inputs, training=True, reuse=False, use_final_conv=False):
|
|
|
|
|
|
with var_storage.model_variable_scope(
|
|
|
self.model_hparams.model_name,
|
|
|
@@ -416,20 +430,37 @@ class ResnetModel(object):
|
|
|
|
|
|
with tf.variable_scope("output"):
|
|
|
net = layers.reduce_mean(
|
|
|
- net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')
|
|
|
-
|
|
|
- logits = layers.dense(
|
|
|
- inputs=net,
|
|
|
- units=self.model_hparams.n_classes,
|
|
|
- use_bias=True,
|
|
|
- trainable=training,
|
|
|
- kernel_initializer=self.dense_hparams.kernel_initializer,
|
|
|
- bias_initializer=self.dense_hparams.bias_initializer)
|
|
|
+ net, keepdims=use_final_conv, data_format=self.model_hparams.compute_format, name='spatial_mean')
|
|
|
+
|
|
|
+ if use_final_conv:
|
|
|
+ logits = layers.conv2d(
|
|
|
+ net,
|
|
|
+ n_channels=self.model_hparams.n_classes,
|
|
|
+ kernel_size=(1, 1),
|
|
|
+ strides=(1, 1),
|
|
|
+ padding='SAME',
|
|
|
+ data_format=self.model_hparams.compute_format,
|
|
|
+ dilation_rate=(1, 1),
|
|
|
+ use_bias=True,
|
|
|
+ kernel_initializer=self.dense_hparams.kernel_initializer,
|
|
|
+ bias_initializer=self.dense_hparams.bias_initializer,
|
|
|
+ trainable=training,
|
|
|
+ name='dense'
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ logits = layers.dense(
|
|
|
+ inputs=net,
|
|
|
+ units=self.model_hparams.n_classes,
|
|
|
+ use_bias=True,
|
|
|
+ trainable=training,
|
|
|
+ kernel_initializer=self.dense_hparams.kernel_initializer,
|
|
|
+ bias_initializer=self.dense_hparams.bias_initializer)
|
|
|
|
|
|
if logits.dtype != tf.float32:
|
|
|
logits = tf.cast(logits, tf.float32)
|
|
|
-
|
|
|
- probs = layers.softmax(logits, name="softmax", axis=1)
|
|
|
+
|
|
|
+ axis = 3 if self.model_hparams.compute_format=="NHWC" and use_final_conv else 1
|
|
|
+ probs = layers.softmax(logits, name="softmax", axis=axis)
|
|
|
|
|
|
return probs, logits
|
|
|
|