run_pretraining.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # coding=utf-8
  2. # Copyright 2020 The Google Research Authors.
  3. # Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Pre-trains an ELECTRA model."""
  17. import argparse
  18. import collections
  19. import json
  20. import time
  21. import datetime
  22. import os
  23. import tensorflow as tf
  24. import horovod.tensorflow as hvd
  25. from horovod.tensorflow.compression import Compression
  26. from gpu_affinity import set_affinity
  27. import utils
  28. import sys
  29. import pretrain_utils
  30. from utils import get_rank, get_world_size, is_main_process, log, log_config, setup_logger, postprocess_dllog
  31. from tokenization import ElectraTokenizer
  32. from modeling import PretrainingModel
  33. from optimization import create_optimizer, GradientAccumulator
  34. import dllogger
  35. class PretrainingConfig(object):
  36. """Defines pre-training hyperparameters."""
  37. def __init__(self, model_name, **kwargs):
  38. self.model_name = model_name
  39. self.seed = 42
  40. self.debug = False # debug mode for quickly running things
  41. self.do_train = True # pre-train ELECTRA
  42. self.do_eval = False # evaluate generator/discriminator on unlabeled data
  43. self.phase2 = False
  44. # amp
  45. self.amp = True
  46. self.xla = True
  47. self.fp16_compression = False
  48. # optimizer type
  49. self.optimizer = 'adam'
  50. self.gradient_accumulation_steps = 1
  51. # lamb whitelisting for LN and biases
  52. self.skip_adaptive = False
  53. # loss functions
  54. self.electra_objective = True # if False, use the BERT objective instead
  55. self.gen_weight = 1.0 # masked language modeling / generator loss
  56. self.disc_weight = 50.0 # discriminator loss
  57. self.mask_prob = 0.15 # percent of input tokens to mask out / replace
  58. # optimization
  59. self.learning_rate = 5e-4
  60. self.lr_decay_power = 0.5
  61. self.weight_decay_rate = 0.01
  62. self.num_warmup_steps = 10000
  63. self.opt_beta_1 = 0.878
  64. self.opt_beta_2 = 0.974
  65. self.end_lr = 0.0
  66. # training settings
  67. self.log_freq = 10
  68. self.skip_checkpoint = False
  69. self.save_checkpoints_steps = 1000
  70. self.num_train_steps = 1000000
  71. self.num_eval_steps = 100
  72. self.keep_checkpoint_max = 5 # maximum number of recent checkpoint files to keep; change to 0 or None to keep all checkpoints
  73. self.restore_checkpoint = None
  74. self.load_weights = False
  75. self.steps_this_run = -1
  76. # model settings
  77. self.model_size = "base" # one of "small", "base", or "large"
  78. # override the default transformer hparams for the provided model size; see
  79. # modeling.BertConfig for the possible hparams and util.training_utils for
  80. # the defaults
  81. self.model_hparam_overrides = (
  82. kwargs["model_hparam_overrides"]
  83. if "model_hparam_overrides" in kwargs else {})
  84. self.embedding_size = None # bert hidden size by default
  85. self.vocab_size = 30522 # number of tokens in the vocabulary
  86. self.do_lower_case = True # lowercase the input?
  87. # generator settings
  88. self.uniform_generator = False # generator is uniform at random
  89. self.shared_embeddings = True # share generator/discriminator token embeddings?
  90. # self.untied_generator = True # tie all generator/discriminator weights?
  91. self.generator_layers = 1.0 # frac of discriminator layers for generator
  92. self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
  93. self.disallow_correct = False # force the generator to sample incorrect
  94. # tokens (so 15% of tokens are always
  95. # fake)
  96. self.temperature = 1.0 # temperature for sampling from generator
  97. # batch sizes
  98. self.max_seq_length = 128
  99. self.train_batch_size = 128
  100. self.eval_batch_size = 128
  101. self.results_dir = "results"
  102. self.json_summary = None
  103. self.update(kwargs)
  104. # default locations of data files
  105. self.pretrain_tfrecords = os.path.join(
  106. "data", "pretrain_tfrecords/pretrain_data.tfrecord*")
  107. self.vocab_file = os.path.join("vocab", "vocab.txt")
  108. self.model_dir = os.path.join(self.results_dir, "models", model_name)
  109. self.checkpoints_dir = os.path.join(self.model_dir, "checkpoints")
  110. self.weights_dir = os.path.join(self.model_dir, "weights")
  111. self.results_txt = os.path.join(self.results_dir, "unsup_results.txt")
  112. self.results_pkl = os.path.join(self.results_dir, "unsup_results.pkl")
  113. self.log_dir = os.path.join(self.model_dir, "logs")
  114. self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
  115. self.max_seq_length)
  116. # defaults for different-sized model
  117. if self.model_size == "base":
  118. self.embedding_size = 768
  119. self.hidden_size = 768
  120. self.num_hidden_layers = 12
  121. if self.hidden_size % 64 != 0:
  122. raise ValueError("Hidden size {} should be divisible by 64. Number of attention heads is hidden size {} / 64 ".format(self.hidden_size, self.hidden_size))
  123. self.num_attention_heads = int(self.hidden_size / 64.)
  124. elif self.model_size == "large":
  125. self.embedding_size = 1024
  126. self.hidden_size = 1024
  127. self.num_hidden_layers = 24
  128. if self.hidden_size % 64 != 0:
  129. raise ValueError("Hidden size {} should be divisible by 64. Number of attention heads is hidden size {} / 64 ".format(self.hidden_size, self.hidden_size))
  130. self.num_attention_heads = int(self.hidden_size / 64.)
  131. else:
  132. raise ValueError("--model_size : 'base' and 'large supported only.")
  133. self.act_func = "gelu"
  134. self.hidden_dropout_prob = 0.1
  135. self.attention_probs_dropout_prob = 0.1
  136. self.update(kwargs)
  137. def update(self, kwargs):
  138. for k, v in kwargs.items():
  139. if v is not None:
  140. self.__dict__[k] = v
  141. def metric_fn(config, metrics, eval_fn_inputs):
  142. """Computes the loss and accuracy of the model."""
  143. d = eval_fn_inputs
  144. metrics["masked_lm_accuracy"].update_state(
  145. y_true=tf.reshape(d["masked_lm_ids"], [-1]),
  146. y_pred=tf.reshape(d["masked_lm_preds"], [-1]),
  147. sample_weight=tf.reshape(d["masked_lm_weights"], [-1]))
  148. metrics["masked_lm_loss"].update_state(
  149. values=tf.reshape(d["mlm_loss"], [-1]),
  150. sample_weight=tf.reshape(d["masked_lm_weights"], [-1]))
  151. if config.electra_objective:
  152. metrics["sampled_masked_lm_accuracy"].update_state(
  153. y_true=tf.reshape(d["masked_lm_ids"], [-1]),
  154. y_pred=tf.reshape(d["sampled_tokids"], [-1]),
  155. sample_weight=tf.reshape(d["masked_lm_weights"], [-1]))
  156. if config.disc_weight > 0:
  157. metrics["disc_loss"].update_state(d["disc_loss"])
  158. #metrics["disc_auc"].update_state(
  159. # d["disc_labels"] * d["input_mask"],
  160. # d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
  161. metrics["disc_accuracy"].update_state(
  162. y_true=d["disc_labels"], y_pred=d["disc_preds"],
  163. sample_weight=d["input_mask"])
  164. metrics["disc_precision"].update_state(
  165. y_true=d["disc_labels"], y_pred=d["disc_preds"],
  166. sample_weight=d["disc_preds"] * d["input_mask"])
  167. metrics["disc_recall"].update_state(
  168. y_true=d["disc_labels"], y_pred=d["disc_preds"],
  169. sample_weight=d["disc_labels"] * d["input_mask"])
  170. return metrics
  171. @tf.function
  172. def train_one_step(config, model, optimizer, features, accumulator, first_step, take_step, clip_norm=1.0):
  173. #Forward and Backward pass
  174. with tf.GradientTape() as tape:
  175. total_loss, eval_fn_inputs = model(features, is_training=True)
  176. unscaled_loss = tf.stop_gradient(total_loss)
  177. if config.amp:
  178. total_loss = optimizer.get_scaled_loss(total_loss)
  179. #Backpropogate gradients
  180. #tape = hvd.DistributedGradientTape(
  181. # tape, sparse_as_dense=True,
  182. # compression=Compression.fp16 if config.amp and config.fp16_compression else Compression.none)
  183. gradients = tape.gradient(total_loss, model.trainable_variables)
  184. #Get unscaled gradients if AMP
  185. if config.amp:
  186. gradients = optimizer.get_unscaled_gradients(gradients)
  187. #Accumulate gradients
  188. accumulator(gradients)
  189. #Need to call apply_gradients on very first step irrespective of gradient accumulation
  190. #This is required for the optimizer to build it's states
  191. if first_step or take_step:
  192. #All reduce and Clip the accumulated gradients
  193. allreduced_accumulated_gradients = [None if g is None else hvd.allreduce(g / tf.cast(config.gradient_accumulation_steps, g.dtype),
  194. compression=Compression.fp16 if config.amp and config.fp16_compression else Compression.none)
  195. for g in accumulator.gradients]
  196. (clipped_accumulated_gradients, _) = tf.clip_by_global_norm(allreduced_accumulated_gradients, clip_norm=clip_norm)
  197. #Weight update
  198. optimizer.apply_gradients(zip(clipped_accumulated_gradients, model.trainable_variables))
  199. accumulator.reset()
  200. #brodcast model weights after first train step
  201. if first_step:
  202. hvd.broadcast_variables(model.variables, root_rank=0)
  203. hvd.broadcast_variables(optimizer.variables(), root_rank=0)
  204. return unscaled_loss, eval_fn_inputs
  205. def main(e2e_start_time):
  206. # Parse essential argumentss
  207. parser = argparse.ArgumentParser()
  208. parser.add_argument("--model_name", required=True)
  209. parser.add_argument("--model_size", default="base", type=str, help="base or large")
  210. parser.add_argument("--pretrain_tfrecords", type=str)
  211. parser.add_argument("--phase2", action='store_true')
  212. parser.add_argument("--fp16_compression", action='store_true')
  213. parser.add_argument("--amp", action='store_true',
  214. help="Whether to use fp16.")
  215. parser.add_argument("--xla", action='store_true',
  216. help="Whether to use xla.")
  217. parser.add_argument("--seed", default=42, type=int)
  218. parser.add_argument("--num_train_steps", type=int)
  219. parser.add_argument("--num_warmup_steps", type=int)
  220. parser.add_argument("--learning_rate", type=float)
  221. parser.add_argument("--train_batch_size", type=int)
  222. parser.add_argument("--max_seq_length", type=int)
  223. parser.add_argument("--mask_prob", type=float)
  224. parser.add_argument("--disc_weight", type=float)
  225. parser.add_argument("--generator_hidden_size", type=float)
  226. parser.add_argument("--log_freq", type=int, default=10, help="Training metrics logging frequency")
  227. parser.add_argument("--save_checkpoints_steps", type=int)
  228. parser.add_argument("--steps_this_run", type=int, default=-1, help="run a fixed number of steps only")
  229. parser.add_argument("--keep_checkpoint_max", type=int)
  230. parser.add_argument("--restore_checkpoint", default=None, type=str)
  231. parser.add_argument("--load_weights", action='store_true')
  232. parser.add_argument("--weights_dir")
  233. parser.add_argument("--optimizer", default="adam", type=str, help="adam or lamb")
  234. parser.add_argument("--skip_adaptive", action='store_true', help="Whether to apply adaptive LR on LayerNorm and biases")
  235. parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of Gradient Accumulation steps")
  236. parser.add_argument("--lr_decay_power", type=float, default=0.5, help="LR decay power")
  237. parser.add_argument("--opt_beta_1", type=float, default=0.878, help="Optimizer beta1")
  238. parser.add_argument("--opt_beta_2", type=float, default=0.974, help="Optimizer beta2")
  239. parser.add_argument("--end_lr", type=float, default=0.0, help="Ending LR")
  240. parser.add_argument("--log_dir", type=str, default=None, help="Path to store logs")
  241. parser.add_argument("--results_dir", type=str, default=None, help="Path to store all model results")
  242. parser.add_argument("--skip_checkpoint", action='store_true', default=False, help="Path to store logs")
  243. parser.add_argument('--json-summary', type=str, default=None,
  244. help='If provided, the json summary will be written to the specified file.')
  245. args = parser.parse_args()
  246. config = PretrainingConfig(**args.__dict__)
  247. # Padding for divisibility by 8
  248. if config.vocab_size % 8 != 0:
  249. config.vocab_size += 8 - (config.vocab_size % 8)
  250. # Set up tensorflow
  251. hvd.init()
  252. args.log_dir = config.log_dir
  253. # DLLogger
  254. setup_logger(args)
  255. dllogger.metadata('training_sequences_per_second', {'unit': 'sequences/s'})
  256. dllogger.metadata('final_loss', {'unit': None})
  257. dllogger.metadata('e2e_train_time', {'unit': 's'})
  258. set_affinity(hvd.local_rank())
  259. gpus = tf.config.experimental.list_physical_devices('GPU')
  260. if gpus:
  261. for gpu in gpus:
  262. tf.config.experimental.set_memory_growth(gpu, True)
  263. tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
  264. tf.config.optimizer.set_jit(config.xla)
  265. #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": config.amp})
  266. if config.amp:
  267. policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic")
  268. tf.keras.mixed_precision.experimental.set_policy(policy)
  269. print('Compute dtype: %s' % policy.compute_dtype) # Compute dtype: float16
  270. print('Variable dtype: %s' % policy.variable_dtype) # Variable dtype: float32
  271. #tf.random.set_seed(config.seed)
  272. # Set up config cont'
  273. if config.load_weights and config.restore_checkpoint:
  274. raise ValueError("`load_weights` and `restore_checkpoint` should not be on at the same time.")
  275. if config.phase2 and not config.restore_checkpoint:
  276. raise ValueError("`phase2` cannot be used without `restore_checkpoint`.")
  277. utils.heading("Config:")
  278. log_config(config)
  279. # Save pretrain configs
  280. pretrain_config_json = os.path.join(config.checkpoints_dir, 'pretrain_config.json')
  281. if is_main_process():
  282. utils.write_json(config.__dict__, pretrain_config_json)
  283. log("Configuration saved in {}".format(pretrain_config_json))
  284. # Set up model
  285. model = PretrainingModel(config)
  286. # Set up metrics
  287. metrics = dict()
  288. metrics["train_perf"] = tf.keras.metrics.Mean(name="train_perf")
  289. metrics["total_loss"] = tf.keras.metrics.Mean(name="total_loss")
  290. metrics["masked_lm_accuracy"] = tf.keras.metrics.Accuracy(name="masked_lm_accuracy")
  291. metrics["masked_lm_loss"] = tf.keras.metrics.Mean(name="masked_lm_loss")
  292. if config.electra_objective:
  293. metrics["sampled_masked_lm_accuracy"] = tf.keras.metrics.Accuracy(name="sampled_masked_lm_accuracy")
  294. if config.disc_weight > 0:
  295. metrics["disc_loss"] = tf.keras.metrics.Mean(name="disc_loss")
  296. metrics["disc_auc"] = tf.keras.metrics.AUC(name="disc_auc")
  297. metrics["disc_accuracy"] = tf.keras.metrics.Accuracy(name="disc_accuracy")
  298. metrics["disc_precision"] = tf.keras.metrics.Accuracy(name="disc_precision")
  299. metrics["disc_recall"] = tf.keras.metrics.Accuracy(name="disc_recall")
  300. # Set up tensorboard
  301. current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  302. train_log_dir = os.path.join(config.log_dir, current_time,
  303. 'train_' + str(get_rank()) + '_of_' + str(get_world_size()))
  304. train_summary_writer = tf.summary.create_file_writer(train_log_dir)
  305. # Set up dataset
  306. dataset = pretrain_utils.get_dataset(
  307. config, config.train_batch_size, world_size=get_world_size(), rank=get_rank())
  308. train_iterator = iter(dataset)
  309. # Set up optimizer
  310. optimizer = create_optimizer(
  311. init_lr=config.learning_rate,
  312. num_train_steps=config.num_train_steps,
  313. num_warmup_steps=config.num_warmup_steps,
  314. weight_decay_rate=config.weight_decay_rate,
  315. optimizer=config.optimizer,
  316. skip_adaptive=config.skip_adaptive,
  317. power=config.lr_decay_power,
  318. beta_1=config.opt_beta_1,
  319. beta_2=config.opt_beta_2,
  320. end_lr=config.end_lr)
  321. accumulator = GradientAccumulator()
  322. if config.amp:
  323. optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
  324. # Set up model checkpoint
  325. checkpoint = tf.train.Checkpoint(
  326. step=tf.Variable(0), phase2=tf.Variable(False), optimizer=optimizer, model=model)
  327. manager = tf.train.CheckpointManager(checkpoint, config.checkpoints_dir, max_to_keep=config.keep_checkpoint_max)
  328. if config.restore_checkpoint and config.restore_checkpoint != "latest":
  329. checkpoint.restore(config.restore_checkpoint)
  330. log(" ** Restored model checkpoint from {}".format(config.restore_checkpoint))
  331. elif config.restore_checkpoint and config.restore_checkpoint == "latest" and manager.latest_checkpoint:
  332. checkpoint.restore(manager.latest_checkpoint)
  333. log(" ** Restored model checkpoint from {}".format(manager.latest_checkpoint))
  334. elif config.load_weights:
  335. model.generator(model.generator.dummy_inputs)
  336. model.discriminator(model.discriminator.dummy_inputs)
  337. model.generator.load_weights(os.path.join(config.weights_dir, 'generator', 'tf_model.h5'))
  338. model.discriminator.load_weights(os.path.join(config.weights_dir, 'discriminator', 'tf_model.h5'))
  339. else:
  340. log(" ** Initializing from scratch.")
  341. restore_iterator = bool(config.restore_checkpoint) and config.restore_checkpoint == "latest"
  342. # Initialize global step for phase2
  343. if config.phase2 and not bool(checkpoint.phase2):
  344. optimizer.iterations.assign(0)
  345. checkpoint.step.assign(0)
  346. checkpoint.phase2.assign(True)
  347. restore_iterator = False
  348. if bool(checkpoint.phase2):
  349. manager = tf.train.CheckpointManager(
  350. checkpoint, config.checkpoints_dir,
  351. checkpoint_name='ckpt-p2',
  352. max_to_keep=config.keep_checkpoint_max)
  353. # Set up iterator checkpoint
  354. iter_checkpoint = tf.train.Checkpoint(
  355. train_iterator=train_iterator, world_size=tf.Variable(get_world_size()), rank=tf.Variable(get_rank()))
  356. iter_manager = tf.train.CheckpointManager(
  357. iter_checkpoint,
  358. os.path.join(config.checkpoints_dir, 'iter_ckpt_rank_' + '{:02}'.format(get_rank())),
  359. checkpoint_name='iter_ckpt_rank_' + '{:02}'.format(get_rank()),
  360. max_to_keep=config.keep_checkpoint_max)
  361. if restore_iterator and iter_manager.latest_checkpoint:
  362. ckpt_world_size = tf.train.load_variable(
  363. iter_manager.latest_checkpoint, 'world_size/.ATTRIBUTES/VARIABLE_VALUE')
  364. if ckpt_world_size == get_world_size():
  365. iter_checkpoint.restore(iter_manager.latest_checkpoint)
  366. log(" ** Restored iterator checkpoint from {}".format(iter_manager.latest_checkpoint), all_rank=True)
  367. utils.heading("Running training")
  368. accumulator.reset()
  369. train_start, start_step = time.time(), int(checkpoint.step) - 1
  370. local_step = 0
  371. saved_ckpt = False
  372. while int(checkpoint.step) <= config.num_train_steps:
  373. saved_ckpt = False
  374. step = int(checkpoint.step)
  375. features = next(train_iterator)
  376. iter_start = time.time()
  377. # if step == 200: tf.profiler.experimental.start(logdir=train_log_dir)
  378. total_loss, eval_fn_inputs = train_one_step(config, model, optimizer, features, accumulator,
  379. local_step==1, take_step=local_step % args.gradient_accumulation_steps == 0)
  380. # if step == 300: tf.profiler.experimental.stop()
  381. metrics["train_perf"].update_state(
  382. config.train_batch_size * get_world_size() / (time.time() - iter_start))
  383. metrics["total_loss"].update_state(values=total_loss)
  384. metric_fn(config, metrics, eval_fn_inputs)
  385. if (step % args.log_freq == 0) and (local_step % args.gradient_accumulation_steps == 0):
  386. log_info_dict = {k:float(v.result().numpy() * 100) if "accuracy" in k else float(v.result().numpy()) for k, v in metrics.items()}
  387. dllogger.log(step=(step,), data=log_info_dict, verbosity=0)
  388. log('Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
  389. 'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f}, Loss Scaler: {loss_scale}, Elapsed: {elapsed}, ETA: {eta}, '.format(
  390. step=step, **log_info_dict,
  391. loss_scale=optimizer.loss_scale if config.amp else 1,
  392. elapsed=utils.get_readable_time(time.time() - train_start),
  393. eta=utils.get_readable_time(
  394. (time.time() - train_start) / (step - start_step) * (config.num_train_steps - step))),
  395. all_rank=True)
  396. with train_summary_writer.as_default():
  397. for key, m in metrics.items():
  398. tf.summary.scalar(key, m.result(), step=step)
  399. if int(checkpoint.step) < config.num_train_steps:
  400. for m in metrics.values():
  401. m.reset_states()
  402. #Print allreduced metrics on the last step
  403. if (int(checkpoint.step) == config.num_train_steps and (local_step % args.gradient_accumulation_steps == 0)) or ((local_step + 1) % (config.save_checkpoints_steps * args.gradient_accumulation_steps) == 0):
  404. log_info_dict = {k:float(hvd.allreduce(v.result()).numpy() * 100) if "accuracy" in k else float(hvd.allreduce(v.result()).numpy()) for k, v in metrics.items()}
  405. log_info_dict["training_sequences_per_second"] = log_info_dict["train_perf"]
  406. log_info_dict["final_loss"] = log_info_dict["total_loss"]
  407. log_info_dict["e2e_train_time"] = time.time() - e2e_start_time
  408. dllogger.log(step=(), data=log_info_dict, verbosity=0)
  409. log('<FINAL STEP METRICS> Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
  410. 'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f},'.format(
  411. step=step, **log_info_dict),
  412. all_rank=False)
  413. if local_step % args.gradient_accumulation_steps == 0:
  414. checkpoint.step.assign(int(optimizer.iterations))
  415. if not config.skip_checkpoint and (local_step % (config.save_checkpoints_steps * args.gradient_accumulation_steps) == 0):
  416. saved_ckpt = True
  417. if is_main_process():
  418. save_path = manager.save(checkpoint_number=step)
  419. log(" ** Saved model checkpoint for step {}: {}".format(step, save_path))
  420. iter_save_path = iter_manager.save(checkpoint_number=step)
  421. log(" ** Saved iterator checkpoint for step {}: {}".format(step, iter_save_path), all_rank=True)
  422. local_step += 1
  423. if config.steps_this_run != -1 and (local_step % (config.steps_this_run * args.gradient_accumulation_steps) == 0):
  424. #terminating run sooner as steps_this_run has been reached
  425. log("terminating as steps_this_run:{} has been reached".format(config.steps_this_run))
  426. break
  427. step = (int(checkpoint.step) - 1)
  428. dllogger.flush()
  429. if not config.skip_checkpoint and not saved_ckpt:
  430. if is_main_process():
  431. save_path = manager.save(checkpoint_number=step)
  432. log(" ** Saved model checkpoint for step {}: {}".format(step, save_path))
  433. iter_save_path = iter_manager.save(checkpoint_number=step)
  434. log(" ** Saved iterator checkpoint for step {}: {}".format(step, iter_save_path), all_rank=True)
  435. return args
  436. if __name__ == "__main__":
  437. start_time = time.time()
  438. args = main(start_time)
  439. log("Total Time:{:.4f}".format(time.time() - start_time))
  440. if is_main_process():
  441. postprocess_dllog(args)