ncf.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. import os
  31. import sys
  32. import time
  33. from argparse import ArgumentParser
  34. import tensorflow as tf
  35. import pandas as pd
  36. import numpy as np
  37. import cupy as cp
  38. import horovod.tensorflow as hvd
  39. from mpi4py import MPI
  40. from neumf import ncf_model_ops
  41. from input_pipeline import DataGenerator
  42. import dllogger
  43. def parse_args():
  44. """
  45. Parse command line arguments.
  46. """
  47. parser = ArgumentParser(description="Train a Neural Collaborative"
  48. " Filtering model")
  49. parser.add_argument('--data', type=str,
  50. help='path to test and training data files')
  51. parser.add_argument('-e', '--epochs', type=int, default=30,
  52. help='number of epochs to train for')
  53. parser.add_argument('-b', '--batch-size', type=int, default=1048576,
  54. help='number of examples for each iteration')
  55. parser.add_argument('--valid-users-per-batch', type=int, default=5000,
  56. help='Number of users tested in each evaluation batch')
  57. parser.add_argument('-f', '--factors', type=int, default=64,
  58. help='number of predictive factors')
  59. parser.add_argument('--layers', nargs='+', type=int,
  60. default=[256, 256, 128, 64],
  61. help='size of hidden layers for MLP')
  62. parser.add_argument('-n', '--negative-samples', type=int, default=4,
  63. help='number of negative examples per interaction')
  64. parser.add_argument('-l', '--learning-rate', type=float, default=0.0045,
  65. help='learning rate for optimizer')
  66. parser.add_argument('-k', '--topk', type=int, default=10,
  67. help='rank for test examples to be considered a hit')
  68. parser.add_argument('--seed', '-s', type=int, default=None,
  69. help='manually set random seed for random number generation')
  70. parser.add_argument('--target', '-t', type=float, default=0.9562,
  71. help='stop training early at target')
  72. parser.add_argument('--amp', action='store_true', dest='amp', default=False,
  73. help='enable half-precision computations using automatic mixed precision \
  74. (only available in supported containers)')
  75. parser.add_argument('--xla', action='store_true',
  76. help='enable TensorFlow XLA (Accelerated Linear Algebra)')
  77. parser.add_argument('--valid-negative', type=int, default=100,
  78. help='Number of negative samples for each positive test example')
  79. parser.add_argument('--beta1', '-b1', type=float, default=0.25,
  80. help='beta1 for Adam')
  81. parser.add_argument('--beta2', '-b2', type=float, default=0.5,
  82. help='beta2 for Adam')
  83. parser.add_argument('--eps', type=float, default=1e-8,
  84. help='epsilon for Adam')
  85. parser.add_argument('--dropout', type=float, default=0.5,
  86. help='Dropout probability, if equal to 0 will not use dropout at all')
  87. parser.add_argument('--loss-scale', default=8192, type=int,
  88. help='Loss scale value to use when manually enabling mixed precision')
  89. parser.add_argument('--checkpoint-dir', default=None, type=str,
  90. help='Path to the store the result checkpoint file for training')
  91. parser.add_argument('--load-checkpoint-path', default=None, type=str,
  92. help='Path to the checkpoint for initialization. If None will initialize with random weights')
  93. parser.add_argument('--mode', choices=['train', 'test'], default='train', type=str,
  94. help='Passing "test" will only run a single evaluation, \
  95. otherwise full training will be performed')
  96. parser.add_argument('--eval-after', type=int, default=8,
  97. help='Perform evaluations only after this many epochs')
  98. parser.add_argument('--log-path', default='log.json', type=str,
  99. help='Path for the JSON training log')
  100. return parser.parse_args()
  101. def hvd_init():
  102. """
  103. Initialize Horovod
  104. """
  105. # Reduce logging
  106. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  107. tf.logging.set_verbosity(tf.logging.ERROR)
  108. # Initialize horovod
  109. hvd.init()
  110. if hvd.rank() == 0:
  111. print('PY', sys.version)
  112. print('TF', tf.__version__)
  113. def get_local_train_data(pos_train_users, pos_train_items, negative_samples):
  114. """
  115. For distributed, split up the train data and only keep the local portion
  116. """
  117. num_pos_samples = pos_train_users.shape[0]
  118. # Create the entire train set
  119. all_train_users = np.tile(pos_train_users, negative_samples+1)
  120. all_train_items = np.tile(pos_train_items, negative_samples+1)
  121. all_train_labels = np.zeros_like(all_train_users, dtype=np.float32)
  122. all_train_labels[:num_pos_samples] = 1.0
  123. # Get local training set
  124. split_size = all_train_users.shape[0] // hvd.size() + 1
  125. split_indices = np.arange(split_size, all_train_users.shape[0], split_size)
  126. all_train_users_splits = np.split(all_train_users, split_indices)
  127. all_train_items_splits = np.split(all_train_items, split_indices)
  128. all_train_labels_splits = np.split(all_train_labels, split_indices)
  129. assert len(all_train_users_splits) == hvd.size()
  130. local_train_users = all_train_users_splits[hvd.rank()]
  131. local_train_items = all_train_items_splits[hvd.rank()]
  132. local_train_labels = all_train_labels_splits[hvd.rank()]
  133. return local_train_users, local_train_items, local_train_labels
  134. def get_local_test_data(pos_test_users, pos_test_items):
  135. """
  136. For distributed, split up the test data and only keep the local portion
  137. """
  138. split_size = pos_test_users.shape[0] // hvd.size() + 1
  139. split_indices = np.arange(split_size, pos_test_users.shape[0], split_size)
  140. test_users_splits = np.split(pos_test_users, split_indices)
  141. test_items_splits = np.split(pos_test_items, split_indices)
  142. assert len(test_users_splits) == hvd.size()
  143. local_test_users = test_users_splits[hvd.rank()]
  144. local_test_items = test_items_splits[hvd.rank()]
  145. return local_test_users, local_test_items
  146. def main():
  147. script_start = time.time()
  148. hvd_init()
  149. mpi_comm = MPI.COMM_WORLD
  150. args = parse_args()
  151. if hvd.rank() == 0:
  152. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  153. filename=args.log_path),
  154. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)])
  155. else:
  156. dllogger.init(backends=[])
  157. dllogger.metadata("best_epoch", {"unit": None})
  158. dllogger.metadata("first_epoch_to_hit", {"unit": None})
  159. dllogger.metadata("best_hr", {"unit": None})
  160. dllogger.metadata("average_eval_time_per_epoch", {"unit": "s"})
  161. dllogger.metadata("average_train_time_per_epoch", {"unit": "s"})
  162. dllogger.metadata("average_train_throughput", {"unit": "samples/s"})
  163. dllogger.metadata("average_eval_throughput", {"unit": "samples/s"})
  164. args.world_size = hvd.size()
  165. dllogger.log(data=vars(args), step='PARAMETER')
  166. if args.seed is None:
  167. if hvd.rank() == 0:
  168. seed = int(time.time())
  169. else:
  170. seed = None
  171. seed = mpi_comm.bcast(seed, root=0)
  172. else:
  173. seed = args.seed
  174. tf.random.set_random_seed(seed)
  175. np.random.seed(seed)
  176. cp.random.seed(seed)
  177. if args.amp:
  178. os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
  179. if args.checkpoint_dir is not None:
  180. os.makedirs(args.checkpoint_dir, exist_ok=True)
  181. final_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.ckpt')
  182. else:
  183. final_checkpoint_path = None
  184. # Load converted data and get statistics
  185. train_df = pd.read_pickle(args.data+'/train_ratings.pickle')
  186. test_df = pd.read_pickle(args.data+'/test_ratings.pickle')
  187. nb_users, nb_items = train_df.max() + 1
  188. # Extract train and test feature tensors from dataframe
  189. pos_train_users = train_df.iloc[:, 0].values.astype(np.int32)
  190. pos_train_items = train_df.iloc[:, 1].values.astype(np.int32)
  191. pos_test_users = test_df.iloc[:, 0].values.astype(np.int32)
  192. pos_test_items = test_df.iloc[:, 1].values.astype(np.int32)
  193. # Negatives indicator for negatives generation
  194. neg_mat = np.ones((nb_users, nb_items), dtype=np.bool)
  195. neg_mat[pos_train_users, pos_train_items] = 0
  196. # Get the local training/test data
  197. train_users, train_items, train_labels = get_local_train_data(
  198. pos_train_users, pos_train_items, args.negative_samples
  199. )
  200. test_users, test_items = get_local_test_data(
  201. pos_test_users, pos_test_items
  202. )
  203. # Create and run Data Generator in a separate thread
  204. data_generator = DataGenerator(
  205. args.seed,
  206. hvd.local_rank(),
  207. nb_users,
  208. nb_items,
  209. neg_mat,
  210. train_users,
  211. train_items,
  212. train_labels,
  213. args.batch_size // hvd.size(),
  214. args.negative_samples,
  215. test_users,
  216. test_items,
  217. args.valid_users_per_batch,
  218. args.valid_negative,
  219. )
  220. # Create tensorflow session and saver
  221. config = tf.ConfigProto()
  222. config.gpu_options.allow_growth = True
  223. config.gpu_options.visible_device_list = str(hvd.local_rank())
  224. if args.xla:
  225. config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
  226. sess = tf.Session(config=config)
  227. # Input tensors
  228. users = tf.placeholder(tf.int32, shape=(None,))
  229. items = tf.placeholder(tf.int32, shape=(None,))
  230. labels = tf.placeholder(tf.int32, shape=(None,))
  231. is_dup = tf.placeholder(tf.float32, shape=(None,))
  232. dropout = tf.placeholder_with_default(args.dropout, shape=())
  233. # Model ops and saver
  234. hit_rate, ndcg, eval_op, train_op = ncf_model_ops(
  235. users,
  236. items,
  237. labels,
  238. is_dup,
  239. params={
  240. 'val_batch_size': args.valid_negative+1,
  241. 'top_k': args.topk,
  242. 'learning_rate': args.learning_rate,
  243. 'beta_1': args.beta1,
  244. 'beta_2': args.beta2,
  245. 'epsilon': args.eps,
  246. 'num_users': nb_users,
  247. 'num_items': nb_items,
  248. 'num_factors': args.factors,
  249. 'mf_reg': 0,
  250. 'layer_sizes': args.layers,
  251. 'layer_regs': [0. for i in args.layers],
  252. 'dropout': dropout,
  253. 'sigmoid': True,
  254. 'loss_scale': args.loss_scale
  255. },
  256. mode='TRAIN' if args.mode == 'train' else 'EVAL'
  257. )
  258. saver = tf.train.Saver()
  259. # Accuracy metric tensors
  260. hr_sum = tf.get_default_graph().get_tensor_by_name('neumf/hit_rate/total:0')
  261. hr_cnt = tf.get_default_graph().get_tensor_by_name('neumf/hit_rate/count:0')
  262. ndcg_sum = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/total:0')
  263. ndcg_cnt = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/count:0')
  264. # Prepare evaluation data
  265. data_generator.prepare_eval_data()
  266. if args.load_checkpoint_path:
  267. saver.restore(sess, args.load_checkpoint_path)
  268. else:
  269. # Manual initialize weights
  270. sess.run(tf.global_variables_initializer())
  271. # If test mode, run one eval
  272. if args.mode == 'test':
  273. sess.run(tf.local_variables_initializer())
  274. eval_start = time.time()
  275. for user_batch, item_batch, dup_batch \
  276. in zip(data_generator.eval_users, data_generator.eval_items, data_generator.dup_mask):
  277. sess.run(
  278. eval_op,
  279. feed_dict={
  280. users: user_batch,
  281. items: item_batch,
  282. is_dup:dup_batch, dropout: 0.0
  283. }
  284. )
  285. eval_duration = time.time() - eval_start
  286. # Report results
  287. hit_rate_sum = sess.run(hvd.allreduce(hr_sum, average=False))
  288. hit_rate_cnt = sess.run(hvd.allreduce(hr_cnt, average=False))
  289. ndcg_sum = sess.run(hvd.allreduce(ndcg_sum, average=False))
  290. ndcg_cnt = sess.run(hvd.allreduce(ndcg_cnt, average=False))
  291. hit_rate = hit_rate_sum / hit_rate_cnt
  292. ndcg = ndcg_sum / ndcg_cnt
  293. if hvd.rank() == 0:
  294. eval_throughput = pos_test_users.shape[0] * (args.valid_negative + 1) / eval_duration
  295. dllogger.log(step=tuple(), data={'eval_throughput': eval_throughput,
  296. 'eval_time': eval_duration,
  297. 'hr@10': float(hit_rate),
  298. 'ndcg': float(ndcg)})
  299. return
  300. # Performance Metrics
  301. train_times = list()
  302. eval_times = list()
  303. # Accuracy Metrics
  304. first_to_target = None
  305. best_hr = 0
  306. best_epoch = 0
  307. # Buffers for global metrics
  308. global_hr_sum = np.ones(1)
  309. global_hr_count = np.ones(1)
  310. global_ndcg_sum = np.ones(1)
  311. global_ndcg_count = np.ones(1)
  312. # Buffers for local metrics
  313. local_hr_sum = np.ones(1)
  314. local_hr_count = np.ones(1)
  315. local_ndcg_sum = np.ones(1)
  316. local_ndcg_count = np.ones(1)
  317. # Begin training
  318. for epoch in range(args.epochs):
  319. # Train for one epoch
  320. train_start = time.time()
  321. data_generator.prepare_train_data()
  322. for user_batch, item_batch, label_batch \
  323. in zip(data_generator.train_users_batches,
  324. data_generator.train_items_batches,
  325. data_generator.train_labels_batches):
  326. sess.run(
  327. train_op,
  328. feed_dict={
  329. users: user_batch.get(),
  330. items: item_batch.get(),
  331. labels: label_batch.get()
  332. }
  333. )
  334. train_duration = time.time() - train_start
  335. # Only log "warm" epochs
  336. if epoch >= 1:
  337. train_times.append(train_duration)
  338. # Evaluate
  339. if epoch > args.eval_after:
  340. eval_start = time.time()
  341. sess.run(tf.local_variables_initializer())
  342. for user_batch, item_batch, dup_batch \
  343. in zip(data_generator.eval_users,
  344. data_generator.eval_items,
  345. data_generator.dup_mask):
  346. sess.run(
  347. eval_op,
  348. feed_dict={
  349. users: user_batch,
  350. items: item_batch,
  351. is_dup: dup_batch,
  352. dropout: 0.0
  353. }
  354. )
  355. # Compute local metrics
  356. local_hr_sum[0] = sess.run(hr_sum)
  357. local_hr_count[0] = sess.run(hr_cnt)
  358. local_ndcg_sum[0] = sess.run(ndcg_sum)
  359. local_ndcg_count[0] = sess.run(ndcg_cnt)
  360. # Reduce metrics across all workers
  361. mpi_comm.Reduce(local_hr_count, global_hr_count)
  362. mpi_comm.Reduce(local_hr_sum, global_hr_sum)
  363. mpi_comm.Reduce(local_ndcg_count, global_ndcg_count)
  364. mpi_comm.Reduce(local_ndcg_sum, global_ndcg_sum)
  365. # Calculate metrics
  366. hit_rate = global_hr_sum[0] / global_hr_count[0]
  367. ndcg = global_ndcg_sum[0] / global_ndcg_count[0]
  368. eval_duration = time.time() - eval_start
  369. # Only log "warm" epochs
  370. if epoch >= 1:
  371. eval_times.append(eval_duration)
  372. if hvd.rank() == 0:
  373. dllogger.log(step=(epoch,), data={
  374. 'train_time': train_duration,
  375. 'eval_time': eval_duration,
  376. 'hr@10': hit_rate,
  377. 'ndcg': ndcg})
  378. # Update summary metrics
  379. if hit_rate > args.target and first_to_target is None:
  380. first_to_target = epoch
  381. if hit_rate > best_hr:
  382. best_hr = hit_rate
  383. best_epoch = epoch
  384. if hit_rate > args.target and final_checkpoint_path:
  385. saver.save(sess, final_checkpoint_path)
  386. # Final Summary
  387. if hvd.rank() == 0:
  388. train_times = np.array(train_times)
  389. train_throughputs = pos_train_users.shape[0]*(args.negative_samples+1) / train_times
  390. eval_times = np.array(eval_times)
  391. eval_throughputs = pos_test_users.shape[0]*(args.valid_negative+1) / eval_times
  392. dllogger.log(step=tuple(), data={
  393. 'average_train_time_per_epoch': np.mean(train_times),
  394. 'average_train_throughput': np.mean(train_throughputs),
  395. 'average_eval_time_per_epoch': np.mean(eval_times),
  396. 'average_eval_throughput': np.mean(eval_throughputs),
  397. 'first_epoch_to_hit': first_to_target,
  398. 'best_hr': best_hr,
  399. 'best_epoch': best_epoch})
  400. dllogger.flush()
  401. sess.close()
  402. return
  403. if __name__ == '__main__':
  404. main()