runner.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. import multiprocessing
  17. import warnings
  18. import tensorflow as tf
  19. import numpy as np
  20. import horovod.tensorflow as hvd
  21. from model import resnet
  22. from utils import hooks
  23. from utils import data_utils
  24. from utils import hvd_utils
  25. from runtime import runner_utils
  26. import dllogger
  27. __all__ = [
  28. 'Runner',
  29. ]
  30. class Runner(object):
  31. def __init__(
  32. self,
  33. # ========= Model HParams ========= #
  34. n_classes=1001,
  35. architecture='resnet50',
  36. input_format='NHWC', # NCHW or NHWC
  37. compute_format='NCHW', # NCHW or NHWC
  38. dtype=tf.float32, # tf.float32 or tf.float16
  39. n_channels=3,
  40. height=224,
  41. width=224,
  42. distort_colors=False,
  43. model_dir=None,
  44. log_dir=None,
  45. data_dir=None,
  46. data_idx_dir=None,
  47. weight_init="fan_out",
  48. # ======= Optimization HParams ======== #
  49. use_xla=False,
  50. use_tf_amp=False,
  51. use_dali=False,
  52. use_cpu=False,
  53. gpu_memory_fraction=1.0,
  54. gpu_id=0,
  55. # ======== Debug Flags ======== #
  56. debug_verbosity=0,
  57. seed=None):
  58. if dtype not in [tf.float32, tf.float16]:
  59. raise ValueError("Unknown dtype received: %s (allowed: `tf.float32` and `tf.float16`)" % dtype)
  60. if compute_format not in ["NHWC", 'NCHW']:
  61. raise ValueError("Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])" % compute_format)
  62. if input_format not in ["NHWC", 'NCHW']:
  63. raise ValueError("Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])" % input_format)
  64. if n_channels not in [1, 3]:
  65. raise ValueError("Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))" % n_channels)
  66. tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None
  67. # ============================================
  68. # Optimsation Flags - Do not remove
  69. # ============================================
  70. os.environ['CUDA_CACHE_DISABLE'] = '0'
  71. os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'
  72. #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  73. os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
  74. os.environ['TF_GPU_THREAD_COUNT'] = '2'
  75. os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
  76. os.environ['TF_ADJUST_HUE_FUSED'] = '1'
  77. os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
  78. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
  79. os.environ['TF_SYNC_ON_FINISH'] = '0'
  80. os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'
  81. os.environ['TF_DISABLE_NVTX_RANGES'] = '1'
  82. os.environ["TF_XLA_FLAGS"] = (os.environ.get("TF_XLA_FLAGS", "") + " --tf_xla_enable_lazy_compilation=false")
  83. # ============================================
  84. # TF-AMP Setup - Do not remove
  85. # ============================================
  86. if dtype == tf.float16:
  87. if use_tf_amp:
  88. raise RuntimeError("TF AMP can not be activated for FP16 precision")
  89. elif use_tf_amp:
  90. os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
  91. else:
  92. os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "0"
  93. # =================================================
  94. model_hparams = tf.contrib.training.HParams(width=height,
  95. height=width,
  96. n_channels=n_channels,
  97. n_classes=n_classes,
  98. dtype=dtype,
  99. input_format=input_format,
  100. compute_format=compute_format,
  101. distort_colors=distort_colors,
  102. seed=tf_seed)
  103. num_preprocessing_threads = 10 if not use_dali else 4
  104. run_config_performance = tf.contrib.training.HParams(num_preprocessing_threads=num_preprocessing_threads,
  105. use_tf_amp=use_tf_amp,
  106. use_xla=use_xla,
  107. use_dali=use_dali,
  108. use_cpu=use_cpu,
  109. gpu_memory_fraction=gpu_memory_fraction,
  110. gpu_id=gpu_id)
  111. run_config_additional = tf.contrib.training.HParams(
  112. model_dir=model_dir, #if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
  113. log_dir=log_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
  114. data_dir=data_dir,
  115. data_idx_dir=data_idx_dir,
  116. num_preprocessing_threads=num_preprocessing_threads)
  117. self.run_hparams = Runner._build_hparams(model_hparams, run_config_additional, run_config_performance)
  118. model_name = architecture
  119. architecture = resnet.model_architectures[architecture]
  120. self._model = resnet.ResnetModel(model_name=model_name,
  121. n_classes=model_hparams.n_classes,
  122. layers_count=architecture["layers"],
  123. layers_depth=architecture["widths"],
  124. expansions=architecture["expansions"],
  125. input_format=model_hparams.input_format,
  126. compute_format=model_hparams.compute_format,
  127. dtype=model_hparams.dtype,
  128. weight_init=weight_init,
  129. use_dali=use_dali,
  130. use_cpu=use_cpu,
  131. cardinality=architecture['cardinality'] if 'cardinality' in architecture else 1,
  132. use_se=architecture['use_se'] if 'use_se' in architecture else False,
  133. se_ratio=architecture['se_ratio'] if 'se_ratio' in architecture else 1)
  134. if self.run_hparams.seed is not None:
  135. tf.set_random_seed(self.run_hparams.seed)
  136. self.training_logging_hook = None
  137. self.eval_logging_hook = None
  138. @staticmethod
  139. def _build_hparams(*args):
  140. hparams = tf.contrib.training.HParams()
  141. for _hparams in args:
  142. if not isinstance(_hparams, tf.contrib.training.HParams):
  143. raise ValueError("Non valid HParams argument object detected:", _hparams)
  144. for key, val in _hparams.values().items():
  145. try:
  146. hparams.add_hparam(name=key, value=val)
  147. except ValueError:
  148. warnings.warn(
  149. "the parameter `{}` already exists - existing value: {} and duplicated value: {}".format(
  150. key, hparams.get(key), val))
  151. return hparams
  152. @staticmethod
  153. def _get_global_batch_size(worker_batch_size):
  154. if hvd_utils.is_using_hvd():
  155. return worker_batch_size * hvd.size()
  156. else:
  157. return worker_batch_size
  158. @staticmethod
  159. def _get_session_config(mode, use_xla, use_dali, use_cpu, gpu_memory_fraction, gpu_id=0):
  160. if mode not in ["train", 'validation', 'benchmark', 'inference']:
  161. raise ValueError("Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')" %
  162. mode)
  163. config = tf.ConfigProto()
  164. if not use_cpu:
  165. # Limit available GPU memory (tune the size)
  166. if use_dali:
  167. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
  168. config = tf.ConfigProto(gpu_options=gpu_options)
  169. config.gpu_options.allow_growth = False
  170. else:
  171. config.gpu_options.allow_growth = True
  172. config.allow_soft_placement = True
  173. config.log_device_placement = False
  174. config.gpu_options.visible_device_list = str(gpu_id)
  175. config.gpu_options.force_gpu_compatible = True # Force pinned memory
  176. if hvd_utils.is_using_hvd():
  177. config.gpu_options.visible_device_list = str(hvd.local_rank())
  178. config.gpu_options.force_gpu_compatible = True # Force pinned memory
  179. if use_xla:
  180. config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
  181. if mode == 'train':
  182. if not use_cpu:
  183. config.intra_op_parallelism_threads = 1 # Avoid pool of Eigen threads
  184. config.inter_op_parallelism_threads = max(2, (multiprocessing.cpu_count() // max(hvd.size(), 8) - 2))
  185. return config
  186. @staticmethod
  187. def _get_run_config(mode, model_dir, use_xla, use_dali, use_cpu, gpu_memory_fraction, gpu_id=0, seed=None):
  188. if mode not in ["train", 'validation', 'benchmark', 'inference']:
  189. raise ValueError("Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')" %
  190. mode)
  191. if seed is not None:
  192. if hvd_utils.is_using_hvd():
  193. tf_random_seed = 2 * (seed + hvd.rank())
  194. else:
  195. tf_random_seed = 2 * seed
  196. else:
  197. tf_random_seed = None
  198. config = tf.estimator.RunConfig(
  199. model_dir=model_dir,
  200. tf_random_seed=tf_random_seed,
  201. save_summary_steps=100 if mode in ['train', 'validation'] else 1e9, # disabled in benchmark mode
  202. save_checkpoints_steps=None,
  203. save_checkpoints_secs=None,
  204. session_config=Runner._get_session_config(mode=mode,
  205. use_xla=use_xla,
  206. use_dali=use_dali,
  207. use_cpu=use_cpu,
  208. gpu_memory_fraction=gpu_memory_fraction,
  209. gpu_id=gpu_id),
  210. keep_checkpoint_max=5,
  211. keep_checkpoint_every_n_hours=1e6, # disabled
  212. log_step_count_steps=1e9,
  213. train_distribute=None,
  214. device_fn=None,
  215. protocol=None,
  216. eval_distribute=None,
  217. experimental_distribute=None)
  218. if mode == 'train':
  219. if hvd_utils.is_using_hvd():
  220. config = config.replace(save_checkpoints_steps=1000 if hvd.rank() == 0 else None,
  221. keep_checkpoint_every_n_hours=3)
  222. else:
  223. config = config.replace(save_checkpoints_steps=1000, keep_checkpoint_every_n_hours=3)
  224. return config
  225. def _get_estimator(self, mode, run_params, use_xla, use_dali, gpu_memory_fraction, gpu_id=0):
  226. if mode not in ["train", 'validation', 'benchmark', 'inference']:
  227. raise ValueError("Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')" %
  228. mode)
  229. run_config = Runner._get_run_config(mode=mode,
  230. model_dir=self.run_hparams.model_dir,
  231. use_xla=use_xla,
  232. use_dali=use_dali,
  233. use_cpu=self.run_hparams.use_cpu,
  234. gpu_memory_fraction=gpu_memory_fraction,
  235. gpu_id=gpu_id,
  236. seed=self.run_hparams.seed)
  237. return tf.estimator.Estimator(model_fn=self._model,
  238. model_dir=self.run_hparams.model_dir,
  239. config=run_config,
  240. params=run_params)
  241. def train(self,
  242. iter_unit,
  243. num_iter,
  244. run_iter,
  245. batch_size,
  246. warmup_steps=50,
  247. weight_decay=1e-4,
  248. lr_init=0.1,
  249. lr_warmup_epochs=5,
  250. momentum=0.9,
  251. log_every_n_steps=1,
  252. loss_scale=256,
  253. label_smoothing=0.0,
  254. mixup=0.0,
  255. use_cosine_lr=False,
  256. use_static_loss_scaling=False,
  257. is_benchmark=False,
  258. quantize=False,
  259. symmetric=False,
  260. quant_delay=0,
  261. finetune_checkpoint=None,
  262. use_final_conv=False,
  263. use_qdq=False):
  264. if iter_unit not in ["epoch", "batch"]:
  265. raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)
  266. if self.run_hparams.data_dir is None and not is_benchmark:
  267. raise ValueError('`data_dir` must be specified for training!')
  268. if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
  269. if use_static_loss_scaling:
  270. os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
  271. else:
  272. os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
  273. else:
  274. use_static_loss_scaling = False # Make sure it hasn't been set to True on FP32 training
  275. num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
  276. global_batch_size = batch_size * num_gpus
  277. if self.run_hparams.data_dir is not None:
  278. filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
  279. data_dir=self.run_hparams.data_dir,
  280. mode="train",
  281. iter_unit=iter_unit,
  282. num_iter=num_iter,
  283. global_batch_size=global_batch_size,
  284. )
  285. steps_per_epoch = num_steps / num_epochs
  286. else:
  287. num_epochs = 1
  288. num_steps = num_iter
  289. steps_per_epoch = num_steps
  290. num_decay_steps = num_steps
  291. num_samples = num_steps * batch_size
  292. if run_iter == -1:
  293. run_iter = num_steps
  294. else:
  295. run_iter = steps_per_epoch * run_iter if iter_unit == "epoch" else run_iter
  296. if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
  297. idx_filenames = runner_utils.parse_dali_idx_dataset(data_idx_dir=self.run_hparams.data_idx_dir,
  298. mode="train")
  299. training_hooks = []
  300. if hvd.rank() == 0:
  301. print('Starting Model Training...')
  302. print("Training Epochs", num_epochs)
  303. print("Total Steps", num_steps)
  304. print("Steps per Epoch", steps_per_epoch)
  305. print("Decay Steps", num_decay_steps)
  306. print("Weight Decay Factor", weight_decay)
  307. print("Init Learning Rate", lr_init)
  308. print("Momentum", momentum)
  309. print("Num GPUs", num_gpus)
  310. print("Per-GPU Batch Size", batch_size)
  311. if is_benchmark:
  312. self.training_logging_hook = hooks.BenchmarkLoggingHook(
  313. global_batch_size=global_batch_size, warmup_steps=warmup_steps, logging_steps=log_every_n_steps
  314. )
  315. else:
  316. self.training_logging_hook = hooks.TrainingLoggingHook(
  317. global_batch_size=global_batch_size,
  318. num_steps=num_steps,
  319. num_samples=num_samples,
  320. num_epochs=num_epochs,
  321. steps_per_epoch=steps_per_epoch,
  322. logging_steps=log_every_n_steps
  323. )
  324. training_hooks.append(self.training_logging_hook)
  325. if hvd_utils.is_using_hvd():
  326. bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
  327. training_hooks.append(bcast_hook)
  328. training_hooks.append(hooks.PrefillStagingAreasHook())
  329. training_hooks.append(hooks.TrainingPartitionHook())
  330. estimator_params = {
  331. 'batch_size': batch_size,
  332. 'steps_per_epoch': steps_per_epoch,
  333. 'num_gpus': num_gpus,
  334. 'momentum': momentum,
  335. 'lr_init': lr_init,
  336. 'lr_warmup_epochs': lr_warmup_epochs,
  337. 'weight_decay': weight_decay,
  338. 'loss_scale': loss_scale,
  339. 'apply_loss_scaling': use_static_loss_scaling,
  340. 'label_smoothing': label_smoothing,
  341. 'mixup': mixup,
  342. 'num_decay_steps': num_decay_steps,
  343. 'use_cosine_lr': use_cosine_lr,
  344. 'use_final_conv': use_final_conv,
  345. 'quantize': quantize,
  346. 'use_qdq': use_qdq,
  347. 'symmetric': symmetric,
  348. 'quant_delay': quant_delay
  349. }
  350. if finetune_checkpoint:
  351. estimator_params['finetune_checkpoint'] = finetune_checkpoint
  352. image_classifier = self._get_estimator(mode='train',
  353. run_params=estimator_params,
  354. use_xla=self.run_hparams.use_xla,
  355. use_dali=self.run_hparams.use_dali,
  356. gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
  357. gpu_id=self.run_hparams.gpu_id)
  358. def training_data_fn():
  359. if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
  360. if hvd.rank() == 0:
  361. print("Using DALI input... ")
  362. return data_utils.get_dali_input_fn(filenames=filenames,
  363. idx_filenames=idx_filenames,
  364. batch_size=batch_size,
  365. height=self.run_hparams.height,
  366. width=self.run_hparams.width,
  367. training=True,
  368. distort_color=self.run_hparams.distort_colors,
  369. num_threads=self.run_hparams.num_preprocessing_threads,
  370. deterministic=False if self.run_hparams.seed is None else True)
  371. elif self.run_hparams.data_dir is not None:
  372. return data_utils.get_tfrecords_input_fn(filenames=filenames,
  373. batch_size=batch_size,
  374. height=self.run_hparams.height,
  375. width=self.run_hparams.width,
  376. training=True,
  377. distort_color=self.run_hparams.distort_colors,
  378. num_threads=self.run_hparams.num_preprocessing_threads,
  379. deterministic=False if self.run_hparams.seed is None else True)
  380. else:
  381. if hvd.rank() == 0:
  382. print("Using Synthetic Data ...")
  383. return data_utils.get_synth_input_fn(
  384. batch_size=batch_size,
  385. height=self.run_hparams.height,
  386. width=self.run_hparams.width,
  387. num_channels=self.run_hparams.n_channels,
  388. data_format=self.run_hparams.input_format,
  389. num_classes=self.run_hparams.n_classes,
  390. dtype=self.run_hparams.dtype,
  391. )
  392. try:
  393. current_step = image_classifier.get_variable_value("global_step")
  394. except ValueError:
  395. current_step = 0
  396. run_iter = max(0, min(run_iter, num_steps - current_step))
  397. print("Current step:", current_step)
  398. if run_iter > 0:
  399. try:
  400. image_classifier.train(
  401. input_fn=training_data_fn,
  402. steps=run_iter,
  403. hooks=training_hooks,
  404. )
  405. except KeyboardInterrupt:
  406. print("Keyboard interrupt")
  407. if hvd.rank() == 0:
  408. if run_iter > 0:
  409. print('Ending Model Training ...')
  410. train_throughput = self.training_logging_hook.mean_throughput.value()
  411. dllogger.log(data={'train_throughput': train_throughput}, step=tuple())
  412. else:
  413. print('Model already trained required number of steps. Skipped')
  414. def evaluate(
  415. self,
  416. iter_unit,
  417. num_iter,
  418. batch_size,
  419. warmup_steps=50,
  420. log_every_n_steps=1,
  421. is_benchmark=False,
  422. export_dir=None,
  423. quantize=False,
  424. symmetric=False,
  425. use_qdq=False,
  426. use_final_conv=False,
  427. ):
  428. if iter_unit not in ["epoch", "batch"]:
  429. raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)
  430. if self.run_hparams.data_dir is None and not is_benchmark:
  431. raise ValueError('`data_dir` must be specified for evaluation!')
  432. if hvd_utils.is_using_hvd() and hvd.rank() != 0:
  433. raise RuntimeError('Multi-GPU inference is not supported')
  434. estimator_params = {'quantize': quantize,
  435. 'symmetric': symmetric,
  436. 'use_qdq': use_qdq,
  437. 'use_final_conv': use_final_conv}
  438. image_classifier = self._get_estimator(mode='validation',
  439. run_params=estimator_params,
  440. use_xla=self.run_hparams.use_xla,
  441. use_dali=self.run_hparams.use_dali,
  442. gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
  443. gpu_id=self.run_hparams.gpu_id)
  444. if self.run_hparams.data_dir is not None:
  445. filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
  446. data_dir=self.run_hparams.data_dir,
  447. mode="validation",
  448. iter_unit=iter_unit,
  449. num_iter=num_iter,
  450. global_batch_size=batch_size,
  451. )
  452. else:
  453. num_epochs = 1
  454. num_decay_steps = -1
  455. num_steps = num_iter
  456. if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
  457. idx_filenames = runner_utils.parse_dali_idx_dataset(data_idx_dir=self.run_hparams.data_idx_dir,
  458. mode="validation")
  459. eval_hooks = []
  460. if hvd.rank() == 0:
  461. self.eval_logging_hook = hooks.BenchmarkLoggingHook(
  462. global_batch_size=batch_size, warmup_steps=warmup_steps, logging_steps=log_every_n_steps
  463. )
  464. eval_hooks.append(self.eval_logging_hook)
  465. print('Starting Model Evaluation...')
  466. print("Evaluation Epochs", num_epochs)
  467. print("Evaluation Steps", num_steps)
  468. print("Decay Steps", num_decay_steps)
  469. print("Global Batch Size", batch_size)
  470. def evaluation_data_fn():
  471. if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
  472. if hvd.rank() == 0:
  473. print("Using DALI input... ")
  474. return data_utils.get_dali_input_fn(filenames=filenames,
  475. idx_filenames=idx_filenames,
  476. batch_size=batch_size,
  477. height=self.run_hparams.height,
  478. width=self.run_hparams.width,
  479. training=False,
  480. distort_color=self.run_hparams.distort_colors,
  481. num_threads=self.run_hparams.num_preprocessing_threads,
  482. deterministic=False if self.run_hparams.seed is None else True)
  483. elif self.run_hparams.data_dir is not None:
  484. return data_utils.get_tfrecords_input_fn(filenames=filenames,
  485. batch_size=batch_size,
  486. height=self.run_hparams.height,
  487. width=self.run_hparams.width,
  488. training=False,
  489. distort_color=self.run_hparams.distort_colors,
  490. num_threads=self.run_hparams.num_preprocessing_threads,
  491. deterministic=False if self.run_hparams.seed is None else True)
  492. else:
  493. print("Using Synthetic Data ...\n")
  494. return data_utils.get_synth_input_fn(
  495. batch_size=batch_size,
  496. height=self.run_hparams.height,
  497. width=self.run_hparams.width,
  498. num_channels=self.run_hparams.n_channels,
  499. data_format=self.run_hparams.input_format,
  500. num_classes=self.run_hparams.n_classes,
  501. dtype=self.run_hparams.dtype,
  502. )
  503. try:
  504. eval_results = image_classifier.evaluate(
  505. input_fn=evaluation_data_fn,
  506. steps=num_steps,
  507. hooks=eval_hooks,
  508. )
  509. eval_throughput = self.eval_logging_hook.mean_throughput.value()
  510. eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000
  511. eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99])
  512. eval_latencies_mean = np.mean(eval_latencies)
  513. dllogger.log(data={
  514. 'top1_accuracy': float(eval_results['top1_accuracy']),
  515. 'top5_accuracy': float(eval_results['top5_accuracy']),
  516. 'eval_throughput': eval_throughput,
  517. 'eval_latency_avg': eval_latencies_mean,
  518. 'eval_latency_p90': eval_latencies_q[0],
  519. 'eval_latency_p95': eval_latencies_q[1],
  520. 'eval_latency_p99': eval_latencies_q[2],
  521. },
  522. step=tuple())
  523. if export_dir is not None:
  524. dllogger.log(data={'export_dir': export_dir}, step=tuple())
  525. input_receiver_fn = data_utils.get_serving_input_receiver_fn(batch_size=None,
  526. height=self.run_hparams.height,
  527. width=self.run_hparams.width,
  528. num_channels=self.run_hparams.n_channels,
  529. data_format=self.run_hparams.input_format,
  530. dtype=self.run_hparams.dtype)
  531. image_classifier.export_savedmodel(export_dir, input_receiver_fn)
  532. except KeyboardInterrupt:
  533. print("Keyboard interrupt")
  534. print('Model evaluation finished')
  535. def predict(self, to_predict, quantize=False, symmetric=False, use_qdq=False, use_final_conv=False):
  536. estimator_params = {
  537. 'quantize': quantize,
  538. 'symmetric': symmetric,
  539. 'use_qdq': use_qdq,
  540. 'use_final_conv': use_final_conv
  541. }
  542. if to_predict is not None:
  543. filenames = runner_utils.parse_inference_input(to_predict)
  544. image_classifier = self._get_estimator(mode='inference',
  545. run_params=estimator_params,
  546. use_xla=self.run_hparams.use_xla,
  547. use_dali=self.run_hparams.use_dali,
  548. gpu_memory_fraction=self.run_hparams.gpu_memory_fraction)
  549. inference_hooks = []
  550. def inference_data_fn():
  551. return data_utils.get_inference_input_fn(filenames=filenames,
  552. height=self.run_hparams.height,
  553. width=self.run_hparams.width,
  554. num_threads=self.run_hparams.num_preprocessing_threads)
  555. try:
  556. inference_results = image_classifier.predict(input_fn=inference_data_fn,
  557. predict_keys=None,
  558. hooks=inference_hooks,
  559. yield_single_examples=True)
  560. for result in inference_results:
  561. print(result['classes'], str(result['probabilities'][result['classes']]))
  562. except KeyboardInterrupt:
  563. print("Keyboard interrupt")
  564. print('Ending Inference ...')