train.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import common.filter_warnings
  15. import argparse
  16. import copy
  17. import io
  18. import os
  19. import sys
  20. import random
  21. from functools import partial
  22. from itertools import cycle, islice
  23. from pathlib import Path
  24. import torch
  25. import numpy as np
  26. from contextlib import suppress as empty_context
  27. from torch.nn.parallel import DistributedDataParallel
  28. import wav2vec2.arg_parser
  29. from common import tb_dllogger as logger
  30. from common.dataset import adjust_max_tokens, get_batch_iterator
  31. from common.fairseq.data import Dictionary
  32. from common.fairseq.dist import ModuleProxyWrapper
  33. from common.fairseq.utils import multiply_grads
  34. from common.helpers import (Checkpointer, num_weights, to_gpu,
  35. init_multi_tensor_ema, apply_multi_tensor_ema)
  36. from common.optimizers import get_optimizer, lr_exp_policy, lr_poly_policy
  37. from common.utils import print_once, set_torch_seed, setup_distributed
  38. from wav2vec2.criterion import Wav2vecCriterion, CTCCriterion
  39. from wav2vec2.logging import init_logger, W2v2Metrics, W2v2FineTuningMetrics
  40. from wav2vec2.utils import build_model, load_dataset
  41. @torch.no_grad()
  42. def validate(epoch, step, valid_loader, model, ema_model, criterion,
  43. val_metrics, val_ema_metrics, world_size, fp16, bf16):
  44. val_losses = []
  45. val_wer = []
  46. for model, metrics, scope in [(model, val_metrics, 'val'),
  47. (ema_model, val_ema_metrics, 'val_ema')]:
  48. if model is None:
  49. continue
  50. model.eval()
  51. criterion.eval()
  52. metrics._start_accumulating(None, True, scope=scope)
  53. output_keys = None
  54. assert len(valid_loader) > 1, (
  55. 'Validation needs at least 2 iterations to handle empty batches.')
  56. for batch in valid_loader:
  57. is_empty_batch = len(batch) == 0
  58. if not is_empty_batch:
  59. to_gpu(batch, fp16=fp16, bf16=bf16)
  60. loss, _, logging_output = criterion(model, batch)
  61. if output_keys is None:
  62. output_keys = logging_output.keys()
  63. else:
  64. assert output_keys is not None, (
  65. f'Invalid iters num: {len(valid_loader)}')
  66. logging_output = {k: 0 for k in output_keys}
  67. logging_output['ignore'] = int(is_empty_batch)
  68. metrics.log_scalars(logging_output)
  69. metrics.all_reduce(world_size)
  70. metrics.accumulate()
  71. metrics.finish_val(scope=scope)
  72. logger.log(() if epoch is None else (epoch,), metrics, scope=scope,
  73. tb_iter=step)
  74. val_losses.append(metrics.metrics[scope]['loss'])
  75. if 'wer' in metrics.metrics[scope]:
  76. val_wer.append(metrics.metrics[scope]['wer'])
  77. model.train()
  78. criterion.train()
  79. return val_losses, val_wer
  80. def main():
  81. parser = argparse.ArgumentParser(
  82. description='wav2vec 2.0 Deep Learning Example')
  83. wav2vec2.arg_parser.populate(parser)
  84. args = parser.parse_args()
  85. assert not args.bf16 or args.fp32_pos_conv, (
  86. "bfloat16 requires casting positional convolutions to float32")
  87. if args.mode == 'finetune':
  88. wav2vec2.utils.update_args_for_finetuning(args, args.w2v_path)
  89. head = lambda list_: list_[0] # fairseq compat, scalars wrapped w/ lists
  90. args.lr = head(args.lr)
  91. args.update_freq = head(args.update_freq)
  92. assert(torch.cuda.is_available())
  93. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  94. world_size = setup_distributed(args.local_rank)
  95. args.world_size = world_size # For FP16Optimizer
  96. print_once(f"World size: {world_size}")
  97. assert args.seed is not None, (
  98. "Random seed is used to ensure same model weights across all devices. "
  99. "To allow None, draw a seed and synchronize across devices")
  100. set_torch_seed(args.seed + args.local_rank)
  101. np.random.seed(args.seed + args.local_rank)
  102. random.seed(args.seed + args.local_rank)
  103. pre_training = (args.mode == 'pretrain')
  104. checkpointer = Checkpointer(args, 'wav2vec2')
  105. if not pre_training:
  106. assert args.labels or checkpointer.last_state, \
  107. "Supply output labels or resume from a checkpoint."
  108. if checkpointer.last_state is not None:
  109. f = io.StringIO(checkpointer.last_state["output_labels"])
  110. else:
  111. f = open(Path(args.data, f"dict.{args.labels}.txt"))
  112. target_dictionary = Dictionary.load(f)
  113. f.seek(0)
  114. checkpointer.output_labels = f.read()
  115. f.close()
  116. Metrics = W2v2FineTuningMetrics
  117. criterion = CTCCriterion(target_dictionary, post_process='letter')
  118. else:
  119. target_dictionary = None
  120. Metrics = W2v2Metrics
  121. criterion = Wav2vecCriterion(args)
  122. kw = {'benchmark_epochs': args.benchmark_epochs_num, 'cuda': not args.cpu}
  123. metrics = Metrics(**kw)
  124. val_metrics = Metrics(scopes=['val'], **kw)
  125. val_ema_metrics = Metrics(scopes=['val_ema'], **kw)
  126. init_logger(args.output_dir, args.log_file, args.ema)
  127. logger.log_parameters(vars(args), tb_subset='train')
  128. assert args.update_freq >= 1
  129. model, seq_gen, tokenizer = build_model(args, args.mode, target_dictionary)
  130. model.cuda()
  131. print_once(f'Model size: {num_weights(model) / 10 ** 6:.1f}M params\n')
  132. print_once('Setting up datasets...')
  133. train_dataset = load_dataset(args.train_subset, args, target_dictionary,
  134. with_labels=not pre_training, training=True)
  135. valid_dataset = load_dataset(args.valid_subset, args, target_dictionary,
  136. with_labels=not pre_training, training=False)
  137. # Future-proof for adoption of native AMP
  138. scaler = torch.cuda.amp.GradScaler(enabled=False)
  139. lr_kw = {'initial_lr_scale': args.initial_lr_scale,
  140. 'final_lr_scale': args.final_lr_scale,
  141. 'warmup_steps': args.warmup_updates,
  142. 'hold_steps': args.hold_updates,
  143. 'num_steps': args.max_update,
  144. 'lr': args.lr}
  145. if args.lr_policy == 'poly':
  146. adjust_lr = partial(lr_poly_policy, power=args.lr_poly_power, **lr_kw)
  147. elif args.lr_policy == 'exp':
  148. adjust_lr = partial(lr_exp_policy, decay=args.lr_exp_decay, **lr_kw)
  149. else:
  150. raise ValueError
  151. assert args.fp16 + args.bf16 <= 1, (
  152. "Select a single mechanism for mixed precision training.")
  153. checkpointer.maybe_load_state(model=model)
  154. if args.bf16:
  155. model.to(dtype=torch.bfloat16)
  156. if args.fp16:
  157. model.half()
  158. if (args.fp16 or args.bf16) and args.fp32_pos_conv:
  159. w2v = model.w2v_encoder.w2v_model if args.mode == 'finetune' else model
  160. w2v.encoder.pos_conv.to(dtype=torch.float32)
  161. multi_gpu = world_size > 1
  162. if multi_gpu:
  163. model = DistributedDataParallel(model, device_ids=[args.local_rank],
  164. output_device=args.local_rank,
  165. find_unused_parameters=True)
  166. model = ModuleProxyWrapper(model)
  167. args.bf16_disable_loss_scaler = False # TODO Add support in the future
  168. optim = get_optimizer(model, args)
  169. adjust_lr(1, optim)
  170. if args.ema > 0.0:
  171. raise NotImplementedError(
  172. "EMA disabled, see https://github.com/pytorch/pytorch/issues/28594"
  173. )
  174. else:
  175. ema_model = None
  176. train_state = {'step': 0, 'epoch': 1, 'best_val_loss': float('inf'),
  177. 'best_val_wer': float('inf')}
  178. checkpointer.maybe_load_state(ema_model=ema_model, optimizer=optim,
  179. scaler=scaler, train_state=train_state)
  180. shard_id = int(os.getenv("RANK", args.local_rank))
  181. train_loader, sampler = get_batch_iterator(
  182. train_dataset,
  183. True,
  184. max_tokens=args.max_tokens,
  185. max_sentences=args.batch_size,
  186. max_positions=(args.max_tokens, args.max_tokens),
  187. ignore_invalid_inputs=True,
  188. required_batch_size_multiple=args.required_batch_size_multiple,
  189. seed=args.seed,
  190. num_shards=world_size,
  191. shard_id=shard_id,
  192. num_workers=args.num_workers,
  193. num_concat_batches=args.num_concat_batches)
  194. valid_loader, _ = get_batch_iterator(
  195. valid_dataset,
  196. False,
  197. max_tokens=args.max_tokens_valid,
  198. max_sentences=args.batch_size_valid,
  199. max_positions=(sys.maxsize, sys.maxsize),
  200. ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
  201. required_batch_size_multiple=args.required_batch_size_multiple,
  202. seed=args.seed,
  203. num_shards=world_size,
  204. shard_id=shard_id,
  205. num_workers=args.num_workers,
  206. num_concat_batches=args.num_concat_batches)
  207. steps_per_epoch = len(train_loader) // args.update_freq
  208. checkpointer.maybe_load_state(train_loader=train_loader)
  209. checkpointer.last_state = None
  210. print_once(model)
  211. model.train()
  212. step, epoch = train_state['step'], train_state['epoch']
  213. start_step = step
  214. start_epoch = epoch
  215. while step < args.max_update: # training loop
  216. set_torch_seed(args.seed + step) # reproducibility after resuming
  217. metrics.start_epoch(epoch)
  218. sampler.set_epoch(epoch)
  219. optim.zero_grad()
  220. itr = islice(train_loader, steps_per_epoch * args.update_freq)
  221. for batch, accum_batches in zip(itr, cycle(range(args.update_freq))):
  222. if accum_batches == 0:
  223. step += 1
  224. model.set_num_updates(step)
  225. metrics.start_iter(accum_batches)
  226. to_gpu(batch, fp16=args.fp16, bf16=args.bf16)
  227. # use context manager to prevent redundant sync of gradients
  228. if (multi_gpu and accum_batches + 1 < args.update_freq):
  229. ctx = model.no_sync()
  230. else:
  231. ctx = empty_context()
  232. with ctx:
  233. loss, _, logging_output = criterion(model, batch)
  234. if args.fp16 or args.bf16:
  235. optim.backward(loss)
  236. else:
  237. scaler.scale(loss).backward()
  238. # at this point, loss is scaled by loss_scale
  239. # and averaged over different devices (because of DDP) (*)
  240. metrics.log_scalars(logging_output)
  241. if (accum_batches + 1) % args.update_freq == 0:
  242. metrics.all_reduce(world_size)
  243. # scales gradients update by world_size
  244. # (to restore sum of gradients - see (*))
  245. # divided by step_ntoks to average over tokens.
  246. grads_mult_factor = world_size / metrics.partials['sample_size']
  247. if args.optimizer == 'adam' and not (args.fp16 or args.bf16):
  248. # adam and non-amp optimizer - can use 'scale' kwarg for step
  249. # and defer grad multiplication
  250. pass
  251. elif args.fp16 or args.bf16:
  252. optim.multiply_grads(grads_mult_factor)
  253. else:
  254. multiply_grads(optim, grads_mult_factor)
  255. try:
  256. if args.fp16 or args.bf16:
  257. # calculate grad norm, maybe clip
  258. grad_norm = optim.clip_grad_norm(args.clip_norm)
  259. if args.optimizer == 'adam' and not (args.fp16 or args.bf16):
  260. scaler.step(optim, scale=1. / grads_mult_factor)
  261. else:
  262. scaler.step(optim)
  263. scaler.update()
  264. model.set_num_updates(step)
  265. except OverflowError as e:
  266. print_once(f"Grad overflow, ignoring grad. {str(e)}")
  267. grad_norm = torch.tensor(0.0).cuda()
  268. optim.zero_grad()
  269. if args.ema > 0.0:
  270. apply_multi_tensor_ema(args.ema, *mt_ema_params)
  271. if args.fp16 or args.bf16:
  272. metrics['loss_scale'] = optim.scaler.loss_scale
  273. metrics['lr'] = optim.param_groups[0]['lr']
  274. metrics.accumulate()
  275. metrics.finish_iter()
  276. if step % args.log_frequency == 0:
  277. metrics.finish_logging_interval()
  278. epoch_step = step % steps_per_epoch or steps_per_epoch
  279. logger.log((epoch, epoch_step, steps_per_epoch),
  280. metrics, scope='train', tb_iter=step)
  281. adjust_lr(step, optim)
  282. if step >= args.max_update:
  283. break
  284. # NOTE this will brake when resuming training on a different dataset
  285. assert step <= steps_per_epoch * epoch
  286. # end of iter
  287. metrics.finish_epoch()
  288. logger.log((epoch,), metrics, scope='train_avg', flush_log=True,
  289. tb_iter=step)
  290. print_once('Validating...')
  291. val_losses, val_wer = validate(
  292. epoch, step, valid_loader, model, ema_model, criterion,
  293. val_metrics, val_ema_metrics, world_size, args.fp16, args.bf16)
  294. # save best ckpt based on non-EMA val results
  295. checkpointer.maybe_save(model, ema_model, optim, scaler, train_state,
  296. step, epoch, val_losses, val_wer, args)
  297. if 0 < args.epochs_this_job <= epoch + 1 - start_epoch:
  298. print_once(f'Reached {args.epochs_this_job} epochs in this run.')
  299. break
  300. if step >= args.max_update:
  301. print_once(f'Reached {step} total updates.')
  302. break
  303. epoch += 1 # end of epoch
  304. # finished training
  305. if step > start_step:
  306. logger.log((), metrics, scope='train_benchmark')
  307. logger.log((), val_metrics, scope='val')
  308. logger.log((), val_ema_metrics, scope='val_ema', flush_log=True)
  309. print_once(f'Finished after reaching update {step}.')
  310. if __name__ == "__main__":
  311. main()