fit.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. # Copyright 2017-2018 The Apache Software Foundation
  2. #
  3. # Licensed to the Apache Software Foundation (ASF) under one
  4. # or more contributor license agreements. See the NOTICE file
  5. # distributed with this work for additional information
  6. # regarding copyright ownership. The ASF licenses this file
  7. # to you under the Apache License, Version 2.0 (the
  8. # "License"); you may not use this file except in compliance
  9. # with the License. You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing,
  14. # software distributed under the License is distributed on an
  15. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  16. # KIND, either express or implied. See the License for the
  17. # specific language governing permissions and limitations
  18. # under the License.
  19. #
  20. # -----------------------------------------------------------------------
  21. #
  22. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  23. #
  24. # Licensed under the Apache License, Version 2.0 (the "License");
  25. # you may not use this file except in compliance with the License.
  26. # You may obtain a copy of the License at
  27. #
  28. # http://www.apache.org/licenses/LICENSE-2.0
  29. #
  30. # Unless required by applicable law or agreed to in writing, software
  31. # distributed under the License is distributed on an "AS IS" BASIS,
  32. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  33. # See the License for the specific language governing permissions and
  34. # limitations under the License.
  35. """ train fit utility """
  36. import logging
  37. import math
  38. import glob
  39. import os
  40. import random
  41. import sys
  42. import time
  43. import re
  44. from itertools import starmap
  45. import signal
  46. import pickle
  47. import dllogger
  48. import horovod.mxnet as hvd
  49. import mxnet as mx
  50. import mxnet.contrib.amp as amp
  51. import numpy as np
  52. from mxnet import autograd as ag
  53. from mxnet import gluon
  54. import data
  55. from benchmarking import BenchmarkingDataIter
  56. from global_metrics import CompositeMeter, MaxMeter, MinMeter, AvgMeter, PercentileMeter
  57. class PartitionSignalHandler():
  58. def __init__(self, sync_freq: int = 10):
  59. self.step = 0
  60. self.freq = sync_freq
  61. self.t = mx.nd.array([0])
  62. signal.signal(signal.SIGUSR1, self._signal_handler)
  63. signal.signal(signal.SIGTERM, self._signal_handler)
  64. def sync(self) -> bool:
  65. if self.step % self.freq == 0:
  66. new_sync = hvd.allreduce(self.t, average=False)
  67. if new_sync[0] > 0:
  68. self.t[0] = 1
  69. self.step += 1
  70. return self.should_end()
  71. def should_end(self) -> bool:
  72. return bool(self.t[0] > 0)
  73. def _signal_handler(self, signum, frame):
  74. print("Signal received")
  75. self.t[0] = 1
  76. def add_fit_args(parser):
  77. def int_list(x):
  78. return list(map(int, x.split(',')))
  79. def float_list(x):
  80. return list(map(float, x.split(',')))
  81. train = parser.add_argument_group('Training')
  82. train.add_argument('--mode', default='train_val', choices=('train_val', 'train', 'val', 'pred'),
  83. help='mode')
  84. train.add_argument('--seed', type=int, default=None,
  85. help='random seed')
  86. train.add_argument('--gpus', type=int_list, default=[0],
  87. help='list of gpus to run, e.g. 0 or 0,2,5')
  88. train.add_argument('--kv-store', type=str, default='device', choices=('device', 'horovod'),
  89. help='key-value store type')
  90. train.add_argument('--dtype', type=str, default='float16', choices=('float32', 'float16'),
  91. help='precision')
  92. train.add_argument('--amp', action='store_true',
  93. help='If enabled, turn on AMP (Automatic Mixed Precision)')
  94. train.add_argument('--batch-size', type=int, default=192,
  95. help='the batch size')
  96. train.add_argument('--num-epochs', type=int, default=90,
  97. help='number of epochs')
  98. train.add_argument('--run-epochs', type=int, default=-1,
  99. help='number of epochs to run in single run')
  100. train.add_argument('--lr', type=float, default=0.1,
  101. help='initial learning rate')
  102. train.add_argument('--lr-schedule', choices=('multistep', 'cosine'), default='cosine',
  103. help='learning rate schedule')
  104. train.add_argument('--lr-factor', type=float, default=0.256,
  105. help='the ratio to reduce lr on each step')
  106. train.add_argument('--lr-steps', type=float_list, default=[],
  107. help='the epochs to reduce the lr, e.g. 30,60')
  108. train.add_argument('--warmup-epochs', type=int, default=5,
  109. help='the epochs to ramp-up lr to scaled large-batch value')
  110. train.add_argument('--optimizer', type=str, default='sgd',
  111. help='the optimizer type')
  112. train.add_argument('--mom', type=float, default=0.875,
  113. help='momentum for sgd')
  114. train.add_argument('--wd', type=float, default=1 / 32768,
  115. help='weight decay for sgd')
  116. train.add_argument('--label-smoothing', type=float, default=0.1,
  117. help='label smoothing factor')
  118. train.add_argument('--mixup', type=float, default=0,
  119. help='alpha parameter for mixup (if 0 then mixup is not applied)')
  120. train.add_argument('--disp-batches', type=int, default=20,
  121. help='show progress for every n batches')
  122. train.add_argument('--model-prefix', type=str, default='model',
  123. help='model checkpoint prefix')
  124. train.add_argument('--save-frequency', type=int, default=-1,
  125. help='frequency of saving model in epochs (--model-prefix must be specified). '
  126. 'If -1 then save only best model. If 0 then do not save anything.')
  127. train.add_argument('--begin-epoch', type=int, default=0,
  128. help='start the model from an epoch')
  129. train.add_argument('--load', help='checkpoint to load')
  130. train.add_argument('--test-io', action='store_true',
  131. help='test reading speed without training')
  132. train.add_argument('--test-io-mode', default='train', choices=('train', 'val'),
  133. help='data to test')
  134. train.add_argument('--log', type=str, default='log.log',
  135. help='file where to save the log from the experiment')
  136. train.add_argument('--dllogger-log', type=str, default='dllogger_log.log',
  137. help='file where to save the dllogger log from the experiment')
  138. train.add_argument('--workspace', type=str, default='./',
  139. help='path to directory where results will be stored')
  140. train.add_argument('--logdir', type=str, default=None,
  141. help="path to directory where logs will be stored")
  142. train.add_argument('--no-metrics', action='store_true',
  143. help='do not calculate evaluation metrics (for benchmarking)')
  144. train.add_argument('--benchmark-iters', type=int, default=None,
  145. help='run only benchmark-iters iterations from each epoch')
  146. return train
  147. def get_epoch_size(args, kv):
  148. return math.ceil(args.num_examples / args.batch_size)
  149. def get_lr_scheduler(args):
  150. def multistep_schedule(x):
  151. lr = args.lr * \
  152. (args.lr_factor ** (len(list(filter(lambda step: step <= x, args.lr_steps)))))
  153. warmup_coeff = min(1, x / args.warmup_epochs)
  154. return warmup_coeff * lr
  155. def cosine_schedule(x):
  156. steps = args.lr_steps
  157. if not steps or steps[0] > args.warmup_epochs:
  158. steps = [args.warmup_epochs] + steps
  159. elif not steps or steps[0] != 0:
  160. steps = [0] + steps
  161. if steps[-1] != args.num_epochs:
  162. steps.append(args.num_epochs)
  163. if x < args.warmup_epochs:
  164. return args.lr * x / args.warmup_epochs
  165. for i, (step, next_step) in enumerate(zip(steps, steps[1:])):
  166. if next_step > x:
  167. return args.lr * 0.5 * (1 + math.cos(math.pi * (x - step) / (next_step - step))) * (args.lr_factor ** i)
  168. return 0
  169. schedules = {
  170. 'multistep': multistep_schedule,
  171. 'cosine': cosine_schedule,
  172. }
  173. return schedules[args.lr_schedule]
  174. def load_model(args, model):
  175. file = list(glob.glob(
  176. f"{args.workspace}/{args.model_prefix}_*.params"))
  177. if len(file) == 0:
  178. return -1
  179. file = [x for x in sorted(file) if "best.params" not in x]
  180. if len(file) == 0:
  181. return -1
  182. file = file[-1]
  183. epoch = re.match(f".*{args.model_prefix}_([0-9]*)\.params", file)
  184. if epoch is None:
  185. return -1
  186. epoch = int(epoch.group(1))
  187. model.load_parameters(file)
  188. logging.info('Loaded model {}'.format(file))
  189. return epoch
  190. def save_checkpoint(net, epoch, top1, best_acc, model_prefix, workspace, save_frequency, kvstore, force_save=False):
  191. if model_prefix is None or save_frequency == 0 or ('horovod' in kvstore and hvd.rank() != 0):
  192. return
  193. if (save_frequency > 0 and (epoch + 1) % save_frequency == 0) or force_save:
  194. fname = '{}_{:04}.params'.format(model_prefix, epoch)
  195. fname = os.path.join(workspace, fname)
  196. net.save_parameters(fname)
  197. logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(
  198. epoch, fname, top1))
  199. if top1 > best_acc:
  200. fname = os.path.join(workspace, f'{model_prefix}_best.params')
  201. net.save_parameters(fname)
  202. logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(
  203. epoch, fname, top1))
  204. def model_pred(args, model, image):
  205. from imagenet_classes import classes
  206. output = model(image.reshape(-1, *image.shape)
  207. )[0].softmax().as_in_context(mx.cpu())
  208. top = output.argsort(is_ascend=False)[:10]
  209. for i, ind in enumerate(top):
  210. ind = int(ind.asscalar())
  211. logging.info('{:2d}. {:5.2f}% -> {}'.format(i + 1,
  212. output[ind].asscalar() * 100, classes[ind]))
  213. def reduce_metrics(args, metrics, kvstore):
  214. if 'horovod' not in kvstore or not metrics[0] or hvd.size() == 1:
  215. return metrics
  216. m = mx.ndarray.array(metrics[1], ctx=mx.gpu(args.gpus[0]))
  217. reduced = hvd.allreduce(m)
  218. values = reduced.as_in_context(mx.cpu()).asnumpy().tolist()
  219. return (metrics[0], values)
  220. def model_score(args, net, val_data, metric, kvstore):
  221. if val_data is None:
  222. logging.info('Omitting validation: no data')
  223. return [], []
  224. if not isinstance(metric, mx.metric.EvalMetric):
  225. metric = mx.metric.create(metric)
  226. metric.reset()
  227. val_data.reset()
  228. total_batch_size = val_data.batch_size * val_data._num_gpus * \
  229. (hvd.size() if 'horovod' in kvstore else 1)
  230. durations = []
  231. tic = time.time()
  232. outputs = []
  233. for batches in val_data:
  234. # synchronize to previous iteration
  235. for o in outputs:
  236. o.wait_to_read()
  237. data = [b.data[0] for b in batches]
  238. label = [b.label[0][:len(b.data[0]) - b.pad]
  239. for b in batches if len(b.data[0]) != b.pad]
  240. outputs = [net(X) for X, b in zip(data, batches)]
  241. outputs = [o[:len(b.data[0]) - b.pad]
  242. for o, b in zip(outputs, batches) if len(b.data[0]) != b.pad]
  243. metric.update(label, outputs)
  244. durations.append(time.time() - tic)
  245. tic = time.time()
  246. metric = reduce_metrics(args, metric.get_global(), kvstore)
  247. durations = durations[min(len(durations) // 10, 100):]
  248. duration_stats = {
  249. 'ips': total_batch_size / np.mean(durations),
  250. 'latency_avg': np.mean(durations),
  251. }
  252. return metric, duration_stats, durations
  253. class ScalarMetric(mx.metric.Loss):
  254. def update(self, _, scalar):
  255. self.sum_metric += scalar
  256. self.global_sum_metric += scalar
  257. self.num_inst += 1
  258. self.global_num_inst += 1
  259. def label_smoothing(labels, classes, eta):
  260. return labels.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes)
  261. def model_fit(args, net, train_data, eval_metric, optimizer,
  262. optimizer_params, lr_scheduler, eval_data, global_metrics, kvstore, kv,
  263. begin_epoch, num_epoch, run_epoch, model_prefix):
  264. if not isinstance(eval_metric, mx.metric.EvalMetric):
  265. eval_metric = mx.metric.create(eval_metric)
  266. loss_metric = ScalarMetric()
  267. if 'horovod' in kvstore:
  268. trainer = hvd.DistributedTrainer(
  269. net.collect_params(), optimizer, optimizer_params)
  270. else:
  271. trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params,
  272. kvstore=kv, update_on_kvstore=False)
  273. if args.amp:
  274. amp.init_trainer(trainer)
  275. partition_handler = PartitionSignalHandler(1)
  276. sparse_label_loss = (args.label_smoothing == 0 and args.mixup == 0)
  277. loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
  278. loss.hybridize(static_shape=True, static_alloc=True)
  279. local_batch_size = train_data.batch_size
  280. total_batch_size = local_batch_size * train_data._num_gpus * \
  281. (hvd.size() if 'horovod' in kvstore else 1)
  282. durations = []
  283. epoch_size = get_epoch_size(args, kv)
  284. run_epoch = num_epoch if (run_epoch == -1) else (begin_epoch + run_epoch)
  285. def transform_data(images, labels):
  286. if args.mixup != 0:
  287. coeffs = mx.nd.array(np.random.beta(args.mixup, args.mixup, size=images.shape[0])).as_in_context(
  288. images.context)
  289. image_coeffs = coeffs.astype(
  290. images.dtype, copy=False).reshape(*coeffs.shape, 1, 1, 1)
  291. ret_images = image_coeffs * images + \
  292. (1 - image_coeffs) * images[::-1]
  293. ret_labels = label_smoothing(
  294. labels, args.num_classes, args.label_smoothing)
  295. label_coeffs = coeffs.reshape(*coeffs.shape, 1)
  296. ret_labels = label_coeffs * ret_labels + \
  297. (1 - label_coeffs) * ret_labels[::-1]
  298. else:
  299. ret_images = images
  300. if not sparse_label_loss:
  301. ret_labels = label_smoothing(
  302. labels, args.num_classes, args.label_smoothing)
  303. else:
  304. ret_labels = labels
  305. return ret_images, ret_labels
  306. i = -1
  307. best_accuracy = -1
  308. for epoch in range(begin_epoch, min(run_epoch, num_epoch)):
  309. tic = time.time()
  310. btic = time.time()
  311. etic = time.time()
  312. train_data.reset()
  313. eval_metric.reset()
  314. loss_metric.reset()
  315. logging.info('Starting epoch {}'.format(epoch))
  316. outputs = []
  317. if not partition_handler.should_end():
  318. for i, batches in enumerate(train_data):
  319. # synchronize to previous iteration
  320. # for o in outputs:
  321. # o.wait_to_read()
  322. trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))
  323. data = [b.data[0] for b in batches]
  324. label = [b.label[0].as_in_context(
  325. b.data[0].context) for b in batches]
  326. orig_label = label
  327. data, label = zip(*starmap(transform_data, zip(data, label)))
  328. outputs = []
  329. Ls = []
  330. with ag.record():
  331. for x, y in zip(data, label):
  332. z = net(x)
  333. L = loss(z, y)
  334. # store the loss and do backward after we have done forward
  335. # on all GPUs for better speed on multiple GPUs.
  336. Ls.append(L)
  337. outputs.append(z)
  338. if args.amp:
  339. with amp.scale_loss(Ls, trainer) as scaled_loss:
  340. ag.backward(scaled_loss)
  341. else:
  342. ag.backward(Ls)
  343. if 'horovod' in kvstore:
  344. trainer.step(local_batch_size)
  345. else:
  346. trainer.step(total_batch_size)
  347. loss_metric.update(..., np.mean(
  348. [l.asnumpy() for l in Ls]).item())
  349. if args.disp_batches and not (i + 1) % args.disp_batches:
  350. dllogger_it_data = {
  351. 'train.loss': loss_metric.get()[1],
  352. 'train.ips': args.disp_batches * total_batch_size / (time.time() - btic),
  353. 'train.lr': trainer.learning_rate
  354. }
  355. dllogger.log((epoch, i), data=dllogger_it_data)
  356. loss_metric.reset_local()
  357. btic = time.time()
  358. durations.append(time.time() - tic)
  359. tic = time.time()
  360. else:
  361. break
  362. durations = durations[min(len(durations) // 10, 100):]
  363. dllogger_epoch_data = {
  364. 'train.loss': loss_metric.get_global()[1],
  365. 'train.ips': total_batch_size / np.mean(durations)
  366. }
  367. should_break = partition_handler.sync()
  368. if args.mode == 'train_val':
  369. logging.info('Validating epoch {}'.format(epoch))
  370. score, duration_stats, _ = model_score(
  371. args, net, eval_data, eval_metric, kvstore)
  372. dllogger_epoch_data.update(
  373. starmap(lambda key, val: (
  374. 'val.{}'.format(key), val), zip(*score))
  375. )
  376. dllogger_epoch_data.update(
  377. starmap(lambda key, val: ('val.{}'.format(key), val),
  378. duration_stats.items())
  379. )
  380. score = dict(zip(*score))
  381. accuracy = score.get('accuracy', -1)
  382. save_checkpoint(net, epoch, accuracy, best_accuracy,
  383. model_prefix, args.workspace,
  384. args.save_frequency if args.mode == "train_val" else -1,
  385. kvstore, force_save=should_break)
  386. best_accuracy = max(best_accuracy, accuracy)
  387. global_metrics.update_dict(dllogger_epoch_data)
  388. dllogger.log(step=(epoch,), data=dllogger_epoch_data)
  389. def fit(args, model, data_loader):
  390. """
  391. train a model
  392. args : argparse returns
  393. model : the the neural network model
  394. data_loader : function that returns the train and val data iterators
  395. """
  396. start_time = time.time()
  397. # select gpu for horovod process
  398. if 'horovod' in args.kv_store:
  399. args.gpus = [args.gpus[hvd.local_rank()]]
  400. if args.amp:
  401. amp.init()
  402. if args.seed is not None:
  403. logging.info('Setting seeds to {}'.format(args.seed))
  404. random.seed(args.seed)
  405. np.random.seed(args.seed)
  406. mx.random.seed(args.seed)
  407. # kvstore
  408. if 'horovod' in args.kv_store:
  409. kv = None
  410. rank = hvd.rank()
  411. num_workers = hvd.size()
  412. else:
  413. kv = mx.kvstore.create(args.kv_store)
  414. rank = kv.rank
  415. num_workers = kv.num_workers
  416. if args.test_io:
  417. train, val = data_loader(args, kv)
  418. if args.test_io_mode == 'train':
  419. data_iter = train
  420. else:
  421. data_iter = val
  422. tic = time.time()
  423. for i, batch in enumerate(data_iter):
  424. if isinstance(batch, list):
  425. for b in batch:
  426. for j in b.data:
  427. j.wait_to_read()
  428. else:
  429. for j in batch.data:
  430. j.wait_to_read()
  431. if (i + 1) % args.disp_batches == 0:
  432. logging.info('Batch [{}]\tSpeed: {:.2f} samples/sec'.format(
  433. i, args.disp_batches * args.batch_size / (time.time() - tic)))
  434. tic = time.time()
  435. return
  436. start_epoch = load_model(args, model) + 1
  437. if start_epoch == 0:
  438. # all initializers should be specified in the model definition.
  439. # if not, this will raise an error
  440. model.initialize(mx.init.Initializer())
  441. logging.info(f"starting epoch {start_epoch}")
  442. # devices for training
  443. devs = list(map(mx.gpu, args.gpus))
  444. model.collect_params().reset_ctx(devs)
  445. if args.mode == 'pred':
  446. logging.info('Infering image {}'.format(args.data_pred))
  447. model_pred(args, model, data.load_image(args, args.data_pred, devs[0]))
  448. return
  449. # learning rate
  450. lr_scheduler = get_lr_scheduler(args)
  451. optimizer_params = {
  452. 'learning_rate': 0,
  453. 'wd': args.wd,
  454. 'multi_precision': True,
  455. }
  456. # Only a limited number of optimizers have 'momentum' property
  457. has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
  458. if args.optimizer in has_momentum:
  459. optimizer_params['momentum'] = args.mom
  460. # evaluation metrices
  461. if not args.no_metrics:
  462. eval_metrics = ['accuracy']
  463. eval_metrics.append(mx.metric.create(
  464. 'top_k_accuracy', top_k=5))
  465. else:
  466. eval_metrics = []
  467. train, val = data_loader(args, kv)
  468. train = BenchmarkingDataIter(train, args.benchmark_iters)
  469. if val is not None:
  470. val = BenchmarkingDataIter(val, args.benchmark_iters)
  471. if 'horovod' in args.kv_store:
  472. # Fetch and broadcast parameters
  473. params = model.collect_params()
  474. if params is not None:
  475. hvd.broadcast_parameters(params, root_rank=0)
  476. ctx = mx.gpu(hvd.local_rank())
  477. tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
  478. tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
  479. tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
  480. global_metrics = CompositeMeter()
  481. if args.mode in ['train_val', 'train']:
  482. global_metrics.register_metric('train.loss', MinMeter())
  483. global_metrics.register_metric('train.ips', AvgMeter())
  484. if args.mode in ['train_val', 'val']:
  485. global_metrics.register_metric('val.accuracy', MaxMeter())
  486. global_metrics.register_metric('val.top_k_accuracy_5', MaxMeter())
  487. global_metrics.register_metric('val.ips', AvgMeter())
  488. global_metrics.register_metric('val.latency_avg', AvgMeter())
  489. if args.mode in ['val']:
  490. global_metrics.register_metric('val.latency_50', PercentileMeter(50))
  491. global_metrics.register_metric('val.latency_90', PercentileMeter(90))
  492. global_metrics.register_metric('val.latency_95', PercentileMeter(95))
  493. global_metrics.register_metric('val.latency_99', PercentileMeter(99))
  494. global_metrics.register_metric('val.latency_100', PercentileMeter(100))
  495. # run
  496. if args.mode in ['train_val', 'train']:
  497. model_fit(
  498. args,
  499. model,
  500. train,
  501. begin_epoch=start_epoch,
  502. num_epoch=args.num_epochs,
  503. run_epoch=args.run_epochs,
  504. eval_data=val,
  505. eval_metric=eval_metrics,
  506. global_metrics=global_metrics,
  507. kvstore=args.kv_store,
  508. kv=kv,
  509. optimizer=args.optimizer,
  510. optimizer_params=optimizer_params,
  511. lr_scheduler=lr_scheduler,
  512. model_prefix=args.model_prefix,
  513. )
  514. elif args.mode == 'val':
  515. for epoch in range(args.num_epochs): # loop for benchmarking
  516. score, duration_stats, durations = model_score(
  517. args, model, val, eval_metrics, args.kv_store)
  518. dllogger_data = dict(starmap(lambda key, val: (
  519. 'val.{}'.format(key), val), zip(*score)))
  520. dllogger_data.update(
  521. starmap(lambda key, val: ('val.{}'.format(key), val),
  522. duration_stats.items())
  523. )
  524. global_metrics.update_dict(dllogger_data)
  525. for percentile in [50, 90, 95, 99, 100]:
  526. metric_name = 'val.latency_{}'.format(percentile)
  527. dllogger_data[metric_name] = np.percentile(
  528. durations, percentile)
  529. global_metrics.update_metric(metric_name, durations)
  530. dllogger.log(step=(epoch,), data=dllogger_data)
  531. else:
  532. raise ValueError('Wrong mode')
  533. mx.nd.waitall()
  534. dllogger.log(tuple(), data=global_metrics.get())