|
|
@@ -6,10 +6,10 @@ import os
|
|
|
import tensorflow as tf
|
|
|
|
|
|
import horovod.tensorflow as hvd
|
|
|
-from model import resnet_v1_5
|
|
|
+from model import resnet
|
|
|
|
|
|
tf.app.flags.DEFINE_string(
|
|
|
- 'model_name', 'resnet50_v1.5', 'The name of the architecture to save. The default name was being '
|
|
|
+ 'model_name', 'resnet50', 'The name of the architecture to save. The default name was being '
|
|
|
'used to train the model')
|
|
|
|
|
|
tf.app.flags.DEFINE_integer(
|
|
|
@@ -26,10 +26,10 @@ tf.app.flags.DEFINE_integer(
|
|
|
'be specified at model runtime.')
|
|
|
|
|
|
|
|
|
-tf.app.flags.DEFINE_string('input_format', 'NHWC',
|
|
|
+tf.app.flags.DEFINE_string('input_format', 'NCHW',
|
|
|
'The dataformat used by the layers in the model')
|
|
|
|
|
|
-tf.app.flags.DEFINE_string('compute_format', 'NHWC',
|
|
|
+tf.app.flags.DEFINE_string('compute_format', 'NCHW',
|
|
|
'The dataformat used by the layers in the model')
|
|
|
|
|
|
tf.app.flags.DEFINE_string('checkpoint', '',
|
|
|
@@ -72,15 +72,24 @@ def main(_):
|
|
|
else:
|
|
|
input_shape = [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]
|
|
|
input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
|
|
|
- network = resnet_v1_5.ResnetModel(FLAGS.model_name, FLAGS.num_classes, FLAGS.compute_format, FLAGS.input_format)
|
|
|
+
|
|
|
+ resnet50_config = resnet.model_architectures[FLAGS.model_name]
|
|
|
+ network = resnet.ResnetModel(FLAGS.model_name,
|
|
|
+ FLAGS.num_classes,
|
|
|
+ resnet50_config['layers'],
|
|
|
+ resnet50_config['widths'],
|
|
|
+ resnet50_config['expansions'],
|
|
|
+ FLAGS.compute_format,
|
|
|
+ FLAGS.input_format)
|
|
|
probs, logits = network.build_model(
|
|
|
input_images,
|
|
|
training=False,
|
|
|
reuse=False,
|
|
|
use_final_conv=FLAGS.use_final_conv)
|
|
|
-
|
|
|
+
|
|
|
if FLAGS.quantize:
|
|
|
- tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq)
|
|
|
+ tf.contrib.quantize.experimental_create_eval_graph(symmetric=FLAGS.symmetric,
|
|
|
+ use_qdq=FLAGS.use_qdq)
|
|
|
|
|
|
# Define the saver and restore the checkpoint
|
|
|
saver = tf.train.Saver()
|
|
|
@@ -101,4 +110,4 @@ def main(_):
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- tf.app.run()
|
|
|
+ tf.app.run()
|