Răsfoiți Sursa

[Convnets/MX] Suspend resume support

Lukasz Pierscieniewski 3 ani în urmă
părinte
comite
135fbd91de
1 a modificat fișierele cu 155 adăugiri și 80 ștergeri
  1. 155 80
      MxNet/Classification/RN50v1.5/fit.py

+ 155 - 80
MxNet/Classification/RN50v1.5/fit.py

@@ -36,12 +36,17 @@
 """ train fit utility """
 import logging
 import math
+import glob
 import os
 import random
 import sys
 import time
+import re
 from itertools import starmap
 
+import signal
+import pickle
+
 import dllogger
 import horovod.mxnet as hvd
 import mxnet as mx
@@ -55,6 +60,32 @@ from benchmarking import BenchmarkingDataIter
 from global_metrics import CompositeMeter, MaxMeter, MinMeter, AvgMeter, PercentileMeter
 
 
+class PartitionSignalHandler():
+    def __init__(self, sync_freq: int = 10):
+        self.step = 0
+        self.freq = sync_freq
+
+        self.t = mx.nd.array([0])
+
+        signal.signal(signal.SIGUSR1, self._signal_handler)
+        signal.signal(signal.SIGTERM, self._signal_handler)
+
+    def sync(self) -> bool:
+        if self.step % self.freq == 0:
+            new_sync = hvd.allreduce(self.t, average=False)
+            if new_sync[0] > 0:
+                self.t[0] = 1
+        self.step += 1
+
+        return self.should_end()
+
+    def should_end(self) -> bool:
+        return bool(self.t[0] > 0)
+
+    def _signal_handler(self, signum, frame):
+        self.t[0] = 1
+
+
 def add_fit_args(parser):
     def int_list(x):
         return list(map(int, x.split(',')))
@@ -79,7 +110,7 @@ def add_fit_args(parser):
                        help='the batch size')
     train.add_argument('--num-epochs', type=int, default=90,
                        help='number of epochs')
-    train.add_argument('--run-epochs', type=int, default=-1, 
+    train.add_argument('--run-epochs', type=int, default=-1,
                        help='number of epochs to run in single run')
     train.add_argument('--lr', type=float, default=0.1,
                        help='initial learning rate')
@@ -134,7 +165,8 @@ def get_epoch_size(args, kv):
 
 def get_lr_scheduler(args):
     def multistep_schedule(x):
-        lr = args.lr * (args.lr_factor ** (len(list(filter(lambda step: step <= x, args.lr_steps)))))
+        lr = args.lr * \
+            (args.lr_factor ** (len(list(filter(lambda step: step <= x, args.lr_steps)))))
         warmup_coeff = min(1, x / args.warmup_epochs)
         return warmup_coeff * lr
 
@@ -164,33 +196,49 @@ def get_lr_scheduler(args):
 
 
 def load_model(args, model):
-    if args.load is None:
-        return False
-    model.load_parameters(args.load)
-    logging.info('Loaded model {}'.format(args.load))
-    return True
+    file = list(glob.glob(
+        f"{args.workspace}/{args.model_prefix}_*.params"))
+    if len(file) == 0:
+        return 0
 
+    file = [x for x in sorted(file) if "best.params" not in x][-1]
 
-def save_checkpoint(net, epoch, top1, best_acc, model_prefix, save_frequency, kvstore):
+    epoch = re.match(f".*{args.model_prefix}_([0-9]*)\.params", file)
+    if epoch is None:
+        return 0
+
+    epoch = int(epoch.group(1))
+    model.load_parameters(file)
+    logging.info('Loaded model {}'.format(file))
+    return epoch
+
+
+def save_checkpoint(net, epoch, top1, best_acc, model_prefix, workspace, save_frequency, kvstore, force_save=False):
     if model_prefix is None or save_frequency == 0 or ('horovod' in kvstore and hvd.rank() != 0):
         return
-    if save_frequency > 0 and (epoch + 1) % save_frequency == 0:
+    if (save_frequency > 0 and (epoch + 1) % save_frequency == 0) or force_save:
         fname = '{}_{:04}.params'.format(model_prefix, epoch)
+        fname = os.path.join(workspace, fname)
         net.save_parameters(fname)
-        logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(epoch, fname, top1))
+        logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(
+            epoch, fname, top1))
+
     if top1 > best_acc:
-        fname = '{}_best.params'.format(model_prefix)
+        fname = os.path.join(workspace, f'{model_prefix}_best.params')
         net.save_parameters(fname)
-        logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(epoch, fname, top1))
+        logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(
+            epoch, fname, top1))
 
 
 def model_pred(args, model, image):
     from imagenet_classes import classes
-    output = model(image.reshape(-1, *image.shape))[0].softmax().as_in_context(mx.cpu())
+    output = model(image.reshape(-1, *image.shape)
+                   )[0].softmax().as_in_context(mx.cpu())
     top = output.argsort(is_ascend=False)[:10]
     for i, ind in enumerate(top):
         ind = int(ind.asscalar())
-        logging.info('{:2d}. {:5.2f}% -> {}'.format(i + 1, output[ind].asscalar() * 100, classes[ind]))
+        logging.info('{:2d}. {:5.2f}% -> {}'.format(i + 1,
+                     output[ind].asscalar() * 100, classes[ind]))
 
 
 def reduce_metrics(args, metrics, kvstore):
@@ -214,7 +262,8 @@ def model_score(args, net, val_data, metric, kvstore):
 
     val_data.reset()
 
-    total_batch_size = val_data.batch_size * val_data._num_gpus * (hvd.size() if 'horovod' in kvstore else 1)
+    total_batch_size = val_data.batch_size * val_data._num_gpus * \
+        (hvd.size() if 'horovod' in kvstore else 1)
 
     durations = []
     tic = time.time()
@@ -225,9 +274,11 @@ def model_score(args, net, val_data, metric, kvstore):
             o.wait_to_read()
 
         data = [b.data[0] for b in batches]
-        label = [b.label[0][:len(b.data[0]) - b.pad] for b in batches if len(b.data[0]) != b.pad]
+        label = [b.label[0][:len(b.data[0]) - b.pad]
+                 for b in batches if len(b.data[0]) != b.pad]
         outputs = [net(X) for X, b in zip(data, batches)]
-        outputs = [o[:len(b.data[0]) - b.pad] for o, b in zip(outputs, batches) if len(b.data[0]) != b.pad]
+        outputs = [o[:len(b.data[0]) - b.pad]
+                   for o, b in zip(outputs, batches) if len(b.data[0]) != b.pad]
         metric.update(label, outputs)
 
         durations.append(time.time() - tic)
@@ -263,21 +314,24 @@ def model_fit(args, net, train_data, eval_metric, optimizer,
     loss_metric = ScalarMetric()
 
     if 'horovod' in kvstore:
-        trainer = hvd.DistributedTrainer(net.collect_params(), optimizer, optimizer_params)
+        trainer = hvd.DistributedTrainer(
+            net.collect_params(), optimizer, optimizer_params)
     else:
         trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params,
                                 kvstore=kv, update_on_kvstore=False)
 
     if args.amp:
         amp.init_trainer(trainer)
-    
+
+    partition_handler = PartitionSignalHandler(1)
 
     sparse_label_loss = (args.label_smoothing == 0 and args.mixup == 0)
     loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
     loss.hybridize(static_shape=True, static_alloc=True)
 
     local_batch_size = train_data.batch_size
-    total_batch_size = local_batch_size * train_data._num_gpus * (hvd.size() if 'horovod' in kvstore else 1)
+    total_batch_size = local_batch_size * train_data._num_gpus * \
+        (hvd.size() if 'horovod' in kvstore else 1)
     durations = []
 
     epoch_size = get_epoch_size(args, kv)
@@ -287,16 +341,21 @@ def model_fit(args, net, train_data, eval_metric, optimizer,
         if args.mixup != 0:
             coeffs = mx.nd.array(np.random.beta(args.mixup, args.mixup, size=images.shape[0])).as_in_context(
                 images.context)
-            image_coeffs = coeffs.astype(images.dtype, copy=False).reshape(*coeffs.shape, 1, 1, 1)
-            ret_images = image_coeffs * images + (1 - image_coeffs) * images[::-1]
+            image_coeffs = coeffs.astype(
+                images.dtype, copy=False).reshape(*coeffs.shape, 1, 1, 1)
+            ret_images = image_coeffs * images + \
+                (1 - image_coeffs) * images[::-1]
 
-            ret_labels = label_smoothing(labels, args.num_classes, args.label_smoothing)
+            ret_labels = label_smoothing(
+                labels, args.num_classes, args.label_smoothing)
             label_coeffs = coeffs.reshape(*coeffs.shape, 1)
-            ret_labels = label_coeffs * ret_labels + (1 - label_coeffs) * ret_labels[::-1]
+            ret_labels = label_coeffs * ret_labels + \
+                (1 - label_coeffs) * ret_labels[::-1]
         else:
             ret_images = images
             if not sparse_label_loss:
-                ret_labels = label_smoothing(labels, args.num_classes, args.label_smoothing)
+                ret_labels = label_smoothing(
+                    labels, args.num_classes, args.label_smoothing)
             else:
                 ret_labels = labels
 
@@ -315,76 +374,87 @@ def model_fit(args, net, train_data, eval_metric, optimizer,
 
         logging.info('Starting epoch {}'.format(epoch))
         outputs = []
-        for i, batches in enumerate(train_data):
-            # synchronize to previous iteration
-            #for o in outputs:
-            #    o.wait_to_read()
-
-            trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))
-
-            data = [b.data[0] for b in batches]
-            label = [b.label[0].as_in_context(b.data[0].context) for b in batches]
-            orig_label = label
-
-            data, label = zip(*starmap(transform_data, zip(data, label)))
-
-            outputs = []
-            Ls = []
-            with ag.record():
-                for x, y in zip(data, label):
-                    z = net(x)
-                    L = loss(z, y)
-                    # store the loss and do backward after we have done forward
-                    # on all GPUs for better speed on multiple GPUs.
-                    Ls.append(L)
-                    outputs.append(z)
-
-                if args.amp:
-                    with amp.scale_loss(Ls, trainer) as scaled_loss:
-                        ag.backward(scaled_loss)
+        if not partition_handler.should_end():
+            for i, batches in enumerate(train_data):
+                # synchronize to previous iteration
+                # for o in outputs:
+                #    o.wait_to_read()
+
+                trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))
+
+                data = [b.data[0] for b in batches]
+                label = [b.label[0].as_in_context(
+                    b.data[0].context) for b in batches]
+                orig_label = label
+
+                data, label = zip(*starmap(transform_data, zip(data, label)))
+
+                outputs = []
+                Ls = []
+                with ag.record():
+                    for x, y in zip(data, label):
+                        z = net(x)
+                        L = loss(z, y)
+                        # store the loss and do backward after we have done forward
+                        # on all GPUs for better speed on multiple GPUs.
+                        Ls.append(L)
+                        outputs.append(z)
+
+                    if args.amp:
+                        with amp.scale_loss(Ls, trainer) as scaled_loss:
+                            ag.backward(scaled_loss)
+                    else:
+                        ag.backward(Ls)
+
+                if 'horovod' in kvstore:
+                    trainer.step(local_batch_size)
                 else:
-                    ag.backward(Ls)
-
-            if 'horovod' in kvstore:
-                trainer.step(local_batch_size)
-            else:
-                trainer.step(total_batch_size)
+                    trainer.step(total_batch_size)
 
-            loss_metric.update(..., np.mean([l.asnumpy() for l in Ls]).item())
+                loss_metric.update(..., np.mean(
+                    [l.asnumpy() for l in Ls]).item())
 
-            if args.disp_batches and not (i + 1) % args.disp_batches:
-                dllogger_it_data = {
-                    'train.loss': loss_metric.get()[1],
-                    'train.ips': args.disp_batches * total_batch_size / (time.time() - btic),
-                    'train.lr': trainer.learning_rate
-                }
-                dllogger.log((epoch, i), data=dllogger_it_data)
+                if args.disp_batches and not (i + 1) % args.disp_batches:
+                    dllogger_it_data = {
+                        'train.loss': loss_metric.get()[1],
+                        'train.ips': args.disp_batches * total_batch_size / (time.time() - btic),
+                        'train.lr': trainer.learning_rate
+                    }
+                    dllogger.log((epoch, i), data=dllogger_it_data)
 
-                loss_metric.reset_local()
-                btic = time.time()
+                    loss_metric.reset_local()
+                    btic = time.time()
 
-            durations.append(time.time() - tic)
-            tic = time.time()
+                durations.append(time.time() - tic)
+                tic = time.time()
 
         durations = durations[min(len(durations) // 10, 100):]
         dllogger_epoch_data = {
             'train.loss': loss_metric.get_global()[1],
             'train.ips': total_batch_size / np.mean(durations)
         }
+
+        should_break = partition_handler.sync()
         if args.mode == 'train_val':
             logging.info('Validating epoch {}'.format(epoch))
-            score, duration_stats, _ = model_score(args, net, eval_data, eval_metric, kvstore)
+            score, duration_stats, _ = model_score(
+                args, net, eval_data, eval_metric, kvstore)
 
             dllogger_epoch_data.update(
-                starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score))
+                starmap(lambda key, val: (
+                    'val.{}'.format(key), val), zip(*score))
             )
             dllogger_epoch_data.update(
-                starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())
+                starmap(lambda key, val: ('val.{}'.format(key), val),
+                        duration_stats.items())
             )
 
             score = dict(zip(*score))
             accuracy = score.get('accuracy', -1)
-            save_checkpoint(net, epoch, accuracy, best_accuracy, model_prefix, args.save_frequency, kvstore)
+            save_checkpoint(net, epoch, accuracy, best_accuracy,
+                            model_prefix, args.workspace,
+                            args.save_frequency, kvstore,
+                            force_save=should_break)
             best_accuracy = max(best_accuracy, accuracy)
         global_metrics.update_dict(dllogger_epoch_data)
         dllogger.log(step=(epoch,), data=dllogger_epoch_data)
@@ -446,7 +516,8 @@ def fit(args, model, data_loader):
                 tic = time.time()
         return
 
-    if not load_model(args, model):
+    start_epoch = load_model(args, model)
+    if start_epoch == 0:
         # all initializers should be specified in the model definition.
         # if not, this will raise an error
         model.initialize(mx.init.Initializer())
@@ -516,7 +587,7 @@ def fit(args, model, data_loader):
             args,
             model,
             train,
-            begin_epoch=args.begin_epoch,
+            begin_epoch=start_epoch,
             num_epoch=args.num_epochs,
             run_epoch=args.run_epochs,
             eval_data=val,
@@ -531,15 +602,19 @@ def fit(args, model, data_loader):
         )
     elif args.mode == 'val':
         for epoch in range(args.num_epochs):  # loop for benchmarking
-            score, duration_stats, durations = model_score(args, model, val, eval_metrics, args.kv_store)
-            dllogger_data = dict(starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score)))
+            score, duration_stats, durations = model_score(
+                args, model, val, eval_metrics, args.kv_store)
+            dllogger_data = dict(starmap(lambda key, val: (
+                'val.{}'.format(key), val), zip(*score)))
             dllogger_data.update(
-                starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())
+                starmap(lambda key, val: ('val.{}'.format(key), val),
+                        duration_stats.items())
             )
             global_metrics.update_dict(dllogger_data)
             for percentile in [50, 90, 95, 99, 100]:
                 metric_name = 'val.latency_{}'.format(percentile)
-                dllogger_data[metric_name] = np.percentile(durations, percentile)
+                dllogger_data[metric_name] = np.percentile(
+                    durations, percentile)
                 global_metrics.update_metric(metric_name, durations)
             dllogger.log(step=(epoch,), data=dllogger_data)
     else: