dataloading.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. import torch
  31. import os
  32. from feature_spec import FeatureSpec
  33. from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_NAME, TEST_SAMPLES_PER_SERIES
  34. class TorchTensorDataset:
  35. """ Warning! This dataset/loader uses torch.load. Torch.load implicitly uses pickle. Pickle is insecure.
  36. It is trivial to achieve arbitrary code execution using a prepared pickle payload. Only unpickle data you trust."""
  37. def __init__(self, feature_spec: FeatureSpec, mapping_name: str, args):
  38. self.local_rank = args.local_rank
  39. self.mapping_name = mapping_name
  40. self.features = dict()
  41. self.feature_spec = feature_spec
  42. self._load_features()
  43. def _load_features(self):
  44. chunks = self.feature_spec.source_spec[self.mapping_name]
  45. for chunk in chunks:
  46. assert chunk['type'] == 'torch_tensor', "Only torch_tensor files supported in this loader"
  47. files_list = chunk['files']
  48. assert len(files_list) == 1, "Only one file per chunk supported in this loader"
  49. file_relative_path = files_list[0]
  50. path_to_load = os.path.join(self.feature_spec.base_directory, file_relative_path)
  51. chunk_data = torch.load(path_to_load, map_location=torch.device('cuda:{}'.format(self.local_rank)))
  52. running_pos = 0
  53. for feature_name in chunk['features']:
  54. next_running_pos = running_pos + 1
  55. feature_data = chunk_data[:, running_pos:next_running_pos]
  56. # This is needed because slicing instead of indexing keeps the data 2-dimensional
  57. feature_data = feature_data.reshape(-1, 1)
  58. running_pos = next_running_pos
  59. self.features[feature_name] = feature_data
  60. class TestDataLoader:
  61. def __init__(self, dataset: TorchTensorDataset, args):
  62. self.dataset = dataset
  63. self.feature_spec = dataset.feature_spec
  64. self.channel_spec = self.feature_spec.channel_spec
  65. self.samples_in_series = self.feature_spec.metadata[TEST_SAMPLES_PER_SERIES]
  66. self.raw_dataset_length = None # First feature loaded sets this. Total length before splitting across cards
  67. self.data = dict()
  68. self.world_size = args.world_size
  69. self.local_rank = args.local_rank
  70. self.batch_size = args.valid_batch_size
  71. self._build_channel_dict()
  72. self._deduplication_augmentation()
  73. self._split_between_devices()
  74. self._split_into_batches()
  75. def _build_channel_dict(self):
  76. for channel_name, channel_features in self.channel_spec.items():
  77. channel_tensors = dict()
  78. for feature_name in channel_features:
  79. channel_tensors[feature_name] = self.dataset.features[feature_name]
  80. if not self.raw_dataset_length:
  81. self.raw_dataset_length = channel_tensors[feature_name].shape[0]
  82. else:
  83. assert self.raw_dataset_length == channel_tensors[feature_name].shape[0]
  84. self.data[channel_name] = channel_tensors
  85. def _deduplication_augmentation(self):
  86. # Augmentation
  87. # This adds a duplication mask tensor.
  88. # This is here to exactly replicate the MLPerf training regime. Moving this deduplication to the candidate item
  89. # generation stage increases the real diversity of the candidates, which makes the ranking task harder
  90. # and results in a drop in HR@10 of approx 0.01. This has been deemed unacceptable (May 2021).
  91. # We need the duplication mask to determine if a given item should be skipped during ranking
  92. # If an item with label 1 is duplicated in the sampled ones, we need to be careful to not mark the one with
  93. # label 1 as a duplicate. If an item appears repeatedly only with label 1, no duplicates are marked.
  94. # To easily compute candidates, we sort the items. This will impact the distribution of examples between
  95. # devices, but should not influence the numerics or performance meaningfully.
  96. # We need to assure that the positive item, which we don't want to mark as a duplicate, appears first.
  97. # We do this by adding labels as a secondary factor
  98. # Reshape the tensors to have items for a given user in a single row
  99. user_feature_name = self.channel_spec[USER_CHANNEL_NAME][0]
  100. item_feature_name = self.channel_spec[ITEM_CHANNEL_NAME][0]
  101. label_feature_name = self.channel_spec[LABEL_CHANNEL_NAME][0]
  102. self.ignore_mask_channel_name = 'mask_ch'
  103. self.ignore_mask_feature_name = 'mask'
  104. items = self.data[ITEM_CHANNEL_NAME][item_feature_name].view(-1, self.samples_in_series)
  105. users = self.data[USER_CHANNEL_NAME][user_feature_name].view(-1, self.samples_in_series)
  106. labels = self.data[LABEL_CHANNEL_NAME][label_feature_name].view(-1, self.samples_in_series)
  107. sorting_weights = items.float() - labels.float() * 0.5
  108. _, indices = torch.sort(sorting_weights)
  109. # The gather reorders according to the indices decided by the sort above
  110. sorted_items = torch.gather(items, 1, indices)
  111. sorted_labels = torch.gather(labels, 1, indices)
  112. sorted_users = torch.gather(users, 1, indices)
  113. dup_mask = sorted_items[:, 0:-1] == sorted_items[:, 1:] # This says if a given item is equal to the next one
  114. dup_mask = dup_mask.type(torch.bool)
  115. # The first item for a given user can never be a duplicate:
  116. dup_mask = torch.cat((torch.zeros_like(dup_mask[:, 0:1]), dup_mask), dim=1)
  117. # Reshape them back
  118. self.data[ITEM_CHANNEL_NAME][item_feature_name] = sorted_items.view(-1, 1)
  119. self.data[USER_CHANNEL_NAME][user_feature_name] = sorted_users.view(-1, 1)
  120. self.data[LABEL_CHANNEL_NAME][label_feature_name] = sorted_labels.view(-1, 1)
  121. self.data[self.ignore_mask_channel_name] = dict()
  122. self.data[self.ignore_mask_channel_name][self.ignore_mask_feature_name] = dup_mask.view(-1, 1)
  123. def _split_between_devices(self):
  124. if self.world_size > 1:
  125. # DO NOT REPLACE WITH torch.chunk (number of returned chunks can silently be lower than requested).
  126. # It would break compatibility with small datasets.
  127. num_test_cases = self.raw_dataset_length / self.samples_in_series
  128. smaller_batch = (int(num_test_cases // self.world_size)) * self.samples_in_series
  129. bigger_batch = smaller_batch + self.samples_in_series
  130. remainder = int(num_test_cases % self.world_size)
  131. samples_per_card = [bigger_batch] * remainder + [smaller_batch] * (self.world_size - remainder)
  132. for channel_name, channel_dict in self.data.items():
  133. for feature_name, feature_tensor in channel_dict.items():
  134. channel_dict[feature_name] = \
  135. channel_dict[feature_name].split(samples_per_card)[self.local_rank]
  136. def _split_into_batches(self):
  137. self.batches = None
  138. # This is the structure of each batch, waiting to be copied and filled in with data
  139. for channel_name, channel_dict in self.data.items():
  140. for feature_name, feature_tensor in channel_dict.items():
  141. feature_batches = feature_tensor.view(-1).split(self.batch_size)
  142. if not self.batches:
  143. self.batches = list(
  144. {channel_name: dict() for channel_name in self.data.keys()} for _ in feature_batches)
  145. for pos, feature_batch_data in enumerate(feature_batches):
  146. self.batches[pos][channel_name][feature_name] = feature_batch_data
  147. def get_epoch_data(self):
  148. return self.batches
  149. def get_ignore_mask(self):
  150. return self.data[self.ignore_mask_channel_name][self.ignore_mask_feature_name]
  151. class TrainDataloader:
  152. def __init__(self, dataset: TorchTensorDataset, args):
  153. self.dataset = dataset
  154. self.local_rank = args.local_rank
  155. if args.distributed:
  156. self.local_batch = args.batch_size // args.world_size
  157. else:
  158. self.local_batch = args.batch_size
  159. self.feature_spec = dataset.feature_spec
  160. self.channel_spec = self.feature_spec.channel_spec
  161. self.negative_samples = args.negative_samples
  162. self.data = dict()
  163. self.raw_dataset_length = None # first feature loaded sets this
  164. self._build_channel_dict()
  165. self.length_after_augmentation = self.raw_dataset_length * (self.negative_samples + 1)
  166. samples_per_worker = self.length_after_augmentation / args.world_size
  167. self.samples_begin = int(samples_per_worker * args.local_rank)
  168. self.samples_end = int(samples_per_worker * (args.local_rank + 1))
  169. def _build_channel_dict(self):
  170. for channel_name, channel_features in self.channel_spec.items():
  171. channel_tensors = dict()
  172. for feature_name in channel_features:
  173. channel_tensors[feature_name] = self.dataset.features[feature_name]
  174. if not self.raw_dataset_length:
  175. self.raw_dataset_length = channel_tensors[feature_name].shape[0]
  176. else:
  177. assert self.raw_dataset_length == channel_tensors[feature_name].shape[0]
  178. self.data[channel_name] = channel_tensors
  179. def get_epoch_data(self):
  180. # Augment, appending args.negative_samples times the original set, now with random items end negative labels
  181. augmented_data = {channel_name: dict() for channel_name in self.data.keys()}
  182. user_feature_name = self.channel_spec[USER_CHANNEL_NAME][0]
  183. item_feature_name = self.channel_spec[ITEM_CHANNEL_NAME][0]
  184. label_feature_name = self.channel_spec[LABEL_CHANNEL_NAME][0]
  185. # USER
  186. user_tensor = self.data[USER_CHANNEL_NAME][user_feature_name]
  187. neg_users = user_tensor.repeat(self.negative_samples, 1)
  188. augmented_users = torch.cat((user_tensor, neg_users))
  189. augmented_data[USER_CHANNEL_NAME][user_feature_name] = augmented_users
  190. del neg_users
  191. # ITEM
  192. item_tensor = self.data[ITEM_CHANNEL_NAME][item_feature_name]
  193. neg_items = torch.empty_like(item_tensor).repeat(self.negative_samples, 1) \
  194. .random_(0, self.feature_spec.feature_spec[item_feature_name]['cardinality'])
  195. augmented_items = torch.cat((item_tensor, neg_items))
  196. augmented_data[ITEM_CHANNEL_NAME][item_feature_name] = augmented_items
  197. del neg_items
  198. # LABEL
  199. label_tensor = self.data[LABEL_CHANNEL_NAME][label_feature_name]
  200. neg_label = torch.zeros_like(label_tensor, dtype=torch.float32).repeat(self.negative_samples, 1)
  201. augmented_labels = torch.cat((label_tensor, neg_label))
  202. del neg_label
  203. augmented_data[LABEL_CHANNEL_NAME][label_feature_name] = augmented_labels
  204. # Labels are not shuffled between cards.
  205. # This replicates previous behaviour.
  206. epoch_indices = torch.randperm(self.samples_end - self.samples_begin, device='cuda:{}'.format(self.local_rank))
  207. epoch_indices += self.samples_begin
  208. batches = None
  209. for channel_name, channel_dict in augmented_data.items():
  210. for feature_name, feature_tensor in channel_dict.items():
  211. # the last batch will almost certainly be smaller, drop it
  212. # Warning: may not work if there's only one
  213. feature_batches = feature_tensor.view(-1)[epoch_indices].split(self.local_batch)[:-1]
  214. if not batches:
  215. batches = list({channel_name: dict() for channel_name in self.data.keys()} for _ in feature_batches)
  216. for pos, feature_batch_data in enumerate(feature_batches):
  217. batches[pos][channel_name][feature_name] = feature_batch_data
  218. return batches