run_pretraining.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. # coding=utf-8
  2. # Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BERT finetuning runner."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. # ==================
  20. import csv
  21. import os
  22. import time
  23. import argparse
  24. import random
  25. import logging
  26. import h5py
  27. from tqdm import tqdm, trange
  28. from typing import Final, Any, Callable
  29. import os
  30. import numpy as np
  31. import torch
  32. from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
  33. from torch.utils.data.distributed import DistributedSampler
  34. import torch.distributed as dist
  35. import math
  36. import modeling
  37. from schedulers import PolyWarmUpScheduler
  38. from lamb_amp_opt.fused_lamb import FusedLAMBAMP
  39. from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
  40. from utils import is_main_process, format_step, get_world_size, get_rank
  41. from torch.nn.parallel import DistributedDataParallel as DDP
  42. from schedulers import LinearWarmUpScheduler
  43. import dllogger
  44. import lddl.torch
  45. # Enabling the TorchScript Runtime Backend NVFuser
  46. torch._C._jit_set_nvfuser_enabled(True)
  47. torch._C._jit_set_texpr_fuser_enabled(False)
  48. torch._C._jit_override_can_fuse_on_cpu(False)
  49. torch._C._jit_override_can_fuse_on_gpu(False)
  50. torch._C._jit_set_bailout_depth(20)
  51. # Track whether a SIGTERM (cluster time up) has been handled
  52. timeout_sent = False
  53. import signal
  54. # handle SIGTERM sent from the scheduler and mark so we
  55. # can gracefully save & exit
  56. def signal_handler(sig, frame):
  57. global timeout_sent
  58. timeout_sent = True
  59. signal.signal(signal.SIGTERM, signal_handler)
  60. class BertPretrainingCriterion(torch.nn.Module):
  61. sequence_output_is_dense: Final[bool]
  62. def __init__(self, vocab_size, sequence_output_is_dense=False):
  63. super(BertPretrainingCriterion, self).__init__()
  64. self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
  65. self.vocab_size = vocab_size
  66. self.sequence_output_is_dense = sequence_output_is_dense
  67. def forward(self, prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels):
  68. if self.sequence_output_is_dense:
  69. # prediction_scores are already dense
  70. masked_lm_labels_flat = masked_lm_labels.view(-1)
  71. mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != -1]
  72. masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), mlm_labels.view(-1))
  73. else:
  74. masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
  75. next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
  76. total_loss = masked_lm_loss + next_sentence_loss
  77. return total_loss
  78. class SyncFreeStats :
  79. def __init__(self) :
  80. self.host_stats = {}
  81. self.device_stats = {}
  82. self.device_funcs = {}
  83. def add_stat(self, name, dtype=torch.int32, device_tensor=None, device_func=None) :
  84. if device_tensor is not None :
  85. assert dtype == device_tensor.dtype, "Error: dtype do not match: {} {}".format(dtype, device_tensor.dtype)
  86. self.host_stats[name] = torch.zeros(1, dtype=dtype).pin_memory()
  87. self.device_stats[name] = device_tensor
  88. self.device_funcs[name] = device_func
  89. def copy_from_device(self) :
  90. for name in self.host_stats.keys() :
  91. # Apply device function to device stat
  92. if self.device_stats[name] is not None and self.device_funcs[name] is not None:
  93. self.host_stats[name].copy_(self.device_funcs[name](self.device_stats[name]), non_blocking=True)
  94. elif self.device_stats[name] is not None :
  95. self.host_stats[name].copy_(self.device_stats[name], non_blocking=True)
  96. elif self.device_funcs[name] is not None :
  97. self.host_stats[name].copy_(self.device_funcs[name](), non_blocking=True)
  98. def host_stat(self, name) :
  99. assert name in self.host_stats
  100. return self.host_stats[name]
  101. def host_stat_value(self, name) :
  102. assert name in self.host_stats
  103. return self.host_stats[name].item()
  104. def update_host_stat(self, name, tensor) :
  105. self.host_stats[name] = tensor
  106. def device_stat(self, name) :
  107. assert self.device_stats[name] is not None
  108. return self.device_stats[name]
  109. def update_device_stat(self, name, tensor) :
  110. self.device_stats[name] = tensor
  111. def parse_arguments():
  112. parser = argparse.ArgumentParser()
  113. ## Required parameters
  114. parser.add_argument("--input_dir",
  115. default=None,
  116. type=str,
  117. required=True,
  118. help="The input data dir. Should contain .parquet files for the task.")
  119. parser.add_argument("--config_file",
  120. default=None,
  121. type=str,
  122. required=True,
  123. help="The BERT model config")
  124. parser.add_argument("--output_dir",
  125. default=None,
  126. type=str,
  127. required=True,
  128. help="The output directory where the model checkpoints will be written.")
  129. parser.add_argument('--vocab_file',
  130. type=str,
  131. default=None,
  132. required=True,
  133. help="Vocabulary mapping/file BERT was pretrainined on")
  134. ## Other parameters
  135. parser.add_argument("--init_checkpoint",
  136. default=None,
  137. type=str,
  138. help="The initial checkpoint to start training from.")
  139. parser.add_argument("--max_seq_length",
  140. default=512,
  141. type=int,
  142. help="The maximum total input sequence length after WordPiece tokenization. \n"
  143. "Sequences longer than this will be truncated, and sequences shorter \n"
  144. "than this will be padded.")
  145. parser.add_argument("--max_predictions_per_seq",
  146. default=80,
  147. type=int,
  148. help="The maximum total of masked tokens in input sequence")
  149. parser.add_argument("--train_batch_size",
  150. default=32,
  151. type=int,
  152. help="Total batch size for training.")
  153. parser.add_argument("--learning_rate",
  154. default=5e-5,
  155. type=float,
  156. help="The initial learning rate for Adam.")
  157. parser.add_argument("--num_train_epochs",
  158. default=3.0,
  159. type=float,
  160. help="Total number of training epochs to perform.")
  161. parser.add_argument("--max_steps",
  162. default=1000,
  163. type=float,
  164. help="Total number of training steps to perform.")
  165. parser.add_argument("--warmup_proportion",
  166. default=0.01,
  167. type=float,
  168. help="Proportion of training to perform linear learning rate warmup for. "
  169. "E.g., 0.1 = 10%% of training.")
  170. parser.add_argument("--local_rank",
  171. type=int,
  172. default=os.getenv('LOCAL_RANK', -1),
  173. help="local_rank for distributed training on gpus")
  174. parser.add_argument('--seed',
  175. type=int,
  176. default=42,
  177. help="random seed for initialization")
  178. parser.add_argument('--gradient_accumulation_steps',
  179. type=int,
  180. default=1,
  181. help="Number of updates steps to accumualte before performing a backward/update pass.")
  182. parser.add_argument('--fp16',
  183. default=False,
  184. action='store_true',
  185. help="Mixed precision training")
  186. parser.add_argument('--amp',
  187. default=False,
  188. action='store_true',
  189. help="Mixed precision training")
  190. parser.add_argument('--loss_scale',
  191. type=float, default=0.0,
  192. help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
  193. parser.add_argument('--log_freq',
  194. type=float, default=1.0,
  195. help='frequency of logging loss.')
  196. parser.add_argument('--checkpoint_activations',
  197. default=False,
  198. action='store_true',
  199. help="Whether to use gradient checkpointing")
  200. parser.add_argument("--resume_from_checkpoint",
  201. default=False,
  202. action='store_true',
  203. help="Whether to resume training from checkpoint.")
  204. parser.add_argument('--resume_step',
  205. type=int,
  206. default=-1,
  207. help="Step to resume training from.")
  208. parser.add_argument('--num_steps_per_checkpoint',
  209. type=int,
  210. default=100,
  211. help="Number of update steps until a model checkpoint is saved to disk.")
  212. parser.add_argument('--skip_checkpoint',
  213. default=False,
  214. action='store_true',
  215. help="Whether to save checkpoints")
  216. parser.add_argument('--phase2',
  217. default=False,
  218. action='store_true',
  219. help="Whether to train with seq len 512")
  220. parser.add_argument('--resume_phase2',
  221. default=False,
  222. action='store_true',
  223. help="Whether to resume training with seq len 512")
  224. parser.add_argument('--allreduce_post_accumulation',
  225. default=False,
  226. action='store_true',
  227. help="Whether to do allreduces during gradient accumulation steps.")
  228. parser.add_argument('--allreduce_post_accumulation_fp16',
  229. default=False,
  230. action='store_true',
  231. help="Whether to do fp16 allreduce post accumulation.")
  232. parser.add_argument('--phase1_end_step',
  233. type=int,
  234. default=7038,
  235. help="Number of training steps in Phase1 - seq len 128")
  236. parser.add_argument('--init_loss_scale',
  237. type=int,
  238. default=2**20,
  239. help="Initial loss scaler value")
  240. parser.add_argument("--do_train",
  241. default=False,
  242. action='store_true',
  243. help="Whether to run training.")
  244. parser.add_argument('--json-summary', type=str, default="results/dllogger.json",
  245. help='If provided, the json summary will be written to'
  246. 'the specified file.')
  247. parser.add_argument("--use_env",
  248. action='store_true',
  249. help="Whether to read local rank from ENVVAR")
  250. parser.add_argument('--disable_progress_bar',
  251. default=False,
  252. action='store_true',
  253. help='Disable tqdm progress bar')
  254. parser.add_argument('--steps_this_run', type=int, default=-1,
  255. help='If provided, only run this many steps before exiting')
  256. parser.add_argument("--profile",
  257. default=False,
  258. action='store_true',
  259. help="Whether to profile model.")
  260. parser.add_argument("--profile-start",
  261. default=0,
  262. type=int,
  263. help="Delay profiling to start step.")
  264. parser.add_argument('--num_workers',
  265. type=int,
  266. default=4,
  267. help='number of DataLoader worker processes per rank')
  268. # optimizations controlled by command line arguments
  269. parser.add_argument("--no_dense_sequence_output",
  270. default=False,
  271. action='store_true',
  272. help="Disable dense sequence output")
  273. parser.add_argument("--disable_jit_fusions",
  274. default=False,
  275. action='store_true',
  276. help="Disable jit fusions.")
  277. parser.add_argument("--cuda_graphs",
  278. default=False,
  279. action='store_true',
  280. help="Enable Cuda Graphs.")
  281. args = parser.parse_args()
  282. args.fp16 = args.fp16 or args.amp
  283. if args.steps_this_run < 0:
  284. args.steps_this_run = args.max_steps
  285. return args
  286. def setup_training(args):
  287. assert (torch.cuda.is_available())
  288. if args.local_rank == -1:
  289. device = torch.device("cuda", 0)
  290. args.n_gpu = 1 # torch.cuda.device_count()
  291. args.allreduce_post_accumulation = False
  292. args.allreduce_post_accumulation_fp16 = False
  293. else:
  294. torch.cuda.set_device(args.local_rank)
  295. device = torch.device("cuda", args.local_rank)
  296. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
  297. if args.cuda_graphs :
  298. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
  299. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  300. args.n_gpu = 1
  301. if is_main_process():
  302. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  303. filename=args.json_summary),
  304. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
  305. else:
  306. dllogger.init(backends=[])
  307. dllogger.metadata("e2e_train_time", {"unit": "s"})
  308. dllogger.metadata("training_sequences_per_second", {"unit": "sequences/s"})
  309. dllogger.metadata("final_loss", {"unit": None})
  310. dllogger.metadata("raw_train_time", {"unit": "s"})
  311. print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
  312. device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
  313. if args.gradient_accumulation_steps < 1:
  314. raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
  315. args.gradient_accumulation_steps))
  316. if args.train_batch_size % args.gradient_accumulation_steps != 0:
  317. raise ValueError("Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format(
  318. args.gradient_accumulation_steps, args.train_batch_size))
  319. args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
  320. if not args.do_train:
  321. raise ValueError(" `do_train` must be True.")
  322. if not args.resume_from_checkpoint and os.path.exists(args.output_dir) and (
  323. os.listdir(args.output_dir) and any([i.startswith('ckpt') for i in os.listdir(args.output_dir)])):
  324. raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
  325. if (not args.resume_from_checkpoint or not os.path.exists(args.output_dir)) and is_main_process():
  326. os.makedirs(args.output_dir, exist_ok=True)
  327. return device, args
  328. def prepare_model_and_optimizer(args, device, sequence_output_is_dense):
  329. # Prepare model
  330. config = modeling.BertConfig.from_json_file(args.config_file)
  331. # Padding for divisibility by 8
  332. if config.vocab_size % 8 != 0:
  333. config.vocab_size += 8 - (config.vocab_size % 8)
  334. model = modeling.BertForPreTraining(config, sequence_output_is_dense=sequence_output_is_dense)
  335. checkpoint = None
  336. if not args.resume_from_checkpoint:
  337. global_step = 0
  338. else:
  339. if args.resume_step == -1 and not args.init_checkpoint:
  340. model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
  341. args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])
  342. global_step = args.resume_step if not args.init_checkpoint else 0
  343. if not args.init_checkpoint:
  344. checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location=device)
  345. else:
  346. checkpoint = torch.load(args.init_checkpoint, map_location=device)
  347. model.load_state_dict(checkpoint['model'], strict=False)
  348. if args.phase2 and not args.init_checkpoint:
  349. global_step -= args.phase1_end_step
  350. if is_main_process():
  351. print("resume step from ", args.resume_step)
  352. model.to(device)
  353. # If allreduce_post_accumulation_fp16 is not set, Native AMP Autocast is
  354. # used along with FP32 gradient accumulation and all-reduce
  355. if args.fp16 and args.allreduce_post_accumulation_fp16:
  356. model.half()
  357. if not args.disable_jit_fusions :
  358. model = torch.jit.script(model)
  359. param_optimizer = list(model.named_parameters())
  360. no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
  361. optimizer_grouped_parameters = [
  362. {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  363. {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
  364. optimizer = FusedLAMBAMP(optimizer_grouped_parameters,
  365. lr=args.learning_rate)
  366. lr_scheduler = PolyWarmUpScheduler(optimizer,
  367. warmup=args.warmup_proportion,
  368. total_steps=args.max_steps,
  369. base_lr=args.learning_rate,
  370. device=device)
  371. grad_scaler = torch.cuda.amp.GradScaler(init_scale=args.init_loss_scale, enabled=args.fp16)
  372. model.checkpoint_activations(args.checkpoint_activations)
  373. if args.resume_from_checkpoint:
  374. # For phase2 from scratch, need to reset the learning rate and step count in the checkpoint. Else restore values in checkpoint.
  375. if (args.phase2 and not args.resume_phase2) or args.init_checkpoint :
  376. for group in checkpoint['optimizer']['param_groups'] :
  377. group['step'].zero_()
  378. group['lr'].fill_(args.learning_rate)
  379. else :
  380. if 'grad_scaler' in checkpoint and (not args.phase2 or args.resume_phase2):
  381. grad_scaler.load_state_dict(checkpoint['grad_scaler'])
  382. optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False)
  383. if args.local_rank != -1:
  384. # Cuda Graphs requires that DDP is captured on a side stream
  385. # It is important to synchronize the streams after the DDP initialization
  386. # so anything after sees properly initialized model weights across GPUs
  387. side_stream = torch.cuda.Stream()
  388. with torch.cuda.stream(side_stream) :
  389. model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, bucket_cap_mb=torch.cuda.get_device_properties(device).total_memory, gradient_as_bucket_view=True)
  390. torch.cuda.current_stream().wait_stream(side_stream)
  391. from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
  392. def scale_by_grad_accum_steps_wrapper(hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
  393. def scale_by_grad_accum_steps_wrapper_hook(
  394. hook_state, bucket: dist.GradBucket
  395. ) -> torch.futures.Future[torch.Tensor]:
  396. bucket.set_buffer(bucket.buffer().div_(args.gradient_accumulation_steps))
  397. fut = hook(hook_state, bucket)
  398. return fut
  399. return scale_by_grad_accum_steps_wrapper_hook
  400. # With gradient accumulation, the DDP comm hook divides the gradients by the number
  401. # gradient accumulation steps
  402. if args.gradient_accumulation_steps > 1:
  403. model.register_comm_hook(None, scale_by_grad_accum_steps_wrapper(allreduce_hook))
  404. optimizer.setup_fp32_params()
  405. criterion = BertPretrainingCriterion(config.vocab_size, sequence_output_is_dense=sequence_output_is_dense)
  406. if args.resume_from_checkpoint and args.init_checkpoint:
  407. start_epoch = checkpoint['epoch']
  408. else:
  409. start_epoch = 0
  410. return model, optimizer, grad_scaler, lr_scheduler, checkpoint, global_step, criterion, start_epoch
  411. def checkpoint_step(args, epoch, global_step, model, optimizer, grad_scaler, last3_checkpoint_paths) :
  412. torch.cuda.synchronize()
  413. if is_main_process() and not args.skip_checkpoint:
  414. # Save a trained model
  415. dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
  416. model_to_save = model.module if hasattr(model,
  417. 'module') else model # Only save the model it-self
  418. if args.resume_step < 0 or not args.phase2:
  419. output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
  420. else:
  421. output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step))
  422. if args.do_train:
  423. torch.save({'model': model_to_save.state_dict(),
  424. 'optimizer': optimizer.state_dict(),
  425. 'grad_scaler': grad_scaler.state_dict(),
  426. 'epoch': epoch}, output_save_file)
  427. # The new checkpoint could have a name already in
  428. # last3_checkpoint_paths. In this case, torch.save will overwrite
  429. # the old file; thus, we need to take the name out of
  430. # last3_checkpoint_paths and append it to the last.
  431. if output_save_file in last3_checkpoint_paths:
  432. last3_checkpoint_paths.remove(output_save_file)
  433. last3_checkpoint_paths.append(output_save_file)
  434. if len(last3_checkpoint_paths) > 3:
  435. ckpt_to_be_removed = last3_checkpoint_paths.pop(0)
  436. os.remove(ckpt_to_be_removed)
  437. def take_training_step(args, grad_scaler, model, criterion, batch, stats):
  438. with torch.cuda.amp.autocast(enabled=(args.fp16 and not args.allreduce_post_accumulation_fp16)) :
  439. prediction_scores, seq_relationship_score = model(input_ids=batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'], masked_lm_labels=batch['labels'])
  440. loss = criterion(prediction_scores, seq_relationship_score, batch['labels'], batch['next_sentence_labels'])
  441. stats.device_stat('average_loss').add_(loss.detach())
  442. grad_scaler.scale(loss).backward()
  443. def take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats):
  444. lr_scheduler.step() # learning rate warmup
  445. grad_scaler.step(optimizer)
  446. # Stats copying is located here prior to the infinity check being reset
  447. # in GradScaler::update()
  448. stats.copy_from_device()
  449. grad_scaler.update()
  450. optimizer.zero_grad(set_to_none=True)
  451. def main():
  452. global timeout_sent
  453. args = parse_arguments()
  454. random.seed(args.seed + args.local_rank)
  455. np.random.seed(args.seed + args.local_rank)
  456. torch.manual_seed(args.seed + args.local_rank)
  457. torch.cuda.manual_seed(args.seed + args.local_rank)
  458. device, args = setup_training(args)
  459. dllogger.log(step="PARAMETER", data={"Config": [str(args)]})
  460. # Prepare optimizer
  461. model, optimizer, grad_scaler, lr_scheduler, checkpoint, global_step, criterion, epoch = prepare_model_and_optimizer(args, device, sequence_output_is_dense=not args.no_dense_sequence_output)
  462. # Prepare the data loader.
  463. if is_main_process():
  464. tic = time.perf_counter()
  465. train_dataloader = lddl.torch.get_bert_pretrain_data_loader(
  466. args.input_dir,
  467. local_rank=max(args.local_rank, 0),
  468. vocab_file=args.vocab_file,
  469. data_loader_kwargs={
  470. 'batch_size': args.train_batch_size * args.n_gpu,
  471. 'num_workers': args.num_workers,
  472. 'pin_memory': True,
  473. },
  474. base_seed=args.seed,
  475. log_dir=None if args.output_dir is None else os.path.join(args.output_dir, 'lddl_log'),
  476. log_level=logging.WARNING,
  477. start_epoch=epoch,
  478. )
  479. if is_main_process():
  480. print('get_bert_pretrain_data_loader took {} s!'.format(time.perf_counter() - tic))
  481. if is_main_process():
  482. dllogger.log(step="PARAMETER", data={"SEED": args.seed})
  483. dllogger.log(step="PARAMETER", data={"train_start": True})
  484. dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
  485. dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})
  486. model.train()
  487. most_recent_ckpts_paths = []
  488. stats = SyncFreeStats()
  489. # Host Only Stats
  490. stats.add_stat('model_step')
  491. # Device/Host Sync-ed Stats
  492. stats.add_stat('optimizer_step', dtype=torch.int32, device_func=(lambda: optimizer.param_groups[0]['step']))
  493. stats.add_stat('average_loss', dtype=torch.float32, device_tensor=torch.zeros(1, dtype=torch.float32, device=device))
  494. stats.add_stat('learning_rate', dtype=torch.float32, device_func=(lambda: optimizer.param_groups[0]['lr']))
  495. if grad_scaler.is_enabled():
  496. # This stat only indicates a skipped step occurred. It does not accumulate the number of skipped steps
  497. stats.add_stat('skip_optimizer_step', dtype=torch.float32, device_func=(lambda: grad_scaler._found_inf_per_device(optimizer)[device]))
  498. stats.add_stat('skipped_optimizer_steps', dtype=torch.float32, device_tensor=torch.zeros(1, dtype=torch.float32, device=device),
  499. device_func=(lambda x: x.add_(grad_scaler._found_inf_per_device(optimizer)[device])))
  500. else:
  501. stats.add_stat('skip_optimizer_step', dtype=torch.float32)
  502. stats.add_stat('skipped_optimizer_steps', dtype=torch.float32)
  503. static_gpu_batch = None
  504. full_cudagraph = None
  505. grad_accum_cudagraph = None
  506. if args.cuda_graphs:
  507. static_gpu_batch = {
  508. 'input_ids': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
  509. 'token_type_ids': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
  510. 'attention_mask': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
  511. 'labels': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
  512. 'next_sentence_labels': torch.ones(args.train_batch_size, dtype=torch.int64, device=device),
  513. }
  514. side_stream = torch.cuda.Stream()
  515. # Warmup Steps - includes jitting fusions
  516. side_stream = torch.cuda.Stream()
  517. side_stream.wait_stream(torch.cuda.current_stream())
  518. with torch.cuda.stream(side_stream):
  519. for _ in range(11):
  520. take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
  521. take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats)
  522. torch.cuda.current_stream().wait_stream(side_stream)
  523. # Capture Graph
  524. full_cudagraph = torch.cuda.CUDAGraph()
  525. with torch.cuda.graph(full_cudagraph):
  526. take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
  527. take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats)
  528. # Warmup Steps - includes jitting fusions
  529. side_stream.wait_stream(torch.cuda.current_stream())
  530. with torch.cuda.stream(side_stream):
  531. for _ in range(3):
  532. with model.no_sync():
  533. take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
  534. torch.cuda.current_stream().wait_stream(side_stream)
  535. # Capture Graph
  536. grad_accum_cudagraph = torch.cuda.CUDAGraph()
  537. with torch.cuda.graph(grad_accum_cudagraph):
  538. with model.no_sync():
  539. take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
  540. train_iter = tqdm(
  541. train_dataloader,
  542. desc="Iteration",
  543. disable=args.disable_progress_bar,
  544. total=len(train_dataloader),
  545. ) if is_main_process() else train_dataloader
  546. raw_train_start = None
  547. # avoid nvfuser compilation times in measuring perf with phase2 binning
  548. # ideally skip > 3 * num_bins fwd+bwd iterations to start measuring perf
  549. skip_fwd_bwd_for_perf = 4
  550. if args.phase2: #we use 8 bins with phase2
  551. skip_fwd_bwd_for_perf = 50
  552. while True:
  553. for step, batch in enumerate(train_iter):
  554. # The first training step is 1 and not 0 when gradient accumulating
  555. # in order to avoid an optimizer step on the very first step
  556. stats.host_stat('model_step').add_(1)
  557. grad_accumulation_step = (stats.host_stat_value('model_step') % args.gradient_accumulation_steps) != 0
  558. if raw_train_start is None and step == skip_fwd_bwd_for_perf:
  559. raw_train_start = time.time()
  560. # Execute Model Step
  561. if args.cuda_graphs:
  562. for k in batch.keys():
  563. static_gpu_batch[k].copy_(batch[k], non_blocking=True)
  564. if grad_accumulation_step:
  565. grad_accum_cudagraph.replay()
  566. else:
  567. full_cudagraph.replay()
  568. else:
  569. batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
  570. if args.allreduce_post_accumulation and grad_accumulation_step:
  571. with model.no_sync():
  572. take_training_step(args, grad_scaler, model, criterion, batch, stats)
  573. else:
  574. take_training_step(args, grad_scaler, model, criterion, batch, stats)
  575. if not grad_accumulation_step:
  576. take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats)
  577. # Log Optimizer Step
  578. if (not grad_accumulation_step) or timeout_sent:
  579. static_optimizer_step = stats.host_stat_value('model_step') // args.gradient_accumulation_steps
  580. dynamic_optimizer_step = static_optimizer_step - int(stats.host_stat_value('skipped_optimizer_steps'))
  581. no_log_steps = static_optimizer_step % args.log_freq
  582. # Log Final Step (MAYBE)
  583. # Since the stats are asynchronously pushed from the GPU to CPU, they are not always reliable
  584. # Therefore, a synchronization is required to guarantee you see the intended value.
  585. # Without a synchronization, it is possible for some GPUs to go through the exit conditional
  586. # and others to not because they accidentally see a different value for `skipped_optimizer_steps`.
  587. # In order to remove most device syncs, synchronizations only begin in the last few steps
  588. # where the skipped step count matters.
  589. if static_optimizer_step >= args.steps_this_run or timeout_sent:
  590. torch.cuda.synchronize()
  591. dynamic_optimizer_step = static_optimizer_step - int(stats.host_stat_value('skipped_optimizer_steps'))
  592. if dynamic_optimizer_step >= args.steps_this_run or timeout_sent:
  593. train_time_raw = time.time() - raw_train_start
  594. last_num_steps = args.log_freq if no_log_steps == 0 else no_log_steps
  595. stats.device_stat('average_loss').div_(last_num_steps * args.gradient_accumulation_steps)
  596. if (torch.distributed.is_initialized()):
  597. stats.device_stat('average_loss').div_(get_world_size())
  598. torch.distributed.all_reduce(stats.device_stat('average_loss'))
  599. # We block on this copy to insure the final value
  600. stats.host_stat('average_loss').copy_(stats.device_stat('average_loss'))
  601. if is_main_process():
  602. dllogger.log(step=(epoch, dynamic_optimizer_step,), data={"final_loss": stats.host_stat_value('average_loss')})
  603. checkpoint_step(args, epoch, dynamic_optimizer_step, model, optimizer, grad_scaler, most_recent_ckpts_paths)
  604. return args, train_time_raw, stats, skip_fwd_bwd_for_perf
  605. if no_log_steps == 0:
  606. if is_main_process():
  607. dllogger.log(step=(epoch, dynamic_optimizer_step,),
  608. data={"average_loss": stats.host_stat_value('average_loss') / (args.log_freq * args.gradient_accumulation_steps),
  609. "learning_rate": stats.host_stat_value('learning_rate'),
  610. "skipped_steps": int(stats.host_stat_value('skipped_optimizer_steps'))})
  611. if stats.host_stat_value('skip_optimizer_step') > 0.:
  612. dllogger.log(step="PARAMETER", data={"loss_scale": grad_scaler._get_scale_async().item()})
  613. stats.device_stat('average_loss').zero_()
  614. if not args.skip_checkpoint and (dynamic_optimizer_step % args.num_steps_per_checkpoint == 0):
  615. checkpoint_step(args, epoch, dynamic_optimizer_step, model, optimizer, grad_scaler, most_recent_ckpts_paths)
  616. epoch += 1
  617. if __name__ == "__main__":
  618. now = time.time()
  619. args, train_time_raw, stats, skip_fwd_bwd_for_perf = main()
  620. gpu_count = args.n_gpu
  621. if torch.distributed.is_initialized():
  622. gpu_count = get_world_size()
  623. if is_main_process():
  624. e2e_time = time.time() - now
  625. training_perf = args.train_batch_size * gpu_count * (stats.host_stat_value('model_step') - skip_fwd_bwd_for_perf) / train_time_raw
  626. dllogger.log(step=tuple(), data={"e2e_train_time": e2e_time,
  627. "training_sequences_per_second": training_perf,
  628. "final_loss": stats.host_stat_value('average_loss'),
  629. "raw_train_time": train_time_raw })
  630. dllogger.flush()