ncf.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. import os
  31. import sys
  32. import time
  33. from argparse import ArgumentParser
  34. import tensorflow as tf
  35. import pandas as pd
  36. import numpy as np
  37. import cupy as cp
  38. import horovod.tensorflow as hvd
  39. from mpi4py import MPI
  40. from neumf import ncf_model_ops
  41. from input_pipeline import DataGenerator
  42. import dllogger
  43. def parse_args():
  44. """
  45. Parse command line arguments.
  46. """
  47. parser = ArgumentParser(description="Train a Neural Collaborative"
  48. " Filtering model")
  49. parser.add_argument('--data', type=str,
  50. help='path to test and training data files')
  51. parser.add_argument('-e', '--epochs', type=int, default=30,
  52. help='number of epochs to train for')
  53. parser.add_argument('-b', '--batch-size', type=int, default=1048576,
  54. help='number of examples for each iteration')
  55. parser.add_argument('--valid-users-per-batch', type=int, default=5000,
  56. help='Number of users tested in each evaluation batch')
  57. parser.add_argument('-f', '--factors', type=int, default=64,
  58. help='number of predictive factors')
  59. parser.add_argument('--layers', nargs='+', type=int,
  60. default=[256, 256, 128, 64],
  61. help='size of hidden layers for MLP')
  62. parser.add_argument('-n', '--negative-samples', type=int, default=4,
  63. help='number of negative examples per interaction')
  64. parser.add_argument('-l', '--learning-rate', type=float, default=0.0045,
  65. help='learning rate for optimizer')
  66. parser.add_argument('-k', '--topk', type=int, default=10,
  67. help='rank for test examples to be considered a hit')
  68. parser.add_argument('--seed', '-s', type=int, default=None,
  69. help='manually set random seed for random number generation')
  70. parser.add_argument('--target', '-t', type=float, default=0.9562,
  71. help='stop training early at target')
  72. parser.add_argument('--amp', action='store_true', dest='amp', default=False,
  73. help='enable half-precision computations using automatic mixed precision \
  74. (only available in supported containers)')
  75. parser.add_argument('--xla', action='store_true',
  76. help='enable TensorFlow XLA (Accelerated Linear Algebra)')
  77. parser.add_argument('--valid-negative', type=int, default=100,
  78. help='Number of negative samples for each positive test example')
  79. parser.add_argument('--beta1', '-b1', type=float, default=0.25,
  80. help='beta1 for Adam')
  81. parser.add_argument('--beta2', '-b2', type=float, default=0.5,
  82. help='beta2 for Adam')
  83. parser.add_argument('--eps', type=float, default=1e-8,
  84. help='epsilon for Adam')
  85. parser.add_argument('--dropout', type=float, default=0.5,
  86. help='Dropout probability, if equal to 0 will not use dropout at all')
  87. parser.add_argument('--loss-scale', default=8192, type=int,
  88. help='Loss scale value to use when manually enabling mixed precision')
  89. parser.add_argument('--checkpoint-dir', default=None, type=str,
  90. help='Path to the store the result checkpoint file for training')
  91. parser.add_argument('--load-checkpoint-path', default=None, type=str,
  92. help='Path to the checkpoint for initialization. If None will initialize with random weights')
  93. parser.add_argument('--mode', choices=['train', 'test'], default='train', type=str,
  94. help='Passing "test" will only run a single evaluation, \
  95. otherwise full training will be performed')
  96. parser.add_argument('--eval-after', type=int, default=8,
  97. help='Perform evaluations only after this many epochs')
  98. parser.add_argument('--log-path', default='log.json', type=str,
  99. help='Path for the JSON training log')
  100. return parser.parse_args()
  101. def hvd_init():
  102. """
  103. Initialize Horovod
  104. """
  105. # Reduce logging
  106. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  107. tf.logging.set_verbosity(tf.logging.ERROR)
  108. # Initialize horovod
  109. hvd.init()
  110. if hvd.rank() == 0:
  111. print('PY', sys.version)
  112. print('TF', tf.__version__)
  113. def get_local_train_data(pos_train_users, pos_train_items, negative_samples):
  114. """
  115. For distributed, split up the train data and only keep the local portion
  116. """
  117. num_pos_samples = pos_train_users.shape[0]
  118. # Create the entire train set
  119. all_train_users = np.tile(pos_train_users, negative_samples+1)
  120. all_train_items = np.tile(pos_train_items, negative_samples+1)
  121. all_train_labels = np.zeros_like(all_train_users, dtype=np.float32)
  122. all_train_labels[:num_pos_samples] = 1.0
  123. # Get local training set
  124. split_size = all_train_users.shape[0] // hvd.size() + 1
  125. split_indices = np.arange(split_size, all_train_users.shape[0], split_size)
  126. all_train_users_splits = np.split(all_train_users, split_indices)
  127. all_train_items_splits = np.split(all_train_items, split_indices)
  128. all_train_labels_splits = np.split(all_train_labels, split_indices)
  129. assert len(all_train_users_splits) == hvd.size()
  130. local_train_users = all_train_users_splits[hvd.rank()]
  131. local_train_items = all_train_items_splits[hvd.rank()]
  132. local_train_labels = all_train_labels_splits[hvd.rank()]
  133. return local_train_users, local_train_items, local_train_labels
  134. def get_local_test_data(pos_test_users, pos_test_items):
  135. """
  136. For distributed, split up the test data and only keep the local portion
  137. """
  138. split_size = pos_test_users.shape[0] // hvd.size() + 1
  139. split_indices = np.arange(split_size, pos_test_users.shape[0], split_size)
  140. test_users_splits = np.split(pos_test_users, split_indices)
  141. test_items_splits = np.split(pos_test_items, split_indices)
  142. assert len(test_users_splits) == hvd.size()
  143. local_test_users = test_users_splits[hvd.rank()]
  144. local_test_items = test_items_splits[hvd.rank()]
  145. return local_test_users, local_test_items
  146. def main():
  147. script_start = time.time()
  148. hvd_init()
  149. mpi_comm = MPI.COMM_WORLD
  150. args = parse_args()
  151. if hvd.rank() == 0:
  152. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  153. filename=args.log_path),
  154. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)])
  155. else:
  156. dllogger.init(backends=[])
  157. dllogger.metadata("best_epoch", {"unit": None})
  158. dllogger.metadata("first_epoch_to_hit", {"unit": None})
  159. dllogger.metadata("best_hr", {"unit": None})
  160. dllogger.metadata("average_eval_time_per_epoch", {"unit": "s"})
  161. dllogger.metadata("average_train_time_per_epoch", {"unit": "s"})
  162. dllogger.metadata("time_to_best", {"unit": "s"})
  163. dllogger.metadata("time_to_train", {"unit": "s"})
  164. dllogger.metadata("average_train_throughput", {"unit": "samples/s"})
  165. dllogger.metadata("average_eval_throughput", {"unit": "samples/s"})
  166. args.world_size = hvd.size()
  167. dllogger.log(data=vars(args), step='PARAMETER')
  168. if args.seed is None:
  169. if hvd.rank() == 0:
  170. seed = int(time.time())
  171. else:
  172. seed = None
  173. seed = mpi_comm.bcast(seed, root=0)
  174. else:
  175. seed = args.seed
  176. tf.random.set_random_seed(seed)
  177. np.random.seed(seed)
  178. cp.random.seed(seed)
  179. if args.amp:
  180. os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
  181. if args.checkpoint_dir is not None:
  182. os.makedirs(args.checkpoint_dir, exist_ok=True)
  183. final_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.ckpt')
  184. else:
  185. final_checkpoint_path = None
  186. # Load converted data and get statistics
  187. train_df = pd.read_pickle(args.data+'/train_ratings.pickle')
  188. test_df = pd.read_pickle(args.data+'/test_ratings.pickle')
  189. nb_users, nb_items = train_df.max() + 1
  190. # Extract train and test feature tensors from dataframe
  191. pos_train_users = train_df.iloc[:, 0].values.astype(np.int32)
  192. pos_train_items = train_df.iloc[:, 1].values.astype(np.int32)
  193. pos_test_users = test_df.iloc[:, 0].values.astype(np.int32)
  194. pos_test_items = test_df.iloc[:, 1].values.astype(np.int32)
  195. # Negatives indicator for negatives generation
  196. neg_mat = np.ones((nb_users, nb_items), dtype=np.bool)
  197. neg_mat[pos_train_users, pos_train_items] = 0
  198. # Get the local training/test data
  199. train_users, train_items, train_labels = get_local_train_data(
  200. pos_train_users, pos_train_items, args.negative_samples
  201. )
  202. test_users, test_items = get_local_test_data(
  203. pos_test_users, pos_test_items
  204. )
  205. # Create and run Data Generator in a separate thread
  206. data_generator = DataGenerator(
  207. args.seed,
  208. hvd.local_rank(),
  209. nb_users,
  210. nb_items,
  211. neg_mat,
  212. train_users,
  213. train_items,
  214. train_labels,
  215. args.batch_size // hvd.size(),
  216. args.negative_samples,
  217. test_users,
  218. test_items,
  219. args.valid_users_per_batch,
  220. args.valid_negative,
  221. )
  222. # Create tensorflow session and saver
  223. config = tf.ConfigProto()
  224. config.gpu_options.allow_growth = True
  225. config.gpu_options.visible_device_list = str(hvd.local_rank())
  226. if args.xla:
  227. config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
  228. sess = tf.Session(config=config)
  229. # Input tensors
  230. users = tf.placeholder(tf.int32, shape=(None,))
  231. items = tf.placeholder(tf.int32, shape=(None,))
  232. labels = tf.placeholder(tf.int32, shape=(None,))
  233. is_dup = tf.placeholder(tf.float32, shape=(None,))
  234. dropout = tf.placeholder_with_default(args.dropout, shape=())
  235. # Model ops and saver
  236. hit_rate, ndcg, eval_op, train_op = ncf_model_ops(
  237. users,
  238. items,
  239. labels,
  240. is_dup,
  241. params={
  242. 'val_batch_size': args.valid_negative+1,
  243. 'top_k': args.topk,
  244. 'learning_rate': args.learning_rate,
  245. 'beta_1': args.beta1,
  246. 'beta_2': args.beta2,
  247. 'epsilon': args.eps,
  248. 'num_users': nb_users,
  249. 'num_items': nb_items,
  250. 'num_factors': args.factors,
  251. 'mf_reg': 0,
  252. 'layer_sizes': args.layers,
  253. 'layer_regs': [0. for i in args.layers],
  254. 'dropout': dropout,
  255. 'sigmoid': True,
  256. 'loss_scale': args.loss_scale
  257. },
  258. mode='TRAIN' if args.mode == 'train' else 'EVAL'
  259. )
  260. saver = tf.train.Saver()
  261. # Accuracy metric tensors
  262. hr_sum = tf.get_default_graph().get_tensor_by_name('neumf/hit_rate/total:0')
  263. hr_cnt = tf.get_default_graph().get_tensor_by_name('neumf/hit_rate/count:0')
  264. ndcg_sum = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/total:0')
  265. ndcg_cnt = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/count:0')
  266. # Prepare evaluation data
  267. data_generator.prepare_eval_data()
  268. if args.load_checkpoint_path:
  269. saver.restore(sess, args.load_checkpoint_path)
  270. else:
  271. # Manual initialize weights
  272. sess.run(tf.global_variables_initializer())
  273. # If test mode, run one eval
  274. if args.mode == 'test':
  275. sess.run(tf.local_variables_initializer())
  276. eval_start = time.time()
  277. for user_batch, item_batch, dup_batch \
  278. in zip(data_generator.eval_users, data_generator.eval_items, data_generator.dup_mask):
  279. sess.run(
  280. eval_op,
  281. feed_dict={
  282. users: user_batch,
  283. items: item_batch,
  284. is_dup:dup_batch, dropout: 0.0
  285. }
  286. )
  287. eval_duration = time.time() - eval_start
  288. # Report results
  289. hit_rate_sum = sess.run(hvd.allreduce(hr_sum, average=False))
  290. hit_rate_cnt = sess.run(hvd.allreduce(hr_cnt, average=False))
  291. ndcg_sum = sess.run(hvd.allreduce(ndcg_sum, average=False))
  292. ndcg_cnt = sess.run(hvd.allreduce(ndcg_cnt, average=False))
  293. hit_rate = hit_rate_sum / hit_rate_cnt
  294. ndcg = ndcg_sum / ndcg_cnt
  295. if hvd.rank() == 0:
  296. eval_throughput = pos_test_users.shape[0] * (args.valid_negative + 1) / eval_duration
  297. dllogger.log(step=tuple(), data={'eval_throughput': eval_throughput,
  298. 'eval_time': eval_duration,
  299. 'hr@10': float(hit_rate),
  300. 'ndcg': float(ndcg)})
  301. return
  302. # Performance Metrics
  303. train_times = list()
  304. eval_times = list()
  305. # Accuracy Metrics
  306. first_to_target = None
  307. time_to_train = 0.0
  308. best_hr = 0
  309. best_epoch = 0
  310. # Buffers for global metrics
  311. global_hr_sum = np.ones(1)
  312. global_hr_count = np.ones(1)
  313. global_ndcg_sum = np.ones(1)
  314. global_ndcg_count = np.ones(1)
  315. # Buffers for local metrics
  316. local_hr_sum = np.ones(1)
  317. local_hr_count = np.ones(1)
  318. local_ndcg_sum = np.ones(1)
  319. local_ndcg_count = np.ones(1)
  320. # Begin training
  321. begin_train = time.time()
  322. for epoch in range(args.epochs):
  323. # Train for one epoch
  324. train_start = time.time()
  325. data_generator.prepare_train_data()
  326. for user_batch, item_batch, label_batch \
  327. in zip(data_generator.train_users_batches,
  328. data_generator.train_items_batches,
  329. data_generator.train_labels_batches):
  330. sess.run(
  331. train_op,
  332. feed_dict={
  333. users: user_batch.get(),
  334. items: item_batch.get(),
  335. labels: label_batch.get()
  336. }
  337. )
  338. train_duration = time.time() - train_start
  339. # Only log "warm" epochs
  340. if epoch >= 1:
  341. train_times.append(train_duration)
  342. # Evaluate
  343. if epoch > args.eval_after:
  344. eval_start = time.time()
  345. sess.run(tf.local_variables_initializer())
  346. for user_batch, item_batch, dup_batch \
  347. in zip(data_generator.eval_users,
  348. data_generator.eval_items,
  349. data_generator.dup_mask):
  350. sess.run(
  351. eval_op,
  352. feed_dict={
  353. users: user_batch,
  354. items: item_batch,
  355. is_dup: dup_batch,
  356. dropout: 0.0
  357. }
  358. )
  359. # Compute local metrics
  360. local_hr_sum[0] = sess.run(hr_sum)
  361. local_hr_count[0] = sess.run(hr_cnt)
  362. local_ndcg_sum[0] = sess.run(ndcg_sum)
  363. local_ndcg_count[0] = sess.run(ndcg_cnt)
  364. # Reduce metrics across all workers
  365. mpi_comm.Reduce(local_hr_count, global_hr_count)
  366. mpi_comm.Reduce(local_hr_sum, global_hr_sum)
  367. mpi_comm.Reduce(local_ndcg_count, global_ndcg_count)
  368. mpi_comm.Reduce(local_ndcg_sum, global_ndcg_sum)
  369. # Calculate metrics
  370. hit_rate = global_hr_sum[0] / global_hr_count[0]
  371. ndcg = global_ndcg_sum[0] / global_ndcg_count[0]
  372. eval_duration = time.time() - eval_start
  373. # Only log "warm" epochs
  374. if epoch >= 1:
  375. eval_times.append(eval_duration)
  376. if hvd.rank() == 0:
  377. dllogger.log(step=(epoch,), data={
  378. 'train_time': train_duration,
  379. 'eval_time': eval_duration,
  380. 'hr@10': hit_rate,
  381. 'ndcg': ndcg})
  382. # Update summary metrics
  383. if hit_rate > args.target and first_to_target is None:
  384. first_to_target = epoch
  385. time_to_train = time.time() - begin_train
  386. if hit_rate > best_hr:
  387. best_hr = hit_rate
  388. best_epoch = epoch
  389. time_to_best = time.time() - begin_train
  390. if hit_rate > args.target and final_checkpoint_path:
  391. saver.save(sess, final_checkpoint_path)
  392. # Final Summary
  393. if hvd.rank() == 0:
  394. train_times = np.array(train_times)
  395. train_throughputs = pos_train_users.shape[0]*(args.negative_samples+1) / train_times
  396. eval_times = np.array(eval_times)
  397. eval_throughputs = pos_test_users.shape[0]*(args.valid_negative+1) / eval_times
  398. dllogger.log(step=tuple(), data={
  399. 'average_train_time_per_epoch': np.mean(train_times),
  400. 'average_train_throughput': np.mean(train_throughputs),
  401. 'average_eval_time_per_epoch': np.mean(eval_times),
  402. 'average_eval_throughput': np.mean(eval_throughputs),
  403. 'first_epoch_to_hit': first_to_target,
  404. 'time_to_train': time_to_train,
  405. 'time_to_best': time_to_best,
  406. 'best_hr': best_hr,
  407. 'best_epoch': best_epoch})
  408. dllogger.flush()
  409. sess.close()
  410. return
  411. if __name__ == '__main__':
  412. main()