|
|
@@ -187,12 +187,9 @@ class ResnetModel(object):
|
|
|
reuse=False,
|
|
|
use_final_conv=params['use_final_conv']
|
|
|
)
|
|
|
-
|
|
|
- if mode!=tf.estimator.ModeKeys.PREDICT:
|
|
|
- logits = tf.squeeze(logits)
|
|
|
|
|
|
- if mode!=tf.estimator.ModeKeys.PREDICT:
|
|
|
- logits = tf.squeeze(logits)
|
|
|
+ if params['use_final_conv']:
|
|
|
+ logits = tf.squeeze(logits, axis=[-2, -1])
|
|
|
|
|
|
y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)
|
|
|
|