train.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import os
  28. import time
  29. import argparse
  30. import numpy as np
  31. from contextlib import contextmanager
  32. import torch
  33. from torch.utils.data import DataLoader
  34. import torch.distributed as dist
  35. from torch.utils.data.distributed import DistributedSampler
  36. from torch.nn.parallel import DistributedDataParallel as DDP
  37. import models
  38. import loss_functions
  39. import data_functions
  40. from tacotron2_common.utils import ParseFromConfigFile
  41. import dllogger as DLLogger
  42. from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
  43. def parse_args(parser):
  44. """
  45. Parse commandline arguments.
  46. """
  47. parser.add_argument('-o', '--output', type=str, required=True,
  48. help='Directory to save checkpoints')
  49. parser.add_argument('-d', '--dataset-path', type=str,
  50. default='./', help='Path to dataset')
  51. parser.add_argument('-m', '--model-name', type=str, default='', required=True,
  52. help='Model to train')
  53. parser.add_argument('--log-file', type=str, default='nvlog.json',
  54. help='Filename for logging')
  55. parser.add_argument('--anneal-steps', nargs='*',
  56. help='Epochs after which decrease learning rate')
  57. parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1,
  58. help='Factor for annealing learning rate')
  59. parser.add_argument('--config-file', action=ParseFromConfigFile,
  60. type=str, help='Path to configuration file')
  61. parser.add_argument('--seed', default=None, type=int,
  62. help='Seed for random number generators')
  63. # training
  64. training = parser.add_argument_group('training setup')
  65. training.add_argument('--epochs', type=int, required=True,
  66. help='Number of total epochs to run')
  67. training.add_argument('--epochs-per-checkpoint', type=int, default=50,
  68. help='Number of epochs per checkpoint')
  69. training.add_argument('--checkpoint-path', type=str, default='',
  70. help='Checkpoint path to resume training')
  71. training.add_argument('--resume-from-last', action='store_true',
  72. help='Resumes training from the last checkpoint; uses the directory provided with \'--output\' option to search for the checkpoint \"checkpoint_<model_name>_last.pt\"')
  73. training.add_argument('--dynamic-loss-scaling', type=bool, default=True,
  74. help='Enable dynamic loss scaling')
  75. training.add_argument('--amp', action='store_true',
  76. help='Enable AMP')
  77. training.add_argument('--cudnn-enabled', action='store_true',
  78. help='Enable cudnn')
  79. training.add_argument('--cudnn-benchmark', action='store_true',
  80. help='Run cudnn benchmark')
  81. training.add_argument('--disable-uniform-initialize-bn-weight', action='store_true',
  82. help='disable uniform initialization of batchnorm layer weight')
  83. optimization = parser.add_argument_group('optimization setup')
  84. optimization.add_argument(
  85. '--use-saved-learning-rate', default=False, type=bool)
  86. optimization.add_argument('-lr', '--learning-rate', type=float, required=True,
  87. help='Learing rate')
  88. optimization.add_argument('--weight-decay', default=1e-6, type=float,
  89. help='Weight decay')
  90. optimization.add_argument('--grad-clip-thresh', default=1.0, type=float,
  91. help='Clip threshold for gradients')
  92. optimization.add_argument('-bs', '--batch-size', type=int, required=True,
  93. help='Batch size per GPU')
  94. optimization.add_argument('--grad-clip', default=5.0, type=float,
  95. help='Enables gradient clipping and sets maximum gradient norm value')
  96. # dataset parameters
  97. dataset = parser.add_argument_group('dataset parameters')
  98. dataset.add_argument('--load-mel-from-disk', action='store_true',
  99. help='Loads mel spectrograms from disk instead of computing them on the fly')
  100. dataset.add_argument('--training-files',
  101. default='filelists/ljs_audio_text_train_filelist.txt',
  102. type=str, help='Path to training filelist')
  103. dataset.add_argument('--validation-files',
  104. default='filelists/ljs_audio_text_val_filelist.txt',
  105. type=str, help='Path to validation filelist')
  106. dataset.add_argument('--text-cleaners', nargs='*',
  107. default=['english_cleaners'], type=str,
  108. help='Type of text cleaners for input text')
  109. # audio parameters
  110. audio = parser.add_argument_group('audio parameters')
  111. audio.add_argument('--max-wav-value', default=32768.0, type=float,
  112. help='Maximum audiowave value')
  113. audio.add_argument('--sampling-rate', default=22050, type=int,
  114. help='Sampling rate')
  115. audio.add_argument('--filter-length', default=1024, type=int,
  116. help='Filter length')
  117. audio.add_argument('--hop-length', default=256, type=int,
  118. help='Hop (stride) length')
  119. audio.add_argument('--win-length', default=1024, type=int,
  120. help='Window length')
  121. audio.add_argument('--mel-fmin', default=0.0, type=float,
  122. help='Minimum mel frequency')
  123. audio.add_argument('--mel-fmax', default=8000.0, type=float,
  124. help='Maximum mel frequency')
  125. distributed = parser.add_argument_group('distributed setup')
  126. # distributed.add_argument('--distributed-run', default=True, type=bool,
  127. # help='enable distributed run')
  128. distributed.add_argument('--rank', default=0, type=int,
  129. help='Rank of the process, do not set! Done by multiproc module')
  130. distributed.add_argument('--world-size', default=1, type=int,
  131. help='Number of processes, do not set! Done by multiproc module')
  132. distributed.add_argument('--dist-url', type=str, default='tcp://localhost:23456',
  133. help='Url used to set up distributed training')
  134. distributed.add_argument('--group-name', type=str, default='group_name',
  135. required=False, help='Distributed group name')
  136. distributed.add_argument('--dist-backend', default='nccl', type=str, choices={'nccl'},
  137. help='Distributed run backend')
  138. benchmark = parser.add_argument_group('benchmark')
  139. benchmark.add_argument('--bench-class', type=str, default='')
  140. return parser
  141. def reduce_tensor(tensor, num_gpus):
  142. rt = tensor.clone()
  143. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  144. if rt.is_floating_point():
  145. rt = rt/num_gpus
  146. else:
  147. rt = torch.div(rt, num_gpus, rounding_mode='floor')
  148. return rt
  149. def init_distributed(args, world_size, rank, group_name):
  150. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  151. print("Initializing Distributed")
  152. # Set cuda device so everything is done on the right GPU.
  153. torch.cuda.set_device(rank % torch.cuda.device_count())
  154. # Initialize distributed communication
  155. dist.init_process_group(
  156. backend=args.dist_backend, init_method=args.dist_url,
  157. world_size=world_size, rank=rank, group_name=group_name)
  158. print("Done initializing distributed")
  159. def save_checkpoint(model, optimizer, scaler, epoch, config, output_dir,
  160. model_name, local_rank, world_size):
  161. random_rng_state = torch.random.get_rng_state().cuda()
  162. cuda_rng_state = torch.cuda.get_rng_state(local_rank).cuda()
  163. random_rng_states_all = [torch.empty_like(random_rng_state) for _ in range(world_size)]
  164. cuda_rng_states_all = [torch.empty_like(cuda_rng_state) for _ in range(world_size)]
  165. if world_size > 1:
  166. dist.all_gather(random_rng_states_all, random_rng_state)
  167. dist.all_gather(cuda_rng_states_all, cuda_rng_state)
  168. else:
  169. random_rng_states_all = [random_rng_state]
  170. cuda_rng_states_all = [cuda_rng_state]
  171. random_rng_states_all = torch.stack(random_rng_states_all).cpu()
  172. cuda_rng_states_all = torch.stack(cuda_rng_states_all).cpu()
  173. if local_rank == 0:
  174. checkpoint = {'epoch': epoch,
  175. 'cuda_rng_state_all': cuda_rng_states_all,
  176. 'random_rng_states_all': random_rng_states_all,
  177. 'config': config,
  178. 'state_dict': model.state_dict(),
  179. 'optimizer': optimizer.state_dict(),
  180. 'scaler': scaler.state_dict()}
  181. checkpoint_filename = "checkpoint_{}_{}.pt".format(model_name, epoch)
  182. checkpoint_path = os.path.join(output_dir, checkpoint_filename)
  183. print("Saving model and optimizer state at epoch {} to {}".format(
  184. epoch, checkpoint_path))
  185. torch.save(checkpoint, checkpoint_path)
  186. symlink_src = checkpoint_filename
  187. symlink_dst = os.path.join(
  188. output_dir, "checkpoint_{}_last.pt".format(model_name))
  189. if os.path.exists(symlink_dst) and os.path.islink(symlink_dst):
  190. print("Updating symlink", symlink_dst, "to point to", symlink_src)
  191. os.remove(symlink_dst)
  192. os.symlink(symlink_src, symlink_dst)
  193. def get_last_checkpoint_filename(output_dir, model_name):
  194. symlink = os.path.join(output_dir, "checkpoint_{}_last.pt".format(model_name))
  195. if os.path.exists(symlink):
  196. print("Loading checkpoint from symlink", symlink)
  197. return os.path.join(output_dir, os.readlink(symlink))
  198. else:
  199. print("No last checkpoint available - starting from epoch 0 ")
  200. return ""
  201. def load_checkpoint(model, optimizer, scaler, epoch, filepath, local_rank):
  202. checkpoint = torch.load(filepath, map_location='cpu')
  203. epoch[0] = checkpoint['epoch']+1
  204. device_id = local_rank % torch.cuda.device_count()
  205. torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
  206. if 'random_rng_states_all' in checkpoint:
  207. torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
  208. elif 'random_rng_state' in checkpoint:
  209. torch.random.set_rng_state(checkpoint['random_rng_state'])
  210. else:
  211. raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
  212. model.load_state_dict(checkpoint['state_dict'])
  213. optimizer.load_state_dict(checkpoint['optimizer'])
  214. scaler.load_state_dict(checkpoint['scaler'])
  215. return checkpoint['config']
  216. # adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
  217. # Following snippet is licensed under MIT license
  218. @contextmanager
  219. def evaluating(model):
  220. '''Temporarily switch to evaluation mode.'''
  221. istrain = model.training
  222. try:
  223. model.eval()
  224. yield model
  225. finally:
  226. if istrain:
  227. model.train()
  228. def validate(model, criterion, valset, epoch, batch_iter, batch_size,
  229. world_size, collate_fn, distributed_run, perf_bench, batch_to_gpu, amp_run):
  230. """Handles all the validation scoring and printing"""
  231. with evaluating(model), torch.no_grad():
  232. val_sampler = DistributedSampler(valset) if distributed_run else None
  233. val_loader = DataLoader(valset, num_workers=1, shuffle=False,
  234. sampler=val_sampler,
  235. batch_size=batch_size, pin_memory=False,
  236. collate_fn=collate_fn,
  237. drop_last=(True if perf_bench else False))
  238. val_loss = 0.0
  239. num_iters = 0
  240. val_items_per_sec = 0.0
  241. for i, batch in enumerate(val_loader):
  242. torch.cuda.synchronize()
  243. iter_start_time = time.perf_counter()
  244. x, y, num_items = batch_to_gpu(batch)
  245. #AMP upstream autocast
  246. with torch.cuda.amp.autocast(enabled=amp_run):
  247. y_pred = model(x)
  248. loss = criterion(y_pred, y)
  249. if distributed_run:
  250. reduced_val_loss = reduce_tensor(loss.data, world_size).item()
  251. reduced_num_items = reduce_tensor(num_items.data, 1).item()
  252. else: #
  253. reduced_val_loss = loss.item()
  254. reduced_num_items = num_items.item()
  255. val_loss += reduced_val_loss
  256. torch.cuda.synchronize()
  257. iter_stop_time = time.perf_counter()
  258. iter_time = iter_stop_time - iter_start_time
  259. items_per_sec = reduced_num_items/iter_time
  260. DLLogger.log(step=(epoch, batch_iter, i), data={'val_items_per_sec': items_per_sec})
  261. val_items_per_sec += items_per_sec
  262. num_iters += 1
  263. val_loss = val_loss/num_iters
  264. val_items_per_sec = val_items_per_sec/num_iters
  265. DLLogger.log(step=(epoch,), data={'val_loss': val_loss})
  266. DLLogger.log(step=(epoch,), data={'val_items_per_sec': val_items_per_sec})
  267. return val_loss, val_items_per_sec
  268. def adjust_learning_rate(iteration, epoch, optimizer, learning_rate,
  269. anneal_steps, anneal_factor, rank):
  270. p = 0
  271. if anneal_steps is not None:
  272. for i, a_step in enumerate(anneal_steps):
  273. if epoch >= int(a_step):
  274. p = p+1
  275. if anneal_factor == 0.3:
  276. lr = learning_rate*((0.1 ** (p//2))*(1.0 if p % 2 == 0 else 0.3))
  277. else:
  278. lr = learning_rate*(anneal_factor ** p)
  279. if optimizer.param_groups[0]['lr'] != lr:
  280. DLLogger.log(step=(epoch, iteration), data={'learning_rate changed': str(optimizer.param_groups[0]['lr'])+" -> "+str(lr)})
  281. for param_group in optimizer.param_groups:
  282. param_group['lr'] = lr
  283. def main():
  284. parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
  285. parser = parse_args(parser)
  286. args, _ = parser.parse_known_args()
  287. if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  288. local_rank = int(os.environ['LOCAL_RANK'])
  289. world_size = int(os.environ['WORLD_SIZE'])
  290. else:
  291. local_rank = args.rank
  292. world_size = args.world_size
  293. distributed_run = world_size > 1
  294. if args.seed is not None:
  295. torch.manual_seed(args.seed + local_rank)
  296. np.random.seed(args.seed + local_rank)
  297. if local_rank == 0:
  298. log_file = os.path.join(args.output, args.log_file)
  299. DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_file),
  300. StdOutBackend(Verbosity.VERBOSE)])
  301. else:
  302. DLLogger.init(backends=[])
  303. for k,v in vars(args).items():
  304. DLLogger.log(step="PARAMETER", data={k:v})
  305. DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'})
  306. DLLogger.metadata('run_time', {'unit': 's'})
  307. DLLogger.metadata('val_loss', {'unit': None})
  308. DLLogger.metadata('train_items_per_sec', {'unit': 'items/s'})
  309. DLLogger.metadata('val_items_per_sec', {'unit': 'items/s'})
  310. model_name = args.model_name
  311. parser = models.model_parser(model_name, parser)
  312. args, _ = parser.parse_known_args()
  313. torch.backends.cudnn.enabled = args.cudnn_enabled
  314. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  315. if distributed_run:
  316. init_distributed(args, world_size, local_rank, args.group_name)
  317. torch.cuda.synchronize()
  318. run_start_time = time.perf_counter()
  319. model_config = models.get_model_config(model_name, args)
  320. model = models.get_model(model_name, model_config,
  321. cpu_run=False,
  322. uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)
  323. if distributed_run:
  324. model = DDP(model, device_ids=[local_rank], output_device=local_rank)
  325. optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
  326. weight_decay=args.weight_decay)
  327. scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
  328. try:
  329. sigma = args.sigma
  330. except AttributeError:
  331. sigma = None
  332. start_epoch = [0]
  333. if args.resume_from_last:
  334. args.checkpoint_path = get_last_checkpoint_filename(args.output, model_name)
  335. if args.checkpoint_path != "":
  336. model_config = load_checkpoint(model, optimizer, scaler, start_epoch,
  337. args.checkpoint_path, local_rank)
  338. start_epoch = start_epoch[0]
  339. criterion = loss_functions.get_loss_function(model_name, sigma)
  340. try:
  341. n_frames_per_step = args.n_frames_per_step
  342. except AttributeError:
  343. n_frames_per_step = None
  344. collate_fn = data_functions.get_collate_function(
  345. model_name, n_frames_per_step)
  346. trainset = data_functions.get_data_loader(
  347. model_name, args.dataset_path, args.training_files, args)
  348. if distributed_run:
  349. train_sampler = DistributedSampler(trainset, seed=(args.seed or 0))
  350. shuffle = False
  351. else:
  352. train_sampler = None
  353. shuffle = True
  354. train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
  355. sampler=train_sampler,
  356. batch_size=args.batch_size, pin_memory=False,
  357. drop_last=True, collate_fn=collate_fn)
  358. valset = data_functions.get_data_loader(
  359. model_name, args.dataset_path, args.validation_files, args)
  360. batch_to_gpu = data_functions.get_batch_to_gpu(model_name)
  361. iteration = 0
  362. train_epoch_items_per_sec = 0.0
  363. val_loss = 0.0
  364. num_iters = 0
  365. model.train()
  366. for epoch in range(start_epoch, args.epochs):
  367. torch.cuda.synchronize()
  368. epoch_start_time = time.perf_counter()
  369. # used to calculate avg items/sec over epoch
  370. reduced_num_items_epoch = 0
  371. train_epoch_items_per_sec = 0.0
  372. num_iters = 0
  373. reduced_loss = 0
  374. if distributed_run:
  375. train_loader.sampler.set_epoch(epoch)
  376. for i, batch in enumerate(train_loader):
  377. torch.cuda.synchronize()
  378. iter_start_time = time.perf_counter()
  379. DLLogger.log(step=(epoch, i),
  380. data={'glob_iter/iters_per_epoch': str(iteration)+"/"+str(len(train_loader))})
  381. adjust_learning_rate(iteration, epoch, optimizer, args.learning_rate,
  382. args.anneal_steps, args.anneal_factor, local_rank)
  383. model.zero_grad()
  384. x, y, num_items = batch_to_gpu(batch)
  385. #AMP upstream autocast
  386. with torch.cuda.amp.autocast(enabled=args.amp):
  387. y_pred = model(x)
  388. loss = criterion(y_pred, y)
  389. if distributed_run:
  390. reduced_loss = reduce_tensor(loss.data, world_size).item()
  391. reduced_num_items = reduce_tensor(num_items.data, 1).item()
  392. else:
  393. reduced_loss = loss.item()
  394. reduced_num_items = num_items.item()
  395. if np.isnan(reduced_loss):
  396. raise Exception("loss is NaN")
  397. DLLogger.log(step=(epoch,i), data={'train_loss': reduced_loss})
  398. num_iters += 1
  399. # accumulate number of items processed in this epoch
  400. reduced_num_items_epoch += reduced_num_items
  401. if args.amp:
  402. scaler.scale(loss).backward()
  403. scaler.unscale_(optimizer)
  404. torch.nn.utils.clip_grad_norm_(
  405. model.parameters(), args.grad_clip_thresh)
  406. scaler.step(optimizer)
  407. scaler.update()
  408. else:
  409. loss.backward()
  410. torch.nn.utils.clip_grad_norm_(
  411. model.parameters(), args.grad_clip_thresh)
  412. optimizer.step()
  413. model.zero_grad(set_to_none=True)
  414. torch.cuda.synchronize()
  415. iter_stop_time = time.perf_counter()
  416. iter_time = iter_stop_time - iter_start_time
  417. items_per_sec = reduced_num_items/iter_time
  418. train_epoch_items_per_sec += items_per_sec
  419. DLLogger.log(step=(epoch, i), data={'train_items_per_sec': items_per_sec})
  420. DLLogger.log(step=(epoch, i), data={'train_iter_time': iter_time})
  421. iteration += 1
  422. torch.cuda.synchronize()
  423. epoch_stop_time = time.perf_counter()
  424. epoch_time = epoch_stop_time - epoch_start_time
  425. DLLogger.log(step=(epoch,), data={'train_items_per_sec':
  426. (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
  427. DLLogger.log(step=(epoch,), data={'train_loss': reduced_loss})
  428. DLLogger.log(step=(epoch,), data={'train_epoch_time': epoch_time})
  429. val_loss, val_items_per_sec = validate(model, criterion, valset, epoch,
  430. iteration, args.batch_size,
  431. world_size, collate_fn,
  432. distributed_run, args.bench_class=="perf-train",
  433. batch_to_gpu,
  434. args.amp)
  435. if (epoch % args.epochs_per_checkpoint == 0) and (args.bench_class == "" or args.bench_class == "train"):
  436. save_checkpoint(model, optimizer, scaler, epoch, model_config,
  437. args.output, args.model_name, local_rank, world_size)
  438. if local_rank == 0:
  439. DLLogger.flush()
  440. torch.cuda.synchronize()
  441. run_stop_time = time.perf_counter()
  442. run_time = run_stop_time - run_start_time
  443. DLLogger.log(step=tuple(), data={'run_time': run_time})
  444. DLLogger.log(step=tuple(), data={'val_loss': val_loss})
  445. DLLogger.log(step=tuple(), data={'train_loss': reduced_loss})
  446. DLLogger.log(step=tuple(), data={'train_items_per_sec':
  447. (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
  448. DLLogger.log(step=tuple(), data={'val_items_per_sec': val_items_per_sec})
  449. if local_rank == 0:
  450. DLLogger.flush()
  451. if __name__ == '__main__':
  452. main()