dataloading.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. import time
  31. import torch
  32. def create_test_data(test_ratings, test_negs, args):
  33. test_users = test_ratings[:,0]
  34. test_pos = test_ratings[:,1].reshape(-1,1)
  35. # create items with real sample at last position
  36. num_valid_negative = test_negs.shape[1]
  37. test_users = test_users.reshape(-1,1).repeat(1, 1 + num_valid_negative)
  38. test_items = torch.cat((test_negs, test_pos), dim=1)
  39. del test_ratings, test_negs
  40. # generate dup mask and real indices for exact same behavior on duplication compare to reference
  41. # here we need a sort that is stable(keep order of duplicates)
  42. sorted_items, indices = torch.sort(test_items) # [1,1,1,2], [3,1,0,2]
  43. sum_item_indices = sorted_items.float()+indices.float()/len(indices[0]) #[1.75,1.25,1.0,2.5]
  44. indices_order = torch.sort(sum_item_indices)[1] #[2,1,0,3]
  45. stable_indices = torch.gather(indices, 1, indices_order) #[0,1,3,2]
  46. # produce -1 mask
  47. dup_mask = (sorted_items[:,0:-1] == sorted_items[:,1:])
  48. dup_mask = dup_mask.type(torch.uint8)
  49. dup_mask = torch.cat((torch.zeros_like(test_pos, dtype=torch.uint8), dup_mask), dim=1)
  50. dup_mask = torch.gather(dup_mask, 1, stable_indices.sort()[1])
  51. # produce real sample indices to later check in topk
  52. sorted_items, indices = (test_items != test_pos).type(torch.uint8).sort()
  53. sum_item_indices = sorted_items.float()+indices.float()/len(indices[0])
  54. indices_order = torch.sort(sum_item_indices)[1]
  55. stable_indices = torch.gather(indices, 1, indices_order)
  56. real_indices = stable_indices[:,0]
  57. if args.distributed:
  58. test_users = torch.chunk(test_users, args.world_size)[args.local_rank]
  59. test_items = torch.chunk(test_items, args.world_size)[args.local_rank]
  60. dup_mask = torch.chunk(dup_mask, args.world_size)[args.local_rank]
  61. real_indices = torch.chunk(real_indices, args.world_size)[args.local_rank]
  62. test_users = test_users.view(-1).split(args.valid_batch_size)
  63. test_items = test_items.view(-1).split(args.valid_batch_size)
  64. return test_users, test_items, dup_mask, real_indices
  65. def prepare_epoch_train_data(train_ratings, nb_items, args):
  66. # create label
  67. train_label = torch.ones_like(train_ratings[:,0], dtype=torch.float32)
  68. neg_label = torch.zeros_like(train_label, dtype=torch.float32)
  69. neg_label = neg_label.repeat(args.negative_samples)
  70. train_label = torch.cat((train_label,neg_label))
  71. del neg_label
  72. train_users = train_ratings[:,0]
  73. train_items = train_ratings[:,1]
  74. train_users_per_worker = len(train_label) / args.world_size
  75. train_users_begin = int(train_users_per_worker * args.local_rank)
  76. train_users_end = int(train_users_per_worker * (args.local_rank + 1))
  77. # prepare data for epoch
  78. neg_users = train_users.repeat(args.negative_samples)
  79. neg_items = torch.empty_like(neg_users, dtype=torch.int64).random_(0, nb_items)
  80. epoch_users = torch.cat((train_users, neg_users))
  81. epoch_items = torch.cat((train_items, neg_items))
  82. del neg_users, neg_items
  83. # shuffle prepared data and split into batches
  84. epoch_indices = torch.randperm(train_users_end - train_users_begin, device='cuda:{}'.format(args.local_rank))
  85. epoch_indices += train_users_begin
  86. epoch_users = epoch_users[epoch_indices]
  87. epoch_items = epoch_items[epoch_indices]
  88. epoch_label = train_label[epoch_indices]
  89. if args.distributed:
  90. local_batch = args.batch_size // args.world_size
  91. else:
  92. local_batch = args.batch_size
  93. epoch_users = epoch_users.split(local_batch)
  94. epoch_items = epoch_items.split(local_batch)
  95. epoch_label = epoch_label.split(local_batch)
  96. # the last batch will almost certainly be smaller, drop it
  97. epoch_users = epoch_users[:-1]
  98. epoch_items = epoch_items[:-1]
  99. epoch_label = epoch_label[:-1]
  100. return epoch_users, epoch_items, epoch_label