Explorar el Código

Merge: [ConvNets/PyT] Enable logging gradient scale

Krzysztof Kudrynski hace 3 años
padre
commit
bd1fb86919

+ 7 - 0
PyTorch/Classification/ConvNets/image_classification/logger.py

@@ -370,6 +370,7 @@ class TrainingMetrics(Metrics):
                 "data_time": ["train.data_time"],
                 "compute_time": ["train.compute_time"],
                 "lr": ["train.lr"],
+                "grad_scale": ["train.grad_scale"],
             }
             logger.register_metric(
                 "train.loss",
@@ -406,6 +407,12 @@ class TrainingMetrics(Metrics):
                 LR_METER(),
                 verbosity=dllogger.Verbosity.DEFAULT,
             )
+            logger.register_metric(
+                "train.grad_scale",
+                PERF_METER(),
+                verbosity=dllogger.Verbosity.DEFAULT,
+                metadata=Metrics.LOSS_METADATA,
+            )
 
 
 class ValidationMetrics(Metrics):

+ 3 - 0
PyTorch/Classification/ConvNets/image_classification/training.py

@@ -206,6 +206,7 @@ def train(
     train_step,
     train_loader,
     lr_scheduler,
+    grad_scale_fn,
     log_fn,
     timeout_handler,
     prof=-1,
@@ -238,6 +239,7 @@ def train(
             compute_time=it_time - data_time,
             lr=lr,
             loss=reduced_loss.item(),
+            grad_scale=grad_scale_fn(),
         )
 
         end = time.time()
@@ -364,6 +366,7 @@ def train_loop(
                     training_step,
                     data_iter,
                     lambda i: lr_scheduler(trainer.optimizer, i, epoch),
+                    trainer.executor.scaler.get_scale,
                     train_metrics.log,
                     timeout_handler,
                     prof=prof,

+ 3 - 2
PyTorch/Classification/ConvNets/main.py

@@ -416,6 +416,7 @@ def prepare_for_training(args, model_args, model_arch):
         print("BSM: {}".format(batch_size_multiplier))
 
     start_epoch = 0
+    best_prec1 = 0
     # optionally resume from a checkpoint
     if args.resume is not None:
         if os.path.isfile(args.resume):
@@ -603,13 +604,12 @@ def prepare_for_training(args, model_args, model_arch):
         val_loader,
         logger,
         start_epoch,
+        best_prec1,
     )
 
 
 def main(args, model_args, model_arch):
     exp_start_time = time.time()
-    global best_prec1
-    best_prec1 = 0
 
     (
         trainer,
@@ -619,6 +619,7 @@ def main(args, model_args, model_arch):
         val_loader,
         logger,
         start_epoch,
+        best_prec1,
     ) = prepare_for_training(args, model_args, model_arch)
 
     train_loop(