run_pretraining.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BERT finetuning runner."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. # ==================
  20. import csv
  21. import os
  22. import time
  23. import argparse
  24. import random
  25. import h5py
  26. from tqdm import tqdm, trange
  27. import os
  28. import numpy as np
  29. import torch
  30. from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
  31. from torch.utils.data.distributed import DistributedSampler
  32. import math
  33. from apex import amp
  34. import multiprocessing
  35. from tokenization import BertTokenizer
  36. import modeling
  37. from apex.optimizers import FusedLAMB
  38. from schedulers import PolyWarmUpScheduler
  39. from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
  40. from utils import is_main_process, format_step, get_world_size, get_rank
  41. from apex.parallel import DistributedDataParallel as DDP
  42. from schedulers import LinearWarmUpScheduler
  43. from apex.parallel.distributed import flat_dist_call
  44. import amp_C
  45. import apex_C
  46. from apex.amp import _amp_state
  47. import dllogger
  48. from concurrent.futures import ProcessPoolExecutor
  49. torch._C._jit_set_profiling_mode(False)
  50. torch._C._jit_set_profiling_executor(False)
  51. skipped_steps = 0
  52. # Track whether a SIGTERM (cluster time up) has been handled
  53. timeout_sent = False
  54. import signal
  55. # handle SIGTERM sent from the scheduler and mark so we
  56. # can gracefully save & exit
  57. def signal_handler(sig, frame):
  58. global timeout_sent
  59. timeout_sent = True
  60. signal.signal(signal.SIGTERM, signal_handler)
  61. #Workaround because python functions are not picklable
  62. class WorkerInitObj(object):
  63. def __init__(self, seed):
  64. self.seed = seed
  65. def __call__(self, id):
  66. np.random.seed(seed=self.seed + id)
  67. random.seed(self.seed + id)
  68. def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, worker_init):
  69. train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length)
  70. train_sampler = RandomSampler(train_data)
  71. train_dataloader = DataLoader(train_data, sampler=train_sampler,
  72. batch_size=args.train_batch_size * args.n_gpu,
  73. num_workers=4, worker_init_fn=worker_init,
  74. pin_memory=True)
  75. return train_dataloader, input_file
  76. class pretraining_dataset(Dataset):
  77. def __init__(self, input_file, max_pred_length):
  78. self.input_file = input_file
  79. self.max_pred_length = max_pred_length
  80. f = h5py.File(input_file, "r")
  81. keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', 'masked_lm_ids',
  82. 'next_sentence_labels']
  83. self.inputs = [np.asarray(f[key][:]) for key in keys]
  84. f.close()
  85. def __len__(self):
  86. 'Denotes the total number of samples'
  87. return len(self.inputs[0])
  88. def __getitem__(self, index):
  89. [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [
  90. torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(
  91. np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)]
  92. masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -1
  93. index = self.max_pred_length
  94. # store number of masked tokens in index
  95. padded_mask_indices = (masked_lm_positions == 0).nonzero()
  96. if len(padded_mask_indices) != 0:
  97. index = padded_mask_indices[0].item()
  98. masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index]
  99. return [input_ids, segment_ids, input_mask,
  100. masked_lm_labels, next_sentence_labels]
  101. class BertPretrainingCriterion(torch.nn.Module):
  102. def __init__(self, vocab_size):
  103. super(BertPretrainingCriterion, self).__init__()
  104. self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
  105. self.vocab_size = vocab_size
  106. def forward(self, prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels):
  107. masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
  108. next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
  109. total_loss = masked_lm_loss + next_sentence_loss
  110. return total_loss
  111. def parse_arguments():
  112. parser = argparse.ArgumentParser()
  113. ## Required parameters
  114. parser.add_argument("--input_dir",
  115. default=None,
  116. type=str,
  117. required=True,
  118. help="The input data dir. Should contain .hdf5 files for the task.")
  119. parser.add_argument("--config_file",
  120. default=None,
  121. type=str,
  122. required=True,
  123. help="The BERT model config")
  124. parser.add_argument("--bert_model", default="bert-large-uncased", type=str,
  125. help="Bert pre-trained model selected in the list: bert-base-uncased, "
  126. "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
  127. parser.add_argument("--output_dir",
  128. default=None,
  129. type=str,
  130. required=True,
  131. help="The output directory where the model checkpoints will be written.")
  132. ## Other parameters
  133. parser.add_argument("--init_checkpoint",
  134. default=None,
  135. type=str,
  136. help="The initial checkpoint to start training from.")
  137. parser.add_argument("--max_seq_length",
  138. default=512,
  139. type=int,
  140. help="The maximum total input sequence length after WordPiece tokenization. \n"
  141. "Sequences longer than this will be truncated, and sequences shorter \n"
  142. "than this will be padded.")
  143. parser.add_argument("--max_predictions_per_seq",
  144. default=80,
  145. type=int,
  146. help="The maximum total of masked tokens in input sequence")
  147. parser.add_argument("--train_batch_size",
  148. default=32,
  149. type=int,
  150. help="Total batch size for training.")
  151. parser.add_argument("--learning_rate",
  152. default=5e-5,
  153. type=float,
  154. help="The initial learning rate for Adam.")
  155. parser.add_argument("--num_train_epochs",
  156. default=3.0,
  157. type=float,
  158. help="Total number of training epochs to perform.")
  159. parser.add_argument("--max_steps",
  160. default=1000,
  161. type=float,
  162. help="Total number of training steps to perform.")
  163. parser.add_argument("--warmup_proportion",
  164. default=0.01,
  165. type=float,
  166. help="Proportion of training to perform linear learning rate warmup for. "
  167. "E.g., 0.1 = 10%% of training.")
  168. parser.add_argument("--local_rank",
  169. type=int,
  170. default=os.getenv('LOCAL_RANK', -1),
  171. help="local_rank for distributed training on gpus")
  172. parser.add_argument('--seed',
  173. type=int,
  174. default=42,
  175. help="random seed for initialization")
  176. parser.add_argument('--gradient_accumulation_steps',
  177. type=int,
  178. default=1,
  179. help="Number of updates steps to accumualte before performing a backward/update pass.")
  180. parser.add_argument('--fp16',
  181. default=False,
  182. action='store_true',
  183. help="Mixed precision training")
  184. parser.add_argument('--amp',
  185. default=False,
  186. action='store_true',
  187. help="Mixed precision training")
  188. parser.add_argument('--loss_scale',
  189. type=float, default=0.0,
  190. help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
  191. parser.add_argument('--log_freq',
  192. type=float, default=1.0,
  193. help='frequency of logging loss.')
  194. parser.add_argument('--checkpoint_activations',
  195. default=False,
  196. action='store_true',
  197. help="Whether to use gradient checkpointing")
  198. parser.add_argument("--resume_from_checkpoint",
  199. default=False,
  200. action='store_true',
  201. help="Whether to resume training from checkpoint.")
  202. parser.add_argument('--resume_step',
  203. type=int,
  204. default=-1,
  205. help="Step to resume training from.")
  206. parser.add_argument('--num_steps_per_checkpoint',
  207. type=int,
  208. default=100,
  209. help="Number of update steps until a model checkpoint is saved to disk.")
  210. parser.add_argument('--skip_checkpoint',
  211. default=False,
  212. action='store_true',
  213. help="Whether to save checkpoints")
  214. parser.add_argument('--phase2',
  215. default=False,
  216. action='store_true',
  217. help="Whether to train with seq len 512")
  218. parser.add_argument('--allreduce_post_accumulation',
  219. default=False,
  220. action='store_true',
  221. help="Whether to do allreduces during gradient accumulation steps.")
  222. parser.add_argument('--allreduce_post_accumulation_fp16',
  223. default=False,
  224. action='store_true',
  225. help="Whether to do fp16 allreduce post accumulation.")
  226. parser.add_argument('--phase1_end_step',
  227. type=int,
  228. default=7038,
  229. help="Number of training steps in Phase1 - seq len 128")
  230. parser.add_argument('--init_loss_scale',
  231. type=int,
  232. default=2**20,
  233. help="Initial loss scaler value")
  234. parser.add_argument("--do_train",
  235. default=False,
  236. action='store_true',
  237. help="Whether to run training.")
  238. parser.add_argument('--json-summary', type=str, default="results/dllogger.json",
  239. help='If provided, the json summary will be written to'
  240. 'the specified file.')
  241. parser.add_argument("--use_env",
  242. action='store_true',
  243. help="Whether to read local rank from ENVVAR")
  244. parser.add_argument('--disable_progress_bar',
  245. default=False,
  246. action='store_true',
  247. help='Disable tqdm progress bar')
  248. parser.add_argument('--steps_this_run', type=int, default=-1,
  249. help='If provided, only run this many steps before exiting')
  250. args = parser.parse_args()
  251. args.fp16 = args.fp16 or args.amp
  252. if args.steps_this_run < 0:
  253. args.steps_this_run = args.max_steps
  254. return args
  255. def setup_training(args):
  256. assert (torch.cuda.is_available())
  257. if args.local_rank == -1:
  258. device = torch.device("cuda")
  259. args.n_gpu = torch.cuda.device_count()
  260. args.allreduce_post_accumulation = False
  261. args.allreduce_post_accumulation_fp16 = False
  262. else:
  263. torch.cuda.set_device(args.local_rank)
  264. device = torch.device("cuda", args.local_rank)
  265. # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
  266. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  267. args.n_gpu = 1
  268. if args.gradient_accumulation_steps == 1:
  269. args.allreduce_post_accumulation = False
  270. args.allreduce_post_accumulation_fp16 = False
  271. if is_main_process():
  272. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  273. filename=args.json_summary),
  274. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
  275. else:
  276. dllogger.init(backends=[])
  277. print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
  278. device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
  279. if args.gradient_accumulation_steps < 1:
  280. raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
  281. args.gradient_accumulation_steps))
  282. if args.train_batch_size % args.gradient_accumulation_steps != 0:
  283. raise ValueError("Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format(
  284. args.gradient_accumulation_steps, args.train_batch_size))
  285. args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
  286. if not args.do_train:
  287. raise ValueError(" `do_train` must be True.")
  288. if not args.resume_from_checkpoint and os.path.exists(args.output_dir) and (
  289. os.listdir(args.output_dir) and any([i.startswith('ckpt') for i in os.listdir(args.output_dir)])):
  290. raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
  291. if (not args.resume_from_checkpoint or not os.path.exists(args.output_dir)) and is_main_process():
  292. os.makedirs(args.output_dir, exist_ok=True)
  293. return device, args
  294. def prepare_model_and_optimizer(args, device):
  295. # Prepare model
  296. config = modeling.BertConfig.from_json_file(args.config_file)
  297. # Padding for divisibility by 8
  298. if config.vocab_size % 8 != 0:
  299. config.vocab_size += 8 - (config.vocab_size % 8)
  300. modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
  301. model = modeling.BertForPreTraining(config)
  302. checkpoint = None
  303. if not args.resume_from_checkpoint:
  304. global_step = 0
  305. else:
  306. if args.resume_step == -1 and not args.init_checkpoint:
  307. model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
  308. args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])
  309. global_step = args.resume_step if not args.init_checkpoint else 0
  310. if not args.init_checkpoint:
  311. checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu")
  312. else:
  313. checkpoint = torch.load(args.init_checkpoint, map_location="cpu")
  314. model.load_state_dict(checkpoint['model'], strict=False)
  315. if args.phase2 and not args.init_checkpoint:
  316. global_step -= args.phase1_end_step
  317. if is_main_process():
  318. print("resume step from ", args.resume_step)
  319. model.to(device)
  320. param_optimizer = list(model.named_parameters())
  321. no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
  322. optimizer_grouped_parameters = [
  323. {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  324. {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
  325. optimizer = FusedLAMB(optimizer_grouped_parameters,
  326. lr=args.learning_rate)
  327. lr_scheduler = PolyWarmUpScheduler(optimizer,
  328. warmup=args.warmup_proportion,
  329. total_steps=args.max_steps)
  330. if args.fp16:
  331. if args.loss_scale == 0:
  332. model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
  333. else:
  334. model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
  335. amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale
  336. model.checkpoint_activations(args.checkpoint_activations)
  337. if args.resume_from_checkpoint:
  338. if args.phase2 or args.init_checkpoint:
  339. keys = list(checkpoint['optimizer']['state'].keys())
  340. #Override hyperparameters from previous checkpoint
  341. for key in keys:
  342. checkpoint['optimizer']['state'][key]['step'] = global_step
  343. for iter, item in enumerate(checkpoint['optimizer']['param_groups']):
  344. checkpoint['optimizer']['param_groups'][iter]['step'] = global_step
  345. checkpoint['optimizer']['param_groups'][iter]['t_total'] = args.max_steps
  346. checkpoint['optimizer']['param_groups'][iter]['warmup'] = args.warmup_proportion
  347. checkpoint['optimizer']['param_groups'][iter]['lr'] = args.learning_rate
  348. optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False)
  349. # Restore AMP master parameters
  350. if args.fp16:
  351. optimizer._lazy_init_maybe_master_weights()
  352. optimizer._amp_stash.lazy_init_called = True
  353. optimizer.load_state_dict(checkpoint['optimizer'])
  354. for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
  355. param.data.copy_(saved_param.data)
  356. if args.local_rank != -1:
  357. if not args.allreduce_post_accumulation:
  358. model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
  359. else:
  360. flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
  361. elif args.n_gpu > 1:
  362. model = torch.nn.DataParallel(model)
  363. criterion = BertPretrainingCriterion(config.vocab_size)
  364. return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
  365. def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
  366. global skipped_steps
  367. if args.allreduce_post_accumulation:
  368. # manually allreduce gradients after all accumulation steps
  369. # check for Inf/NaN
  370. # 1. allocate an uninitialized buffer for flattened gradient
  371. loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1
  372. master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None]
  373. flat_grad_size = sum(p.numel() for p in master_grads)
  374. allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
  375. flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype)
  376. # 2. combine unflattening and predivision of unscaled 'raw' gradient
  377. allreduced_views = apex_C.unflatten(flat_raw, master_grads)
  378. overflow_buf.zero_()
  379. amp_C.multi_tensor_scale(65536,
  380. overflow_buf,
  381. [master_grads, allreduced_views],
  382. loss_scale / (get_world_size() * args.gradient_accumulation_steps))
  383. # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
  384. torch.distributed.all_reduce(flat_raw)
  385. # 4. combine unscaling and unflattening of allreduced gradient
  386. overflow_buf.zero_()
  387. amp_C.multi_tensor_scale(65536,
  388. overflow_buf,
  389. [allreduced_views, master_grads],
  390. 1./loss_scale)
  391. # 5. update loss scale
  392. if args.fp16:
  393. scaler = _amp_state.loss_scalers[0]
  394. old_overflow_buf = scaler._overflow_buf
  395. scaler._overflow_buf = overflow_buf
  396. had_overflow = scaler.update_scale()
  397. scaler._overfloat_buf = old_overflow_buf
  398. else:
  399. had_overflow = 0
  400. # 6. call optimizer step function
  401. if had_overflow == 0:
  402. optimizer.step()
  403. global_step += 1
  404. else:
  405. # Overflow detected, print message and clear gradients
  406. skipped_steps += 1
  407. if is_main_process():
  408. scaler = _amp_state.loss_scalers[0]
  409. dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()})
  410. if _amp_state.opt_properties.master_weights:
  411. for param in optimizer._amp_stash.all_fp32_from_fp16_params:
  412. param.grad = None
  413. for param in model.parameters():
  414. param.grad = None
  415. else:
  416. optimizer.step()
  417. #optimizer.zero_grad()
  418. for param in model.parameters():
  419. param.grad = None
  420. global_step += 1
  421. return global_step
  422. def main():
  423. global timeout_sent
  424. args = parse_arguments()
  425. random.seed(args.seed + args.local_rank)
  426. np.random.seed(args.seed + args.local_rank)
  427. torch.manual_seed(args.seed + args.local_rank)
  428. torch.cuda.manual_seed(args.seed + args.local_rank)
  429. worker_init = WorkerInitObj(args.seed + args.local_rank)
  430. device, args = setup_training(args)
  431. dllogger.log(step="PARAMETER", data={"Config": [str(args)]})
  432. # Prepare optimizer
  433. model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device)
  434. if is_main_process():
  435. dllogger.log(step="PARAMETER", data={"SEED": args.seed})
  436. raw_train_start = None
  437. if args.do_train:
  438. if is_main_process():
  439. dllogger.log(step="PARAMETER", data={"train_start": True})
  440. dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
  441. dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})
  442. model.train()
  443. most_recent_ckpts_paths = []
  444. average_loss = 0.0 # averaged loss every args.log_freq steps
  445. epoch = 0
  446. training_steps = 0
  447. pool = ProcessPoolExecutor(1)
  448. # Note: We loop infinitely over epochs, termination is handled via iteration count
  449. while True:
  450. thread = None
  451. restored_data_loader = None
  452. if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint:
  453. files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
  454. os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f]
  455. files.sort()
  456. num_files = len(files)
  457. random.Random(args.seed + epoch).shuffle(files)
  458. f_start_id = 0
  459. else:
  460. f_start_id = checkpoint['files'][0]
  461. files = checkpoint['files'][1:]
  462. args.resume_from_checkpoint = False
  463. num_files = len(files)
  464. # may not exist in all checkpoints
  465. epoch = checkpoint.get('epoch', 0)
  466. restored_data_loader = checkpoint.get('data_loader', None)
  467. shared_file_list = {}
  468. if torch.distributed.is_initialized() and get_world_size() > num_files:
  469. remainder = get_world_size() % num_files
  470. data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files]
  471. else:
  472. data_file = files[(f_start_id*get_world_size()+get_rank())%num_files]
  473. previous_file = data_file
  474. if restored_data_loader is None:
  475. train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
  476. train_sampler = RandomSampler(train_data)
  477. train_dataloader = DataLoader(train_data, sampler=train_sampler,
  478. batch_size=args.train_batch_size * args.n_gpu,
  479. num_workers=4, worker_init_fn=worker_init,
  480. pin_memory=True)
  481. # shared_file_list["0"] = (train_dataloader, data_file)
  482. else:
  483. train_dataloader = restored_data_loader
  484. restored_data_loader = None
  485. overflow_buf = None
  486. if args.allreduce_post_accumulation:
  487. overflow_buf = torch.cuda.IntTensor([0])
  488. for f_id in range(f_start_id + 1 , len(files)):
  489. if get_world_size() > num_files:
  490. data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files]
  491. else:
  492. data_file = files[(f_id*get_world_size()+get_rank())%num_files]
  493. previous_file = data_file
  494. dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init)
  495. train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader
  496. if raw_train_start is None:
  497. raw_train_start = time.time()
  498. for step, batch in enumerate(train_iter):
  499. training_steps += 1
  500. batch = [t.to(device) for t in batch]
  501. input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
  502. prediction_scores, seq_relationship_score = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
  503. loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels)
  504. if args.n_gpu > 1:
  505. loss = loss.mean() # mean() to average on multi-gpu.
  506. divisor = args.gradient_accumulation_steps
  507. if args.gradient_accumulation_steps > 1:
  508. if not args.allreduce_post_accumulation:
  509. # this division was merged into predivision
  510. loss = loss / args.gradient_accumulation_steps
  511. divisor = 1.0
  512. if args.fp16:
  513. with amp.scale_loss(loss, optimizer, delay_overflow_check=args.allreduce_post_accumulation) as scaled_loss:
  514. scaled_loss.backward()
  515. else:
  516. loss.backward()
  517. average_loss += loss.item()
  518. if training_steps % args.gradient_accumulation_steps == 0:
  519. lr_scheduler.step() # learning rate warmup
  520. global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step)
  521. if global_step >= args.steps_this_run or timeout_sent:
  522. train_time_raw = time.time() - raw_train_start
  523. last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
  524. last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
  525. average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
  526. average_loss = average_loss / (last_num_steps * divisor)
  527. if (torch.distributed.is_initialized()):
  528. average_loss /= get_world_size()
  529. torch.distributed.all_reduce(average_loss)
  530. final_loss = average_loss.item()
  531. if is_main_process():
  532. dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss})
  533. elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
  534. if is_main_process():
  535. dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor),
  536. "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
  537. "learning_rate": optimizer.param_groups[0]['lr']})
  538. average_loss = 0
  539. if global_step >= args.steps_this_run or training_steps % (
  540. args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent:
  541. if is_main_process() and not args.skip_checkpoint:
  542. # Save a trained model
  543. dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
  544. model_to_save = model.module if hasattr(model,
  545. 'module') else model # Only save the model it-self
  546. if args.resume_step < 0 or not args.phase2:
  547. output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
  548. else:
  549. output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step))
  550. if args.do_train:
  551. torch.save({'model': model_to_save.state_dict(),
  552. 'optimizer': optimizer.state_dict(),
  553. 'master params': list(amp.master_params(optimizer)),
  554. 'files': [f_id] + files,
  555. 'epoch': epoch,
  556. 'data_loader': None if global_step >= args.max_steps else train_dataloader}, output_save_file)
  557. most_recent_ckpts_paths.append(output_save_file)
  558. if len(most_recent_ckpts_paths) > 3:
  559. ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
  560. os.remove(ckpt_to_be_removed)
  561. # Exiting the training due to hitting max steps, or being sent a
  562. # timeout from the cluster scheduler
  563. if global_step >= args.steps_this_run or timeout_sent:
  564. del train_dataloader
  565. # thread.join()
  566. return args, final_loss, train_time_raw, global_step
  567. del train_dataloader
  568. # thread.join()
  569. # Make sure pool has finished and switch train_dataloader
  570. # NOTE: Will block until complete
  571. train_dataloader, data_file = dataset_future.result(timeout=None)
  572. epoch += 1
  573. if __name__ == "__main__":
  574. now = time.time()
  575. args, final_loss, train_time_raw, global_step = main()
  576. gpu_count = args.n_gpu
  577. global_step += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0
  578. if args.resume_step == -1:
  579. args.resume_step = 0
  580. if torch.distributed.is_initialized():
  581. gpu_count = get_world_size()
  582. if is_main_process():
  583. e2e_time = time.time() - now
  584. training_perf = args.train_batch_size * args.gradient_accumulation_steps * gpu_count\
  585. * (global_step - args.resume_step + skipped_steps) / train_time_raw
  586. dllogger.log(step=tuple(), data={"e2e_train_time": e2e_time, "training_sequences_per_second": training_perf,
  587. "final_loss": final_loss, "raw_train_time": train_time_raw })
  588. dllogger.flush()