Răsfoiți Sursa

[Convnets/TF] Basic CPU model support

Lukasz Pierscieniewski 4 ani în urmă
părinte
comite
33110132cc
17 a modificat fișierele cu 78 adăugiri și 72 ștergeri
  1. 1 0
      TensorFlow/Classification/ConvNets/main.py
  2. 13 13
      TensorFlow/Classification/ConvNets/model/resnet.py
  3. 1 1
      TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX1_RN50_AMP_250E.sh
  4. 1 1
      TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX1_RN50_FP32_250E.sh
  5. 1 1
      TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX2_RN50_AMP_250E.sh
  6. 1 1
      TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX2_RN50_FP32_250E.sh
  7. 1 1
      TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX1_RNxt101-32x4d_AMP_250E.sh
  8. 1 1
      TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX1_RNxt101-32x4d_FP32_250E.sh
  9. 1 1
      TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX2_RNxt101-32x4d_AMP_250E.sh
  10. 1 1
      TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX2_RNxt101-32x4d_FP32_250E.sh
  11. 27 19
      TensorFlow/Classification/ConvNets/runtime/runner.py
  12. 1 1
      TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX1_SE-RNxt101-32x4d_AMP_250E.sh
  13. 1 1
      TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX1_SE-RNxt101-32x4d_FP32_250E.sh
  14. 1 1
      TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX2_SE-RNxt101-32x4d_AMP_250E.sh
  15. 1 1
      TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX2_SE-RNxt101-32x4d_FP32_250E.sh
  16. 7 0
      TensorFlow/Classification/ConvNets/utils/cmdline_helper.py
  17. 18 28
      TensorFlow/Classification/ConvNets/utils/hooks/training_hooks.py

+ 1 - 0
TensorFlow/Classification/ConvNets/main.py

@@ -69,6 +69,7 @@ if __name__ == "__main__":
         use_xla=FLAGS.xla,
         use_tf_amp=FLAGS.amp,
         use_dali=FLAGS.dali,
+        use_cpu=FLAGS.cpu,
         gpu_memory_fraction=FLAGS.gpu_memory_fraction,
         gpu_id=FLAGS.gpu_id,
         seed=FLAGS.seed)

+ 13 - 13
TensorFlow/Classification/ConvNets/model/resnet.py

@@ -53,6 +53,7 @@ class ResnetModel(object):
         weight_init='fan_out',
         dtype=tf.float32,
         use_dali=False,
+        use_cpu=False,
         cardinality=1,
         use_se=False,
         se_ratio=1,
@@ -68,6 +69,7 @@ class ResnetModel(object):
             expansions=expansions,
             model_name=model_name,
             use_dali=use_dali,
+            use_cpu=use_cpu,
             cardinality=cardinality,
             use_se=use_se,
             se_ratio=se_ratio
@@ -124,11 +126,13 @@ class ResnetModel(object):
                 # Stage inputs on the host
                 cpu_prefetch_op, (features, labels) = self._stage([features, labels])
 
-            with tf.device('/gpu:0'):
-                # Stage inputs to the device
-                gpu_prefetch_op, (features, labels) = self._stage([features, labels])
+            if not self.model_hparams.use_cpu:
+                with tf.device('/gpu:0'):
+                    # Stage inputs to the device
+                    gpu_prefetch_op, (features, labels) = self._stage([features, labels])
 
-        with tf.device("/gpu:0"):
+        main_device = "/gpu:0" if not self.model_hparams.use_cpu else "/cpu:0"
+        with tf.device(main_device):
 
             if features.dtype != self.model_hparams.dtype:
                 features = tf.cast(features, self.model_hparams.dtype)
@@ -237,14 +241,6 @@ class ResnetModel(object):
                 dllogger.log(data={"Restoring variables from checkpoint": params['finetune_checkpoint']}, step=tuple())
                 tf.train.init_from_checkpoint(params['finetune_checkpoint'], train_var_dict)
 
-        with tf.device("/cpu:0"):
-            if hvd_utils.is_using_hvd():
-                sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var",
-                                       trainable=False)
-                sync_var_assing = sync_var.assign([1], name="signal_handler_var_set")
-                sync_var_reset = sync_var.assign([0], name="signal_handler_var_reset")
-                sync_op = hvd.allreduce(sync_var, op=hvd.Sum, name="signal_handler_all_reduce")
-
         if mode == tf.estimator.ModeKeys.PREDICT:
 
             predictions = {'classes': y_preds, 'probabilities': probs}
@@ -257,7 +253,7 @@ class ResnetModel(object):
 
         else:
 
-            with tf.device("/gpu:0"):
+            with tf.device(main_device):
 
                 if mode == tf.estimator.ModeKeys.TRAIN:
                     acc_top1 = tf.nn.in_top_k(predictions=logits, targets=labels, k=1)
@@ -355,6 +351,10 @@ class ResnetModel(object):
 
                     if self.model_hparams.use_dali:
                         train_ops = tf.group(backprop_op, update_ops, name='train_ops')
+                    elif self.model_hparams.use_cpu:
+                        train_ops = tf.group(
+                            backprop_op, cpu_prefetch_op, update_ops, name='train_ops'
+                        )
                     else:
                         train_ops = tf.group(
                             backprop_op, cpu_prefetch_op, gpu_prefetch_op, update_ops, name='train_ops'

+ 1 - 1
TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX1_RN50_AMP_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnet50 \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=256 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=3.0517578125e-05 \
     --amp --static_loss_scale 128 \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX1_RN50_FP32_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnet50 \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=128 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=3.0517578125e-05 \
     --data_dir=${DATA_DIR}/tfrecords --data_idx_dir=${DATA_DIR}/dali_idx \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX2_RN50_AMP_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnet50 \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=256 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=3.0517578125e-05 \
     --amp --static_loss_scale 128 \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnet50v1.5/training/DGX2_RN50_FP32_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnet50 \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=128 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=3.0517578125e-05 \
     --data_dir=${DATA_DIR}/tfrecords --data_idx_dir=${DATA_DIR}/dali_idx \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX1_RNxt101-32x4d_AMP_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=128 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --amp --static_loss_scale 128 \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX1_RNxt101-32x4d_FP32_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=64 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --data_dir=${DATA_DIR}/tfrecords --data_idx_dir=${DATA_DIR}/dali_idx \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX2_RNxt101-32x4d_AMP_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=128 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --amp --static_loss_scale 128 \

+ 1 - 1
TensorFlow/Classification/ConvNets/resnext101-32x4d/training/DGX2_RNxt101-32x4d_FP32_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=64 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --data_dir=${DATA_DIR}/tfrecords --data_idx_dir=${DATA_DIR}/dali_idx \

+ 27 - 19
TensorFlow/Classification/ConvNets/runtime/runner.py

@@ -61,6 +61,7 @@ class Runner(object):
             use_xla=False,
             use_tf_amp=False,
             use_dali=False,
+            use_cpu=False,
             gpu_memory_fraction=1.0,
             gpu_id=0,
 
@@ -136,6 +137,7 @@ class Runner(object):
                                                              use_tf_amp=use_tf_amp,
                                                              use_xla=use_xla,
                                                              use_dali=use_dali,
+                                                             use_cpu=use_cpu,
                                                              gpu_memory_fraction=gpu_memory_fraction,
                                                              gpu_id=gpu_id)
 
@@ -161,6 +163,7 @@ class Runner(object):
                                          dtype=model_hparams.dtype,
                                          weight_init=weight_init,
                                          use_dali=use_dali,
+                                         use_cpu=use_cpu,
                                          cardinality=architecture['cardinality'] if 'cardinality' in architecture else 1,
                                          use_se=architecture['use_se'] if 'use_se' in architecture else False,
                                          se_ratio=architecture['se_ratio'] if 'se_ratio' in architecture else 1)
@@ -200,42 +203,45 @@ class Runner(object):
             return worker_batch_size
 
     @staticmethod
-    def _get_session_config(mode, use_xla, use_dali, gpu_memory_fraction, gpu_id=0):
+    def _get_session_config(mode, use_xla, use_dali, use_cpu, gpu_memory_fraction, gpu_id=0):
 
         if mode not in ["train", 'validation', 'benchmark', 'inference']:
             raise ValueError("Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')" %
                              mode)
 
-        # Limit available GPU memory (tune the size)
-        if use_dali:
-            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
-            config = tf.ConfigProto(gpu_options=gpu_options)
-            config.gpu_options.allow_growth = False
-        else:
-            config = tf.ConfigProto()
-            config.gpu_options.allow_growth = True
+        config = tf.ConfigProto()
+        if not use_cpu:
+            # Limit available GPU memory (tune the size)
+            if use_dali:
+                gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
+                config = tf.ConfigProto(gpu_options=gpu_options)
+                config.gpu_options.allow_growth = False
+            else:
+                config.gpu_options.allow_growth = True
 
-        config.allow_soft_placement = True
-        config.log_device_placement = False
+            config.allow_soft_placement = True
+            config.log_device_placement = False
 
-        config.gpu_options.visible_device_list = str(gpu_id)
+            config.gpu_options.visible_device_list = str(gpu_id)
+            config.gpu_options.force_gpu_compatible = True  # Force pinned memory
 
-        if hvd_utils.is_using_hvd():
-            config.gpu_options.visible_device_list = str(hvd.local_rank())
+            if hvd_utils.is_using_hvd():
+                config.gpu_options.visible_device_list = str(hvd.local_rank())
+
+            config.gpu_options.force_gpu_compatible = True  # Force pinned memory
 
         if use_xla:
             config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
 
-        config.gpu_options.force_gpu_compatible = True  # Force pinned memory
-
         if mode == 'train':
-            config.intra_op_parallelism_threads = 1  # Avoid pool of Eigen threads
-            config.inter_op_parallelism_threads = max(2, (multiprocessing.cpu_count() // max(hvd.size(), 8) - 2))
+            if not use_cpu:
+                config.intra_op_parallelism_threads = 1  # Avoid pool of Eigen threads
+                config.inter_op_parallelism_threads = max(2, (multiprocessing.cpu_count() // max(hvd.size(), 8) - 2))
 
         return config
 
     @staticmethod
-    def _get_run_config(mode, model_dir, use_xla, use_dali, gpu_memory_fraction, gpu_id=0, seed=None):
+    def _get_run_config(mode, model_dir, use_xla, use_dali, use_cpu, gpu_memory_fraction, gpu_id=0, seed=None):
 
         if mode not in ["train", 'validation', 'benchmark', 'inference']:
             raise ValueError("Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')" %
@@ -258,6 +264,7 @@ class Runner(object):
             session_config=Runner._get_session_config(mode=mode,
                                                       use_xla=use_xla,
                                                       use_dali=use_dali,
+                                                      use_cpu=use_cpu,
                                                       gpu_memory_fraction=gpu_memory_fraction,
                                                       gpu_id=gpu_id),
             keep_checkpoint_max=5,
@@ -288,6 +295,7 @@ class Runner(object):
                                             model_dir=self.run_hparams.model_dir,
                                             use_xla=use_xla,
                                             use_dali=use_dali,
+                                            use_cpu=self.run_hparams.use_cpu,
                                             gpu_memory_fraction=gpu_memory_fraction,
                                             gpu_id=gpu_id,
                                             seed=self.run_hparams.seed)

+ 1 - 1
TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX1_SE-RNxt101-32x4d_AMP_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=se-resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=96 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --amp --static_loss_scale 128 \

+ 1 - 1
TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX1_SE-RNxt101-32x4d_FP32_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=se-resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=64 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --data_dir=${DATA_DIR}/tfrecords --data_idx_dir=${DATA_DIR}/dali_idx \

+ 1 - 1
TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX2_SE-RNxt101-32x4d_AMP_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=96 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --amp --static_loss_scale 128 \

+ 1 - 1
TensorFlow/Classification/ConvNets/se-resnext101-32x4d/training/DGX2_SE-RNxt101-32x4d_FP32_250E.sh

@@ -24,7 +24,7 @@ if [[ ! -z "${BIND_TO_SOCKET}" ]]; then
 fi
 
 mpiexec --allow-run-as-root ${BIND_TO_SOCKET} -np 8 python3 main.py --arch=se-resnext101-32x4d \
-    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --muxup=0.2 \
+    --mode=train_and_evaluate --iter_unit=epoch --num_iter=250 --mixup=0.2 \
     --batch_size=64 --warmup_steps=100 --cosine_lr --label_smoothing 0.1 \
     --lr_init=0.256 --lr_warmup_epochs=8 --momentum=0.875 --weight_decay=6.103515625e-05 \
     --data_dir=${DATA_DIR}/tfrecords --data_idx_dir=${DATA_DIR}/dali_idx \

+ 7 - 0
TensorFlow/Classification/ConvNets/utils/cmdline_helper.py

@@ -129,6 +129,13 @@ class ArgumentParserUtil(object):
                                   required=False,
                                   help="Enable Automatic Mixed Precision to speedup computation using tensor cores.")
 
+        goptim_group.add_argument("--cpu",
+                                  action="store_true",
+                                  dest="cpu",
+                                  default=False,
+                                  required=False,
+                                  help="Run model on CPU instead of GPU")
+
         amp_group = self.parser.add_argument_group("Automatic Mixed Precision arguments")
         amp_group.add_argument("--static_loss_scale",
                                "--loss_scale",

+ 18 - 28
TensorFlow/Classification/ConvNets/utils/hooks/training_hooks.py

@@ -118,47 +118,37 @@ class TrainingPartitionHook(tf.estimator.SessionRunHook):
     def __init__(self, sync_freq=10):
         super().__init__()
         self.signal_recieved = False
-        self.should_sync_params = False
         self.sync_freq = sync_freq
         self.global_step = 0
 
-        self.should_exit = False
-
         signal.signal(signal.SIGUSR1, self._signal_handler)
         signal.signal(signal.SIGTERM, self._signal_handler)
 
-    def before_run(self, run_context):
-        fetches = [tf.train.get_global_step()]
-
+    def begin(self):
         if is_using_hvd():
-            fetches.append(
-                "signal_handler_var_set:0" if self.signal_recieved else "signal_handler_var:0")
-
-
-            if self.should_exit:
-                fetches.append("signal_handler_var_reset:0")
-            elif self.signal_recieved:
-                fetches.append("signal_handler_var_set:0")
-            else:
-                fetches.append("signal_handler_var:0")
+            with tf.device("/cpu:0"):
+                self.input_op = tf.placeholder(tf.int32, shape=())
+                self.allreduce_op = hvd.allreduce(self.input_op, op=hvd.Sum, 
+                                                  name="signal_handler_all_reduce")
 
-            if ((self.global_step % self.sync_freq) == 0) and not self.should_exit:
-                fetches.append("signal_handler_all_reduce:0")
+    def before_run(self, run_context):
+        fetches = [tf.train.get_global_step()]
+        feed_dict = None
 
-        run_args = tf.train.SessionRunArgs(fetches)
-        return run_args
+        if is_using_hvd() and (self.global_step % self.sync_freq) == 0:
+            fetches += [self.allreduce_op]
+            feed_dict = {self.input_op: int(self.signal_recieved)}
+            
+        return tf.train.SessionRunArgs(fetches, feed_dict=feed_dict)
 
     def after_run(self, run_context, run_values):
-        self.global_step = run_values.results[0]
+        self.global_step = run_values.results[0] + 1
 
-        if self.should_exit:
+        if is_using_hvd() and len(run_values.results) == 2:
+            if run_values.results[1] > 0:
+                run_context.request_stop()
+        elif self.signal_recieved:
             run_context.request_stop()
-            return
-
-        if is_using_hvd() and len(run_values.results) == 3:
-            self.should_exit = (run_values.results[2][0] == hvd.size())
-        else:
-            self.should_exit = self.signal_recieved
 
     def _signal_handler(self, signum, frame):
         print("Stop signal received")