Explorar o código

[Jasper/PyT] Switch to native AMP

Mikolaj Blaz %!s(int64=4) %!d(string=hai) anos
pai
achega
e6f507c84a

+ 2 - 2
PyTorch/SpeechRecognition/Jasper/common/dali/iterator.py

@@ -49,10 +49,10 @@ class DaliJasperIterator(object):
         from nvidia.dali.plugin.pytorch import DALIGenericIterator
         from nvidia.dali.plugin.base_iterator import LastBatchPolicy
 
-        # in train pipeline shard_size is set to divisable by batch_size, so PARTIAL policy is safe
         self.dali_it = DALIGenericIterator(
             dali_pipelines, ["audio", "label", "audio_shape"], reader_name=reader_name,
-            dynamic_shape=True, auto_reset=True, last_batch_policy=LastBatchPolicy.PARTIAL)
+            dynamic_shape=True, auto_reset=True,
+            last_batch_policy=(LastBatchPolicy.DROP if train_iterator else LastBatchPolicy.PARTIAL))
 
     @staticmethod
     def _str2list(s: str):

+ 2 - 8
PyTorch/SpeechRecognition/Jasper/common/features.py

@@ -5,8 +5,6 @@ import librosa
 import torch
 import torch.nn as nn
 
-from apex import amp
-
 
 class BaseFeatures(nn.Module):
     """Base class for GPU accelerated audio preprocessing."""
@@ -42,14 +40,10 @@ class BaseFeatures(nn.Module):
     def calculate_features(self, audio, audio_lens):
         return audio, audio_lens
 
-    def __call__(self, audio, audio_lens, optim_level=0):
+    def __call__(self, audio, audio_lens):
         dtype = audio.dtype
         audio = audio.float()
-        if optim_level == 1:
-            with amp.disable_casts():
-                feat, feat_lens = self.calculate_features(audio, audio_lens)
-        else:
-            feat, feat_lens = self.calculate_features(audio, audio_lens)
+        feat, feat_lens = self.calculate_features(audio, audio_lens)
 
         feat = self.apply_padding(feat)
 

+ 5 - 11
PyTorch/SpeechRecognition/Jasper/common/helpers.py

@@ -17,8 +17,6 @@ import os
 import re
 from collections import OrderedDict
 
-from apex import amp
-
 import torch
 import torch.distributed as dist
 
@@ -187,11 +185,9 @@ def convert_v1_state_dict(state_dict):
 
 class Checkpointer(object):
 
-    def __init__(self, save_dir, model_name, keep_milestones=[100,200,300],
-                 use_amp=False):
+    def __init__(self, save_dir, model_name, keep_milestones=[100, 200, 300]):
         self.save_dir = save_dir
         self.keep_milestones = keep_milestones
-        self.use_amp = use_amp
         self.model_name = model_name
 
         tracked = [
@@ -200,7 +196,7 @@ class Checkpointer(object):
         tracked = sorted(tracked, key=lambda t: t[0])
         self.tracked = OrderedDict(tracked)
 
-    def save(self, model, ema_model, optimizer, epoch, step, best_wer,
+    def save(self, model, ema_model, optimizer, scaler, epoch, step, best_wer,
              is_best=False):
         """Saves model checkpoint for inference/resuming training.
 
@@ -234,7 +230,7 @@ class Checkpointer(object):
             'state_dict': unwrap_ddp(model).state_dict(),
             'ema_state_dict': unwrap_ddp(ema_model).state_dict() if ema_model is not None else None,
             'optimizer': optimizer.state_dict(),
-            'amp': amp.state_dict() if self.use_amp else None,
+            'scaler': scaler.state_dict(),
         }
 
         if is_best:
@@ -272,7 +268,7 @@ class Checkpointer(object):
         else:
             return None
 
-    def load(self, fpath, model, ema_model, optimizer, meta):
+    def load(self, fpath, model, ema_model, optimizer, scaler, meta):
 
         print_once(f'Loading model from {fpath}')
         checkpoint = torch.load(fpath, map_location="cpu")
@@ -292,9 +288,7 @@ class Checkpointer(object):
             unwrap_ddp(ema_model).load_state_dict(state_dict, strict=True)
 
         optimizer.load_state_dict(checkpoint['optimizer'])
-
-        if self.use_amp:
-            amp.load_state_dict(checkpoint['amp'])
+        scaler.load_state_dict(checkpoint['scaler'])
 
         meta['start_epoch'] = checkpoint.get('epoch')
         meta['best_wer'] = checkpoint.get('best_wer', meta['best_wer'])

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX1-16GB_Jasper_AMP_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=4 bash scripts/train.sh "$@"
+NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=4 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX1-16GB_Jasper_FP32_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=4 bash scripts/train.sh "$@"
+NUM_GPUS=8 BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=4 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX1-32GB_Jasper_AMP_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"
+NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX1-32GB_Jasper_FP32_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=2 bash scripts/train.sh "$@"
+NUM_GPUS=8 BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=2 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX2_Jasper_AMP_16GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=16 AMP=true BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"
+NUM_GPUS=16 AMP=true BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX2_Jasper_AMP_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"
+NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX2_Jasper_FP32_16GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=16 BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"
+NUM_GPUS=16 BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGX2_Jasper_FP32_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=2 bash scripts/train.sh "$@"
+NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=2 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGXA100_Jasper_AMP_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"
+NUM_GPUS=8 AMP=true BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=1 bash scripts/train.sh "$@"

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/platform/DGXA100_Jasper_TF32_8GPU.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-NUM_GPUS=8 BATCH_SIZE=64 GRADIENT_ACCUMULATION_STEPS=2 bash scripts/train.sh "$@"
+NUM_GPUS=8 BATCH_SIZE=64 GRAD_ACCUMULATION_STEPS=2 bash scripts/train.sh "$@"

+ 58 - 44
PyTorch/SpeechRecognition/Jasper/train.py

@@ -27,8 +27,7 @@ import torch
 import numpy as np
 import torch.cuda.profiler as profiler
 import torch.distributed as dist
-from apex import amp
-from apex.parallel import DistributedDataParallel
+from contextlib import suppress as empty_context
 
 from common import helpers
 from common.dali.data_loader import DaliDataLoader
@@ -59,7 +58,7 @@ def parse_args():
     training.add_argument('--cudnn_benchmark', action='store_true', default=True,
                           help='Enable cudnn benchmark')
     training.add_argument('--amp', '--fp16', action='store_true', default=False,
-                          help='Use mixed precision training')
+                          help='Use pytorch native mixed precision training')
     training.add_argument('--seed', default=42, type=int, help='Random seed')
     training.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0),
                           type=int, help='GPU id used for distributed training')
@@ -158,15 +157,16 @@ def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
                 # with DALI, the data is already on GPU
                 feat, feat_lens, txt, txt_lens = batch
                 if val_feat_proc is not None:
-                    feat, feat_lens = val_feat_proc(feat, feat_lens, use_amp)
+                    feat, feat_lens = val_feat_proc(feat, feat_lens)
             else:
                 batch = [t.cuda(non_blocking=True) for t in batch]
                 audio, audio_lens, txt, txt_lens = batch
-                feat, feat_lens = val_feat_proc(audio, audio_lens, use_amp)
+                feat, feat_lens = val_feat_proc(audio, audio_lens)
 
-            log_probs, enc_lens = model.forward(feat, feat_lens)
-            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
-            pred = greedy_decoder(log_probs)
+            with torch.cuda.amp.autocast(enabled=use_amp):
+                log_probs, enc_lens = model(feat, feat_lens)
+                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
+                pred = greedy_decoder(log_probs)
 
             agg['losses'] += helpers.gather_losses([loss])
             agg['preds'] += helpers.gather_predictions([pred], labels)
@@ -323,37 +323,34 @@ def main():
     else:
         raise ValueError(f'Invalid optimizer "{args.optimizer}"')
 
+    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
+
     adjust_lr = lambda step, epoch, optimizer: lr_policy(
         step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch,
         warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
         num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr,
         exp_gamma=args.lr_exp_gamma)
 
-    if args.amp:
-        model, optimizer = amp.initialize(
-            min_loss_scale=1.0, models=model, optimizers=optimizer,
-            opt_level='O1', max_loss_scale=512.0)
-
     if args.ema > 0:
         ema_model = copy.deepcopy(model)
     else:
         ema_model = None
 
     if multi_gpu:
-        model = DistributedDataParallel(model)
-
+        model = torch.nn.parallel.DistributedDataParallel(
+            model, device_ids=[args.local_rank], output_device=args.local_rank)
     if args.pyprof:
         pyprof.init(enable_function_stack=True)
 
     # load checkpoint
     meta = {'best_wer': 10**6, 'start_epoch': 0}
     checkpointer = Checkpointer(args.output_dir, 'Jasper',
-                                args.keep_milestones, args.amp)
+                                args.keep_milestones)
     if args.resume:
         args.ckpt = checkpointer.last_checkpoint() or args.ckpt
 
     if args.ckpt is not None:
-        checkpointer.load(args.ckpt, model, ema_model, optimizer, meta)
+        checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)
 
     start_epoch = meta['start_epoch']
     best_wer = meta['best_wer']
@@ -380,11 +377,13 @@ def main():
             txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100),
                                 device='cuda')
             txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
-            log_probs, enc_lens = model(feat, feat_lens)
-            del feat
-            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
+            with torch.cuda.amp.autocast(enabled=args.amp):
+                log_probs, enc_lens = model(feat, feat_lens)
+                del feat
+                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
             loss.backward()
             model.zero_grad()
+    torch.cuda.empty_cache()
 
     bmark_stats = BenchmarkStats()
 
@@ -396,12 +395,11 @@ def main():
         epoch_loss = 0
         accumulated_batches = 0
         epoch_start_time = time.time()
+        epoch_eval_time = 0
 
         for batch in train_loader:
 
             if accumulated_batches == 0:
-                adjust_lr(step, epoch, optimizer)
-                optimizer.zero_grad()
                 step_loss = 0
                 step_utts = 0
                 step_start_time = time.time()
@@ -410,37 +408,49 @@ def main():
                 # with DALI, the data is already on GPU
                 feat, feat_lens, txt, txt_lens = batch
                 if train_feat_proc is not None:
-                    feat, feat_lens = train_feat_proc(feat, feat_lens, args.amp)
+                    feat, feat_lens = train_feat_proc(feat, feat_lens)
             else:
                 batch = [t.cuda(non_blocking=True) for t in batch]
                 audio, audio_lens, txt, txt_lens = batch
-                feat, feat_lens = train_feat_proc(audio, audio_lens, args.amp)
+                feat, feat_lens = train_feat_proc(audio, audio_lens)
 
-            log_probs, enc_lens = model(feat, feat_lens)
+            # Use context manager to prevent redundant accumulation of gradients
+            if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation_steps):
+                ctx = model.no_sync()
+            else:
+                ctx = empty_context()
 
-            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
-            loss /= args.grad_accumulation_steps
+            with ctx:
+                with torch.cuda.amp.autocast(enabled=args.amp):
+                    log_probs, enc_lens = model(feat, feat_lens)
+
+                    loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
+                    loss /= args.grad_accumulation_steps
 
-            if torch.isnan(loss).any():
-                print_once(f'WARNING: loss is NaN; skipping update')
-            else:
                 if multi_gpu:
-                    step_loss += reduce_tensor(loss.data, world_size).item()
+                    reduced_loss = reduce_tensor(loss.data, world_size)
                 else:
-                    step_loss += loss.item()
+                    reduced_loss = loss
 
-                if args.amp:
-                    with amp.scale_loss(loss, optimizer) as scaled_loss:
-                        scaled_loss.backward()
+                if torch.isnan(reduced_loss).any():
+                    print_once(f'WARNING: loss is NaN; skipping update')
+                    continue
                 else:
-                    loss.backward()
-                step_utts += batch[0].size(0) * world_size
-                epoch_utts += batch[0].size(0) * world_size
-                accumulated_batches += 1
+                    step_loss += reduced_loss.item()
+                    step_utts += batch[0].size(0) * world_size
+                    epoch_utts += batch[0].size(0) * world_size
+                    accumulated_batches += 1
+
+                    scaler.scale(loss).backward()
 
             if accumulated_batches % args.grad_accumulation_steps == 0:
                 epoch_loss += step_loss
-                optimizer.step()
+                scaler.step(optimizer)
+                scaler.update()
+
+                adjust_lr(step, epoch, optimizer)
+                optimizer.zero_grad()
+
                 apply_ema(model, ema_model, args.ema)
 
                 if step % args.log_frequency == 0:
@@ -463,14 +473,16 @@ def main():
                 step_start_time = time.time()
 
                 if step % args.eval_frequency == 0:
+                    tik = time.time()
                     wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                    symbols, model, ema_model, ctc_loss,
                                    greedy_decoder, args.amp, use_dali)
 
                     if wer < best_wer and epoch >= args.save_best_from:
-                        checkpointer.save(model, ema_model, optimizer, epoch,
-                                          step, best_wer, is_best=True)
+                        checkpointer.save(model, ema_model, optimizer, scaler,
+                                          epoch, step, best_wer, is_best=True)
                         best_wer = wer
+                    epoch_eval_time += time.time() - tik
 
                 step += 1
                 accumulated_batches = 0
@@ -489,7 +501,8 @@ def main():
         bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
 
         if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
-            checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
+            checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
+                              best_wer)
 
         if 0 < args.epochs_this_job <= epoch - start_epoch:
             print_once(f'Finished after {args.epochs_this_job} epochs.')
@@ -506,7 +519,8 @@ def main():
         evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                  ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)
 
-        checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
+        checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
+                          best_wer)
     flush_log()