main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected])
  16. from absl import app, flags
  17. import os
  18. import sys
  19. import json
  20. from distributed_embeddings.python.layers import dist_model_parallel as dmp
  21. # Define the flags first before importing TensorFlow.
  22. # Otherwise, enabling XLA-Lite would be impossible with a command-line flag
  23. def define_common_flags():
  24. flags.DEFINE_enum("mode", default="train", enum_values=['inference', 'eval', 'train'],
  25. help='Choose "train" to train the model, "inference" to benchmark inference'
  26. ' and "eval" to run validation')
  27. # Debug parameters
  28. flags.DEFINE_bool("run_eagerly", default=False, help="Disable all tf.function decorators for debugging")
  29. flags.DEFINE_bool("tfdata_debug", default=False, help="Run tf.data operations eagerly (experimental)")
  30. flags.DEFINE_integer("seed", default=None, help="Random seed")
  31. flags.DEFINE_bool("embedding_zeros_initializer", default=False,
  32. help="Initialize the embeddings to zeros. This takes much less time so it's useful"
  33. " for benchmarking and debugging.")
  34. flags.DEFINE_bool("embedding_trainable", default=True, help="If True the embeddings will be trainable, otherwise frozen")
  35. # Hardware and performance features
  36. flags.DEFINE_bool("amp", default=False, help="Enable automatic mixed precision")
  37. flags.DEFINE_bool("use_mde_embeddings", default=True,
  38. help="Use the embedding implementation from the TensorFlow Distributed Embeddings package")
  39. flags.DEFINE_bool("concat_embedding", default=False,
  40. help="Concatenate embeddings with the same dimension. Only supported for singleGPU.")
  41. flags.DEFINE_string("dist_strategy", default='memory_balanced',
  42. help="Strategy for the Distributed Embeddings to use. Supported options are"
  43. "'memory_balanced', 'basic' and 'memory_optimized'")
  44. flags.DEFINE_integer("column_slice_threshold", default=5*1000*1000*1000,
  45. help='Number of elements above which a distributed embedding will be sliced across'
  46. 'multiple devices')
  47. flags.DEFINE_integer("row_slice_threshold", default=10*1000*1000*1000,
  48. help='Number of elements above which a distributed embedding will be sliced across'
  49. 'multiple devices')
  50. flags.DEFINE_integer("data_parallel_threshold", default=None,
  51. help='Number of elements above which a distributed embedding will be sliced across'
  52. 'multiple devices')
  53. flags.DEFINE_integer("cpu_offloading_threshold_gb", default=75,
  54. help='Size of the embedding tables in GB above which '
  55. 'offloading to CPU memory should be employed.'
  56. 'Applies only to singleGPU at the moment.')
  57. flags.DEFINE_bool('cpu', default=False, help='Place the entire model on CPU')
  58. flags.DEFINE_bool("xla", default=False, help="Enable XLA")
  59. flags.DEFINE_integer("loss_scale", default=65536, help="Static loss scale to use with mixed precision training")
  60. flags.DEFINE_integer("inter_op_parallelism", default=None, help='Number of inter op threads')
  61. flags.DEFINE_integer("intra_op_parallelism", default=None, help='Number of intra op threads')
  62. # Checkpointing
  63. flags.DEFINE_string("save_checkpoint_path", default=None,
  64. help="Path to which to save a checkpoint file at the end of the training")
  65. flags.DEFINE_string("restore_checkpoint_path", default=None,
  66. help="Path from which to restore a checkpoint before training")
  67. # Evaluation, logging, profiling
  68. flags.DEFINE_integer("auc_thresholds", default=8000,
  69. help="Number of thresholds for the AUC computation")
  70. flags.DEFINE_integer("epochs", default=1, help="Number of epochs to train for")
  71. flags.DEFINE_integer("max_steps", default=-1, help="Stop the training/inference after this many optimiation steps")
  72. flags.DEFINE_integer("evals_per_epoch", default=1, help='Number of evaluations per epoch')
  73. flags.DEFINE_float("print_freq", default=100, help='Number of steps between debug prints')
  74. flags.DEFINE_integer("profiler_start_step", default=None, help='Step at which to start profiling')
  75. flags.DEFINE_integer("profiled_rank", default=1, help='Rank to profile')
  76. flags.DEFINE_string("log_path", default='dlrm_tf_log.json', help="Path to JSON file for storing benchmark results")
  77. # dataset and dataloading settings
  78. flags.DEFINE_string("dataset_path", default=None,
  79. help="Path to dataset directory")
  80. flags.DEFINE_string("feature_spec", default="feature_spec.yaml",
  81. help="Name of the feature spec file in the dataset directory")
  82. flags.DEFINE_enum("dataset_type", default="tf_raw",
  83. enum_values=['tf_raw', 'synthetic', 'split_tfrecords'],
  84. help='The type of the dataset to use')
  85. flags.DEFINE_boolean("data_parallel_input", default=False, help="Use a data-parallel dataloader,"
  86. " i.e., load a local batch of of data for all input features")
  87. # Synthetic dataset settings
  88. flags.DEFINE_boolean("synthetic_dataset_use_feature_spec", default=False,
  89. help="Create a temporary synthetic dataset based on a real one. "
  90. "Uses --dataset_path and --feature_spec"
  91. "Overrides synthetic dataset dimension flags, except the number of batches")
  92. flags.DEFINE_integer('synthetic_dataset_train_batches', default=64008,
  93. help='Number of training batches in the synthetic dataset')
  94. flags.DEFINE_integer('synthetic_dataset_valid_batches', default=1350,
  95. help='Number of validation batches in the synthetic dataset')
  96. flags.DEFINE_list('synthetic_dataset_cardinalities', default=26*[1000],
  97. help='Number of categories for each embedding table of the synthetic dataset')
  98. flags.DEFINE_list('synthetic_dataset_hotness', default=26*[20],
  99. help='Number of categories for each embedding table of the synthetic dataset')
  100. flags.DEFINE_integer('synthetic_dataset_num_numerical_features', default=13,
  101. help='Number of numerical features of the synthetic dataset')
  102. define_common_flags()
  103. FLAGS = flags.FLAGS
  104. app.define_help_flags()
  105. app.parse_flags_with_usage(sys.argv)
  106. if FLAGS.xla:
  107. if FLAGS.cpu:
  108. os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=fusible --tf_xla_cpu_global_jit'
  109. else:
  110. os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=fusible'
  111. import time
  112. import tensorflow as tf
  113. import tensorflow_addons as tfa
  114. import numpy as np
  115. import horovod.tensorflow as hvd
  116. from tensorflow.keras.mixed_precision import LossScaleOptimizer
  117. import dllogger
  118. from utils.logging import IterTimer, init_logging
  119. from utils.distributed import dist_print
  120. from dataloading.dataloader import create_input_pipelines, get_dataset_metadata
  121. from nn.lr_scheduler import LearningRateScheduler
  122. from nn.model import Model
  123. from nn.evaluator import Evaluator
  124. from nn.trainer import Trainer
  125. def init_tf(FLAGS):
  126. """
  127. Set global options for TensorFlow
  128. """
  129. gpus = tf.config.experimental.list_physical_devices('GPU')
  130. for gpu in gpus:
  131. tf.config.experimental.set_memory_growth(gpu, True)
  132. visible_gpus = []
  133. if gpus and not FLAGS.cpu:
  134. visible_gpus = gpus[hvd.local_rank()]
  135. tf.config.experimental.set_visible_devices(visible_gpus, 'GPU')
  136. if FLAGS.amp:
  137. policy = tf.keras.mixed_precision.Policy("mixed_float16")
  138. tf.keras.mixed_precision.set_global_policy(policy)
  139. tf.config.run_functions_eagerly(FLAGS.run_eagerly)
  140. if FLAGS.tfdata_debug:
  141. tf.data.experimental.enable_debug_mode()
  142. if FLAGS.inter_op_parallelism:
  143. tf.config.threading.set_inter_op_parallelism_threads(FLAGS.inter_op_parallelism)
  144. if FLAGS.intra_op_parallelism:
  145. tf.config.threading.set_intra_op_parallelism_threads(FLAGS.intra_op_parallelism)
  146. tf.random.set_seed(hash((FLAGS.seed, hvd.rank())))
  147. def parse_embedding_dimension(embedding_dim, num_embeddings):
  148. try:
  149. embedding_dim = int(embedding_dim)
  150. embedding_dim = [embedding_dim] * num_embeddings
  151. return embedding_dim
  152. except:
  153. pass
  154. if not isinstance(embedding_dim, str):
  155. return ValueError(f'Unsupported embedding_dimension type: f{type(embedding_dim)}')
  156. if os.path.exists(embedding_dim):
  157. # json file with a list of dimensions for each feature
  158. with open(embedding_dim) as f:
  159. edim = json.load(f)
  160. else:
  161. edim = embedding_dim.split(',')
  162. edim = [int(d) for d in edim]
  163. if len(edim) != num_embeddings:
  164. raise ValueError(f'Length of specified embedding dimensions ({len(edim)}) does not match'
  165. f' the number of embedding layers in the neural network ({num_embeddings})')
  166. return edim
  167. def compute_eval_points(train_batches, evals_per_epoch):
  168. eval_points = np.linspace(0, train_batches - 1, evals_per_epoch + 1)[1:]
  169. eval_points = np.round(eval_points).tolist()
  170. return eval_points
  171. def inference_benchmark(validation_pipeline, dlrm, timer, FLAGS):
  172. if FLAGS.max_steps == -1:
  173. FLAGS.max_steps = 1000
  174. evaluator = Evaluator(model=dlrm, timer=timer, auc_thresholds=FLAGS.auc_thresholds,
  175. max_steps=FLAGS.max_steps, cast_dtype=None)
  176. auc, test_loss, latencies = evaluator(validation_pipeline)
  177. # don't benchmark the first few warmup steps
  178. latencies = latencies[10:]
  179. result_data = {
  180. 'mean_inference_throughput': FLAGS.valid_batch_size / np.mean(latencies),
  181. 'mean_inference_latency': np.mean(latencies)
  182. }
  183. for percentile in [90, 95, 99]:
  184. result_data[f'p{percentile}_inference_latency'] = np.percentile(latencies, percentile)
  185. result_data['auc'] = auc
  186. if hvd.rank() == 0:
  187. dllogger.log(data=result_data, step=tuple())
  188. def validate_cmd_line_flags():
  189. if FLAGS.cpu and hvd.size() > 1:
  190. raise ValueError('MultiGPU mode is not supported when training on CPU')
  191. if FLAGS.cpu and FLAGS.interaction == 'custom_cuda':
  192. raise ValueError('"custom_cuda" dot interaction not supported for CPU. '
  193. 'Please specify "--dot_interaction tensorflow" if you want to run on CPU')
  194. if FLAGS.concat_embedding and hvd.size() != 1:
  195. raise ValueError('Concat embedding is currently unsupported in multiGPU mode.')
  196. if FLAGS.concat_embedding and FLAGS.dataset_type != 'tf_raw':
  197. raise ValueError('Concat embedding is only supported for dataset_type="tf_raw",'
  198. f'got dataset_type={FLAGS.dataset_type}')
  199. all_embedding_dims_equal = all(dim == FLAGS.embedding_dim[0] for dim in FLAGS.embedding_dim)
  200. if FLAGS.concat_embedding and not all_embedding_dims_equal:
  201. raise ValueError('Concat embedding is only supported when all embeddings have the same output dimension,'
  202. f'got embedding_dim={FLAGS.embedding_dim}')
  203. def create_optimizers(flags):
  204. if flags.optimizer == 'sgd':
  205. embedding_optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=flags.learning_rate, momentum=0)
  206. if flags.amp:
  207. embedding_optimizer = LossScaleOptimizer(embedding_optimizer,
  208. initial_scale=flags.loss_scale,
  209. dynamic=False)
  210. mlp_optimizer = embedding_optimizer
  211. elif flags.optimizer == 'adam':
  212. embedding_optimizer = tfa.optimizers.LazyAdam(learning_rate=flags.learning_rate,
  213. beta_1=flags.beta1, beta_2=flags.beta2)
  214. mlp_optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=flags.learning_rate,
  215. beta_1=flags.beta1, beta_2=flags.beta2)
  216. if flags.amp:
  217. # only wrap the mlp optimizer and not the embedding optimizer because the embeddings are not run in FP16
  218. mlp_optimizer = LossScaleOptimizer(mlp_optimizer, initial_scale=flags.loss_scale, dynamic=False)
  219. return mlp_optimizer, embedding_optimizer
  220. def main():
  221. hvd.init()
  222. init_logging(log_path=FLAGS.log_path, params_dict=FLAGS.flag_values_dict(), enabled=hvd.rank()==0)
  223. init_tf(FLAGS)
  224. dataset_metadata = get_dataset_metadata(FLAGS.dataset_path, FLAGS.feature_spec)
  225. FLAGS.embedding_dim = parse_embedding_dimension(FLAGS.embedding_dim,
  226. num_embeddings=len(dataset_metadata.categorical_cardinalities))
  227. validate_cmd_line_flags()
  228. if FLAGS.restore_checkpoint_path is not None:
  229. model = Model.create_from_checkpoint(FLAGS.restore_checkpoint_path)
  230. else:
  231. model = Model(**FLAGS.flag_values_dict(), num_numerical_features=dataset_metadata.num_numerical_features,
  232. categorical_cardinalities=dataset_metadata.categorical_cardinalities,
  233. transpose=False)
  234. table_ids = model.sparse_model.get_local_table_ids(hvd.rank())
  235. print(f'local feature ids={table_ids}')
  236. train_pipeline, validation_pipeline = create_input_pipelines(dataset_type=FLAGS.dataset_type,
  237. dataset_path=FLAGS.dataset_path,
  238. train_batch_size=FLAGS.batch_size,
  239. test_batch_size=FLAGS.valid_batch_size,
  240. table_ids=table_ids,
  241. feature_spec=FLAGS.feature_spec,
  242. rank=hvd.rank(), world_size=hvd.size(),
  243. concat_features=FLAGS.concat_embedding,
  244. data_parallel_input=FLAGS.data_parallel_input)
  245. mlp_optimizer, embedding_optimizer = create_optimizers(FLAGS)
  246. scheduler = LearningRateScheduler([mlp_optimizer, embedding_optimizer],
  247. warmup_steps=FLAGS.warmup_steps,
  248. base_lr=FLAGS.learning_rate,
  249. decay_start_step=FLAGS.decay_start_step,
  250. decay_steps=FLAGS.decay_steps)
  251. timer = IterTimer(train_batch_size=FLAGS.batch_size, test_batch_size=FLAGS.batch_size,
  252. optimizer=embedding_optimizer, print_freq=FLAGS.print_freq, enabled=hvd.rank() == 0)
  253. if FLAGS.mode == 'inference':
  254. inference_benchmark(validation_pipeline, model, timer, FLAGS)
  255. return
  256. elif FLAGS.mode == 'eval':
  257. evaluator = Evaluator(model=model, timer=timer, auc_thresholds=FLAGS.auc_thresholds, max_steps=FLAGS.max_steps)
  258. test_auc, test_loss, _ = evaluator(validation_pipeline)
  259. if hvd.rank() == 0:
  260. dllogger.log(data=dict(auc=test_auc, test_loss=test_loss), step=tuple())
  261. return
  262. eval_points = compute_eval_points(train_batches=len(train_pipeline),
  263. evals_per_epoch=FLAGS.evals_per_epoch)
  264. trainer = Trainer(model, embedding_optimizer=embedding_optimizer, mlp_optimizer=mlp_optimizer, amp=FLAGS.amp,
  265. lr_scheduler=scheduler, tf_dataset_op=train_pipeline.op, cpu=FLAGS.cpu)
  266. evaluator = Evaluator(model=model, timer=timer, auc_thresholds=FLAGS.auc_thresholds, distributed=hvd.size() > 1)
  267. best_auc = 0
  268. best_loss = 1e6
  269. train_begin = time.time()
  270. for epoch in range(FLAGS.epochs):
  271. print('Starting epoch: ', epoch)
  272. for step in range(len(train_pipeline)):
  273. if step == FLAGS.profiler_start_step and hvd.rank() == FLAGS.profiled_rank:
  274. tf.profiler.experimental.start('logdir')
  275. if FLAGS.profiler_start_step and step == FLAGS.profiler_start_step + 100 and hvd.rank() == FLAGS.profiled_rank:
  276. tf.profiler.experimental.stop()
  277. loss = trainer.train_step()
  278. if step == 0 and hvd.size() > 1:
  279. dmp.broadcast_variables(model.variables, root_rank=0)
  280. if step % FLAGS.print_freq == 0:
  281. if tf.math.is_nan(loss):
  282. print('NaN loss encountered in training. Aborting.')
  283. break
  284. timer.step_train(loss=loss)
  285. if FLAGS.max_steps != -1 and step > FLAGS.max_steps:
  286. dist_print(f'Max steps of {FLAGS.max_steps} reached, exiting')
  287. break
  288. if step in eval_points:
  289. test_auc, test_loss, _ = evaluator(validation_pipeline)
  290. dist_print(f'Evaluation completed, AUC: {test_auc:.6f}, test_loss: {test_loss:.6f}')
  291. timer.test_idx = 0
  292. best_auc = max(best_auc, test_auc)
  293. best_loss = min(best_loss, test_loss)
  294. elapsed = time.time() - train_begin
  295. if FLAGS.save_checkpoint_path is not None:
  296. model.save_checkpoint(FLAGS.save_checkpoint_path)
  297. if hvd.rank() == 0:
  298. dist_print(f'Training run completed, elapsed: {elapsed:.0f} [s]')
  299. results = {
  300. 'throughput': FLAGS.batch_size / timer.mean_train_time(),
  301. 'mean_step_time_ms': timer.mean_train_time() * 1000,
  302. 'auc': best_auc,
  303. 'validation_loss': best_loss
  304. }
  305. dllogger.log(data=results, step=tuple())