ncf.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. import torch.jit
  31. from apex.optimizers import FusedAdam
  32. import os
  33. import math
  34. import time
  35. import numpy as np
  36. from argparse import ArgumentParser
  37. import torch
  38. import torch.nn as nn
  39. import utils
  40. import dataloading
  41. from neumf import NeuMF
  42. from feature_spec import FeatureSpec
  43. from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_NAME
  44. import dllogger
  45. def synchronized_timestamp():
  46. torch.cuda.synchronize()
  47. return time.time()
  48. def parse_args():
  49. parser = ArgumentParser(description="Train a Neural Collaborative"
  50. " Filtering model")
  51. parser.add_argument('--data', type=str,
  52. help='Path to the directory containing the feature specification yaml')
  53. parser.add_argument('--feature_spec_file', type=str, default='feature_spec.yaml',
  54. help='Name of the feature specification file or path relative to the data directory.')
  55. parser.add_argument('-e', '--epochs', type=int, default=30,
  56. help='Number of epochs for training')
  57. parser.add_argument('-b', '--batch_size', type=int, default=2 ** 20,
  58. help='Number of examples for each iteration. This will be divided by the number of devices')
  59. parser.add_argument('--valid_batch_size', type=int, default=2 ** 20,
  60. help='Number of examples in each validation chunk. This will be the maximum size of a batch '
  61. 'on each device.')
  62. parser.add_argument('-f', '--factors', type=int, default=64,
  63. help='Number of predictive factors')
  64. parser.add_argument('--layers', nargs='+', type=int,
  65. default=[256, 256, 128, 64],
  66. help='Sizes of hidden layers for MLP')
  67. parser.add_argument('-n', '--negative_samples', type=int, default=4,
  68. help='Number of negative examples per interaction')
  69. parser.add_argument('-l', '--learning_rate', type=float, default=0.0045,
  70. help='Learning rate for optimizer')
  71. parser.add_argument('-k', '--topk', type=int, default=10,
  72. help='Rank for test examples to be considered a hit')
  73. parser.add_argument('--seed', '-s', type=int, default=None,
  74. help='Manually set random seed for torch')
  75. parser.add_argument('--threshold', '-t', type=float, default=1.0,
  76. help='Stop training early at threshold')
  77. parser.add_argument('--beta1', '-b1', type=float, default=0.25,
  78. help='Beta1 for Adam')
  79. parser.add_argument('--beta2', '-b2', type=float, default=0.5,
  80. help='Beta1 for Adam')
  81. parser.add_argument('--eps', type=float, default=1e-8,
  82. help='Epsilon for Adam')
  83. parser.add_argument('--dropout', type=float, default=0.5,
  84. help='Dropout probability, if equal to 0 will not use dropout at all')
  85. parser.add_argument('--checkpoint_dir', default='', type=str,
  86. help='Path to the directory storing the checkpoint file, '
  87. 'passing an empty path disables checkpoint saving')
  88. parser.add_argument('--load_checkpoint_path', default=None, type=str,
  89. help='Path to the checkpoint file to be loaded before training/evaluation')
  90. parser.add_argument('--mode', choices=['train', 'test'], default='train', type=str,
  91. help='Passing "test" will only run a single evaluation; '
  92. 'otherwise, full training will be performed')
  93. parser.add_argument('--grads_accumulated', default=1, type=int,
  94. help='Number of gradients to accumulate before performing an optimization step')
  95. parser.add_argument('--amp', action='store_true', help='Enable mixed precision training')
  96. parser.add_argument('--log_path', default='log.json', type=str,
  97. help='Path for the JSON training log')
  98. return parser.parse_args()
  99. def init_distributed(args):
  100. args.world_size = int(os.environ.get('WORLD_SIZE', default=1))
  101. args.distributed = args.world_size > 1
  102. if args.distributed:
  103. args.local_rank = int(os.environ['LOCAL_RANK'])
  104. '''
  105. Set cuda device so everything is done on the right GPU.
  106. THIS MUST BE DONE AS SOON AS POSSIBLE.
  107. '''
  108. torch.cuda.set_device(args.local_rank)
  109. '''Initialize distributed communication'''
  110. torch.distributed.init_process_group(backend='nccl',
  111. init_method='env://')
  112. else:
  113. args.local_rank = 0
  114. def val_epoch(model, dataloader: dataloading.TestDataLoader, k, distributed=False, world_size=1):
  115. model.eval()
  116. user_feature_name = dataloader.channel_spec[USER_CHANNEL_NAME][0]
  117. item_feature_name = dataloader.channel_spec[ITEM_CHANNEL_NAME][0]
  118. label_feature_name = dataloader.channel_spec[LABEL_CHANNEL_NAME][0]
  119. with torch.no_grad():
  120. p = []
  121. labels_list = []
  122. losses = []
  123. for batch_dict in dataloader.get_epoch_data():
  124. user_batch = batch_dict[USER_CHANNEL_NAME][user_feature_name]
  125. item_batch = batch_dict[ITEM_CHANNEL_NAME][item_feature_name]
  126. label_batch = batch_dict[LABEL_CHANNEL_NAME][label_feature_name]
  127. prediction_batch = model(user_batch, item_batch, sigmoid=True).detach()
  128. loss_batch = torch.nn.functional.binary_cross_entropy(input=prediction_batch.reshape([-1]),
  129. target=label_batch)
  130. losses.append(loss_batch)
  131. p.append(prediction_batch)
  132. labels_list.append(label_batch)
  133. ignore_mask = dataloader.get_ignore_mask().view(-1, dataloader.samples_in_series)
  134. ratings = torch.cat(p).view(-1, dataloader.samples_in_series)
  135. ratings[ignore_mask] = -1
  136. labels = torch.cat(labels_list).view(-1, dataloader.samples_in_series)
  137. del p, labels_list
  138. top_indices = torch.topk(ratings, k)[1]
  139. # Positive items are always first in a given series
  140. labels_of_selected = torch.gather(labels, 1, top_indices)
  141. ifzero = (labels_of_selected == 1)
  142. hits = ifzero.sum()
  143. ndcg = (math.log(2) / (torch.nonzero(ifzero)[:, 1].view(-1).to(torch.float) + 2).log_()).sum()
  144. total_validation_loss = torch.mean(torch.stack(losses, dim=0))
  145. # torch.nonzero may cause host-device synchronization
  146. if distributed:
  147. torch.distributed.all_reduce(hits, op=torch.distributed.ReduceOp.SUM)
  148. torch.distributed.all_reduce(ndcg, op=torch.distributed.ReduceOp.SUM)
  149. torch.distributed.all_reduce(total_validation_loss, op=torch.distributed.ReduceOp.SUM)
  150. total_validation_loss = total_validation_loss / world_size
  151. num_test_cases = dataloader.raw_dataset_length / dataloader.samples_in_series
  152. hr = hits.item() / num_test_cases
  153. ndcg = ndcg.item() / num_test_cases
  154. model.train()
  155. return hr, ndcg, total_validation_loss
  156. def main():
  157. args = parse_args()
  158. init_distributed(args)
  159. if args.local_rank == 0:
  160. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  161. filename=args.log_path),
  162. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)])
  163. else:
  164. dllogger.init(backends=[])
  165. dllogger.metadata('train_throughput', {"name": 'train_throughput', 'unit': 'samples/s', 'format': ":.3e"})
  166. dllogger.metadata('best_train_throughput', {'unit': 'samples/s'})
  167. dllogger.metadata('mean_train_throughput', {'unit': 'samples/s'})
  168. dllogger.metadata('eval_throughput', {"name": 'eval_throughput', 'unit': 'samples/s', 'format': ":.3e"})
  169. dllogger.metadata('best_eval_throughput', {'unit': 'samples/s'})
  170. dllogger.metadata('mean_eval_throughput', {'unit': 'samples/s'})
  171. dllogger.metadata('train_epoch_time', {"name": 'train_epoch_time', 'unit': 's', 'format': ":.3f"})
  172. dllogger.metadata('validation_epoch_time', {"name": 'validation_epoch_time', 'unit': 's', 'format': ":.3f"})
  173. dllogger.metadata('time_to_target', {'unit': 's'})
  174. dllogger.metadata('time_to_best_model', {'unit': 's'})
  175. dllogger.metadata('hr@10', {"name": 'hr@10', 'unit': None, 'format': ":.5f"})
  176. dllogger.metadata('best_accuracy', {'unit': None})
  177. dllogger.metadata('best_epoch', {'unit': None})
  178. dllogger.metadata('validation_loss', {"name": 'validation_loss', 'unit': None, 'format': ":.5f"})
  179. dllogger.metadata('train_loss', {"name": 'train_loss', 'unit': None, 'format': ":.5f"})
  180. dllogger.log(data=vars(args), step='PARAMETER')
  181. if args.seed is not None:
  182. torch.manual_seed(args.seed)
  183. if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir:
  184. print("Saving results to {}".format(args.checkpoint_dir))
  185. os.makedirs(args.checkpoint_dir, exist_ok=True)
  186. # sync workers before timing
  187. if args.distributed:
  188. torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
  189. torch.cuda.synchronize()
  190. main_start_time = synchronized_timestamp()
  191. feature_spec_path = os.path.join(args.data, args.feature_spec_file)
  192. feature_spec = FeatureSpec.from_yaml(feature_spec_path)
  193. trainset = dataloading.TorchTensorDataset(feature_spec, mapping_name='train', args=args)
  194. testset = dataloading.TorchTensorDataset(feature_spec, mapping_name='test', args=args)
  195. train_loader = dataloading.TrainDataloader(trainset, args)
  196. test_loader = dataloading.TestDataLoader(testset, args)
  197. # make pytorch memory behavior more consistent later
  198. torch.cuda.empty_cache()
  199. # Create model
  200. user_feature_name = feature_spec.channel_spec[USER_CHANNEL_NAME][0]
  201. item_feature_name = feature_spec.channel_spec[ITEM_CHANNEL_NAME][0]
  202. label_feature_name = feature_spec.channel_spec[LABEL_CHANNEL_NAME][0]
  203. model = NeuMF(nb_users=feature_spec.feature_spec[user_feature_name]['cardinality'],
  204. nb_items=feature_spec.feature_spec[item_feature_name]['cardinality'],
  205. mf_dim=args.factors,
  206. mlp_layer_sizes=args.layers,
  207. dropout=args.dropout)
  208. optimizer = FusedAdam(model.parameters(), lr=args.learning_rate,
  209. betas=(args.beta1, args.beta2), eps=args.eps)
  210. criterion = nn.BCEWithLogitsLoss(reduction='none') # use torch.mean() with dim later to avoid copy to host
  211. # Move model and loss to GPU
  212. model = model.cuda()
  213. criterion = criterion.cuda()
  214. if args.distributed:
  215. model = torch.nn.parallel.DistributedDataParallel(model)
  216. local_batch = args.batch_size // args.world_size
  217. traced_criterion = torch.jit.trace(criterion.forward,
  218. (torch.rand(local_batch, 1), torch.rand(local_batch, 1)))
  219. print(model)
  220. print("{} parameters".format(utils.count_parameters(model)))
  221. if args.load_checkpoint_path:
  222. state_dict = torch.load(args.load_checkpoint_path)
  223. state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
  224. model.load_state_dict(state_dict)
  225. if args.mode == 'test':
  226. start = synchronized_timestamp()
  227. hr, ndcg, val_loss = val_epoch(model, test_loader, args.topk,
  228. distributed=args.distributed, world_size=args.world_size)
  229. val_time = synchronized_timestamp() - start
  230. eval_size = test_loader.raw_dataset_length
  231. eval_throughput = eval_size / val_time
  232. dllogger.log(step=tuple(), data={'best_eval_throughput': eval_throughput,
  233. 'hr@10': hr,
  234. 'validation_loss': float(val_loss.item())})
  235. return
  236. # this should always be overridden if hr>0.
  237. # It is theoretically possible for the hit rate to be zero in the first epoch, which would result in referring
  238. # to an uninitialized variable.
  239. max_hr = 0
  240. best_epoch = 0
  241. best_model_timestamp = synchronized_timestamp()
  242. train_throughputs, eval_throughputs = [], []
  243. scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
  244. for epoch in range(args.epochs):
  245. begin = synchronized_timestamp()
  246. batch_dict_list = train_loader.get_epoch_data()
  247. num_batches = len(batch_dict_list)
  248. for i in range(num_batches // args.grads_accumulated):
  249. for j in range(args.grads_accumulated):
  250. batch_idx = (args.grads_accumulated * i) + j
  251. batch_dict = batch_dict_list[batch_idx]
  252. user_features = batch_dict[USER_CHANNEL_NAME]
  253. item_features = batch_dict[ITEM_CHANNEL_NAME]
  254. user_batch = user_features[user_feature_name]
  255. item_batch = item_features[item_feature_name]
  256. label_features = batch_dict[LABEL_CHANNEL_NAME]
  257. label_batch = label_features[label_feature_name]
  258. with torch.cuda.amp.autocast(enabled=args.amp):
  259. outputs = model(user_batch, item_batch)
  260. loss = traced_criterion(outputs, label_batch.view(-1, 1))
  261. loss = torch.mean(loss.float().view(-1), 0)
  262. scaler.scale(loss).backward()
  263. scaler.step(optimizer)
  264. scaler.update()
  265. for p in model.parameters():
  266. p.grad = None
  267. del batch_dict_list
  268. train_time = synchronized_timestamp() - begin
  269. begin = synchronized_timestamp()
  270. epoch_samples = train_loader.length_after_augmentation
  271. train_throughput = epoch_samples / train_time
  272. train_throughputs.append(train_throughput)
  273. hr, ndcg, val_loss = val_epoch(model, test_loader, args.topk,
  274. distributed=args.distributed, world_size=args.world_size)
  275. val_time = synchronized_timestamp() - begin
  276. eval_size = test_loader.raw_dataset_length
  277. eval_throughput = eval_size / val_time
  278. eval_throughputs.append(eval_throughput)
  279. if args.distributed:
  280. torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
  281. loss = loss / args.world_size
  282. dllogger.log(step=(epoch,),
  283. data={'train_throughput': train_throughput,
  284. 'hr@10': hr,
  285. 'train_epoch_time': train_time,
  286. 'validation_epoch_time': val_time,
  287. 'eval_throughput': eval_throughput,
  288. 'validation_loss': float(val_loss.item()),
  289. 'train_loss': float(loss.item())})
  290. if hr > max_hr and args.local_rank == 0:
  291. max_hr = hr
  292. best_epoch = epoch
  293. print("New best hr!")
  294. if args.checkpoint_dir:
  295. save_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.pth')
  296. print("Saving the model to: ", save_checkpoint_path)
  297. torch.save(model.state_dict(), save_checkpoint_path)
  298. best_model_timestamp = synchronized_timestamp()
  299. if args.threshold is not None:
  300. if hr >= args.threshold:
  301. print("Hit threshold of {}".format(args.threshold))
  302. break
  303. if args.local_rank == 0:
  304. dllogger.log(data={'best_train_throughput': max(train_throughputs),
  305. 'best_eval_throughput': max(eval_throughputs),
  306. 'mean_train_throughput': np.mean(train_throughputs),
  307. 'mean_eval_throughput': np.mean(eval_throughputs),
  308. 'best_accuracy': max_hr,
  309. 'best_epoch': best_epoch,
  310. 'time_to_target': synchronized_timestamp() - main_start_time,
  311. 'time_to_best_model': best_model_timestamp - main_start_time,
  312. 'validation_loss': float(val_loss.item()),
  313. 'train_loss': float(loss.item())},
  314. step=tuple())
  315. if __name__ == '__main__':
  316. main()