main.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #!/usr/bin/python3
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  17. from functools import partial
  18. import json
  19. import logging
  20. from argparse import ArgumentParser
  21. import tensorflow as tf
  22. tf.logging.set_verbosity(tf.logging.ERROR)
  23. import numpy as np
  24. import horovod.tensorflow as hvd
  25. from mpi4py import MPI
  26. import dllogger
  27. import time
  28. from vae.utils.round import round_8
  29. from vae.metrics.recall import recall
  30. from vae.metrics.ndcg import ndcg
  31. from vae.models.train import VAE
  32. from vae.load.preprocessing import load_and_parse_ML_20M
  33. def main():
  34. hvd.init()
  35. mpi_comm = MPI.COMM_WORLD
  36. parser = ArgumentParser(description="Train a Variational Autoencoder for Collaborative Filtering in TensorFlow")
  37. parser.add_argument('--train', action='store_true',
  38. help='Run training of VAE')
  39. parser.add_argument('--test', action='store_true',
  40. help='Run validation of VAE')
  41. parser.add_argument('--inference_benchmark', action='store_true',
  42. help='Measure inference latency and throughput on a variety of batch sizes')
  43. parser.add_argument('--amp', action='store_true', default=False,
  44. help='Enable Automatic Mixed Precision')
  45. parser.add_argument('--epochs', type=int, default=400,
  46. help='Number of epochs to train')
  47. parser.add_argument('--batch_size_train', type=int, default=24576,
  48. help='Global batch size for training')
  49. parser.add_argument('--batch_size_validation', type=int, default=10000,
  50. help='Used both for validation and testing')
  51. parser.add_argument('--validation_step', type=int, default=50,
  52. help='Train epochs for one validation')
  53. parser.add_argument('--warm_up_epochs', type=int, default=5,
  54. help='Number of epochs to omit during benchmark')
  55. parser.add_argument('--total_anneal_steps', type=int, default=15000,
  56. help='Number of annealing steps')
  57. parser.add_argument('--anneal_cap', type=float, default=0.1,
  58. help='Annealing cap')
  59. parser.add_argument('--lam', type=float, default=1.00,
  60. help='Regularization parameter')
  61. parser.add_argument('--lr', type=float, default=0.004,
  62. help='Learning rate')
  63. parser.add_argument('--beta1', type=float, default=0.90,
  64. help='Adam beta1')
  65. parser.add_argument('--beta2', type=float, default=0.90,
  66. help='Adam beta2')
  67. parser.add_argument('--top_results', type=int, default=100,
  68. help='Number of results to be recommended')
  69. parser.add_argument('--xla', action='store_true', default=False,
  70. help='Enable XLA')
  71. parser.add_argument('--trace', action='store_true', default=False,
  72. help='Save profiling traces')
  73. parser.add_argument('--activation', type=str, default='tanh',
  74. help='Activation function')
  75. parser.add_argument('--log_path', type=str, default='./vae_cf.log',
  76. help='Path to the detailed training log to be created')
  77. parser.add_argument('--seed', type=int, default=0,
  78. help='Random seed for TensorFlow and numpy')
  79. parser.add_argument('--data_dir', default='/data', type=str,
  80. help='Directory for storing the training data')
  81. parser.add_argument('--checkpoint_dir', type=str,
  82. default=None,
  83. help='Path for saving a checkpoint after the training')
  84. args = parser.parse_args()
  85. args.world_size = hvd.size()
  86. if args.batch_size_train % hvd.size() != 0:
  87. raise ValueError('Global batch size should be a multiple of the number of workers')
  88. args.local_batch_size = args.batch_size_train // hvd.size()
  89. logger = logging.getLogger("VAE")
  90. if hvd.rank() == 0:
  91. logger.setLevel(logging.INFO)
  92. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  93. filename=args.log_path),
  94. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)])
  95. else:
  96. dllogger.init(backends=[])
  97. logger.setLevel(logging.ERROR)
  98. dllogger.metadata("final_ndcg@100", {"unit": None})
  99. dllogger.metadata("mean_inference_throughput", {"unit": "samples/s"})
  100. dllogger.metadata("mean_training_throughput", {"unit": "samples/s"})
  101. if args.seed is None:
  102. if hvd.rank() == 0:
  103. seed = int(time.time())
  104. else:
  105. seed = None
  106. seed = mpi_comm.bcast(seed, root=0)
  107. else:
  108. seed = args.seed
  109. tf.random.set_random_seed(seed)
  110. np.random.seed(seed)
  111. args.seed = seed
  112. dllogger.log(data=vars(args), step='PARAMETER')
  113. # Suppress TF warnings
  114. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  115. # set AMP
  116. os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1' if args.amp else '0'
  117. # load dataset
  118. (train_data,
  119. validation_data_input,
  120. validation_data_true,
  121. test_data_input,
  122. test_data_true) = load_and_parse_ML_20M(args.data_dir)
  123. # make sure all dims and sizes are divisible by 8
  124. number_of_train_users, number_of_items = train_data.shape
  125. number_of_items = round_8(number_of_items)
  126. for data in [train_data,
  127. validation_data_input,
  128. validation_data_true,
  129. test_data_input,
  130. test_data_true]:
  131. number_of_users, _ = data.shape
  132. data.resize(number_of_users, number_of_items)
  133. number_of_users, number_of_items = train_data.shape
  134. encoder_dims = [number_of_items, 600, 200]
  135. vae = VAE(train_data, encoder_dims, total_anneal_steps=args.total_anneal_steps,
  136. anneal_cap=args.anneal_cap, batch_size_train=args.local_batch_size,
  137. batch_size_validation=args.batch_size_validation, lam=args.lam,
  138. lr=args.lr, beta1=args.beta1, beta2=args.beta2, activation=args.activation,
  139. xla=args.xla, checkpoint_dir=args.checkpoint_dir, trace=args.trace,
  140. top_results=args.top_results)
  141. metrics = {'ndcg@100': partial(ndcg, R=100),
  142. 'recall@20': partial(recall, R=20),
  143. 'recall@50': partial(recall, R=50)}
  144. if args.train:
  145. vae.train(n_epochs=args.epochs, validation_data_input=validation_data_input,
  146. validation_data_true=validation_data_true, metrics=metrics,
  147. validation_step=args.validation_step)
  148. if args.test and hvd.size() <= 1:
  149. test_results = vae.test(test_data_input=test_data_input,
  150. test_data_true=test_data_true, metrics=metrics)
  151. for k, v in test_results.items():
  152. print("{}:\t{}".format(k, v))
  153. elif args.test and hvd.size() > 1:
  154. print("Testing is not supported with horovod multigpu yet")
  155. elif args.test and hvd.size() > 1:
  156. print("Testing is not supported with horovod multigpu yet")
  157. if args.inference_benchmark:
  158. items_per_user = 10
  159. item_indices = np.random.randint(low=0, high=10000, size=items_per_user)
  160. user_indices = np.zeros(len(item_indices))
  161. indices = np.stack([user_indices, item_indices], axis=1)
  162. num_batches = 200
  163. latencies = []
  164. for i in range(num_batches):
  165. start_time = time.time()
  166. _ = vae.query(indices=indices)
  167. end_time = time.time()
  168. if i < 10:
  169. #warmup steps
  170. continue
  171. latencies.append(end_time - start_time)
  172. result_data = {}
  173. result_data[f'batch_1_mean_throughput'] = 1 / np.mean(latencies)
  174. result_data[f'batch_1_mean_latency'] = np.mean(latencies)
  175. result_data[f'batch_1_p90_latency'] = np.percentile(latencies, 90)
  176. result_data[f'batch_1_p95_latency'] = np.percentile(latencies, 95)
  177. result_data[f'batch_1_p99_latency'] = np.percentile(latencies, 99)
  178. dllogger.log(data=result_data, step=tuple())
  179. vae.close_session()
  180. dllogger.flush()
  181. if __name__ == '__main__':
  182. main()