Răsfoiți Sursa

Fixed wrong shapes with BS=1

Lukasz Pierscieniewski 5 ani în urmă
părinte
comite
43d061ae87
1 a modificat fișierele cu 2 adăugiri și 5 ștergeri
  1. 2 5
      TensorFlow/Classification/ConvNets/model/resnet.py

+ 2 - 5
TensorFlow/Classification/ConvNets/model/resnet.py

@@ -187,12 +187,9 @@ class ResnetModel(object):
                 reuse=False,
                 use_final_conv=params['use_final_conv']
             )
-            
-            if mode!=tf.estimator.ModeKeys.PREDICT:
-                logits = tf.squeeze(logits)
 
-            if mode!=tf.estimator.ModeKeys.PREDICT:
-                logits = tf.squeeze(logits)
+            if params['use_final_conv']:
+                logits = tf.squeeze(logits, axis=[-2, -1])
 
             y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)