datasets.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import concurrent
  15. import math
  16. import os
  17. import queue
  18. import torch
  19. import numpy as np
  20. from torch.utils.data import Dataset
  21. from typing import Optional, Sequence, Tuple, List
  22. from dlrm.data.defaults import CATEGORICAL_CHANNEL, NUMERICAL_CHANNEL, LABEL_CHANNEL, \
  23. DTYPE_SELECTOR, FEATURES_SELECTOR, FILES_SELECTOR
  24. from dlrm.data.feature_spec import FeatureSpec
  25. class SyntheticDataset(Dataset):
  26. """Synthetic dataset version of criteo dataset."""
  27. def __init__(
  28. self,
  29. num_entries: int,
  30. device: str = 'cuda',
  31. batch_size: int = 32768,
  32. numerical_features: Optional[int] = None,
  33. categorical_feature_sizes: Optional[Sequence[int]] = None # features are returned in this order
  34. ):
  35. cat_features_count = len(categorical_feature_sizes) if categorical_feature_sizes is not None else 0
  36. num_features_count = numerical_features if numerical_features is not None else 0
  37. self._batches_per_epoch = math.ceil(num_entries / batch_size)
  38. self._num_tensor = torch.rand(size=(batch_size, num_features_count), device=device, dtype=torch.float32) \
  39. if num_features_count > 0 else None
  40. self._label_tensor = torch.randint(low=0, high=2, size=(batch_size,), device=device, dtype=torch.float32)
  41. self._cat_tensor = torch.cat(
  42. [torch.randint(low=0, high=cardinality, size=(batch_size, 1), device=device, dtype=torch.long)
  43. for cardinality in categorical_feature_sizes], dim=1) if cat_features_count > 0 else None
  44. def __len__(self):
  45. return self._batches_per_epoch
  46. def __getitem__(self, idx: int):
  47. if idx >= self._batches_per_epoch:
  48. raise IndexError()
  49. return self._num_tensor, self._cat_tensor, self._label_tensor
  50. class ParametricDataset(Dataset):
  51. def __init__(
  52. self,
  53. feature_spec: FeatureSpec,
  54. mapping: str,
  55. batch_size: int = 1,
  56. numerical_features_enabled: bool = False,
  57. categorical_features_to_read: List[str] = None, # This parameter dictates order of returned features
  58. prefetch_depth: int = 10,
  59. drop_last_batch: bool = False,
  60. **kwargs
  61. ):
  62. self._feature_spec = feature_spec
  63. self._batch_size = batch_size
  64. self._mapping = mapping
  65. feature_spec.check_feature_spec()
  66. categorical_features = feature_spec.channel_spec[CATEGORICAL_CHANNEL]
  67. numerical_features = feature_spec.channel_spec[NUMERICAL_CHANNEL]
  68. label_features = feature_spec.channel_spec[LABEL_CHANNEL]
  69. set_of_categorical_features = set(categorical_features)
  70. set_of_numerical_features = set(numerical_features)
  71. set_of_label_features = set(label_features)
  72. set_of_categoricals_to_read = set(categorical_features_to_read)
  73. bytes_per_feature = {feature_name: np.dtype(feature_spec.feature_spec[feature_name][DTYPE_SELECTOR]).itemsize
  74. for feature_name in feature_spec.feature_spec.keys()}
  75. self._numerical_features_file = None
  76. self._label_file = None
  77. self._numerical_bytes_per_batch = bytes_per_feature[numerical_features[0]] * \
  78. len(numerical_features) * batch_size
  79. self._label_bytes_per_batch = np.dtype(bool).itemsize * batch_size
  80. self._number_of_numerical_features = len(numerical_features)
  81. chosen_mapping = feature_spec.source_spec[mapping]
  82. categorical_feature_files = {}
  83. root_path = feature_spec.base_directory
  84. number_of_batches = None
  85. for chunk in chosen_mapping:
  86. contained_features = chunk[FEATURES_SELECTOR]
  87. containing_file = chunk[FILES_SELECTOR][0]
  88. first_feature = contained_features[0]
  89. if first_feature in set_of_categorical_features:
  90. # Load categorical
  91. if first_feature not in set_of_categoricals_to_read:
  92. continue # skip chunk
  93. path_to_open = os.path.join(root_path, containing_file)
  94. cat_file = os.open(path_to_open, os.O_RDONLY)
  95. bytes_per_batch = bytes_per_feature[first_feature] * self._batch_size
  96. batch_num_float = os.fstat(cat_file).st_size / bytes_per_batch
  97. categorical_feature_files[first_feature] = cat_file
  98. elif first_feature in set_of_numerical_features:
  99. # Load numerical
  100. if not numerical_features_enabled:
  101. continue # skip chunk
  102. path_to_open = os.path.join(root_path, containing_file)
  103. self._numerical_features_file = os.open(path_to_open, os.O_RDONLY)
  104. batch_num_float = os.fstat(self._numerical_features_file).st_size / self._numerical_bytes_per_batch
  105. elif first_feature in set_of_label_features:
  106. # Load label
  107. path_to_open = os.path.join(root_path, containing_file)
  108. self._label_file = os.open(path_to_open, os.O_RDONLY)
  109. batch_num_float = os.fstat(self._label_file).st_size / self._label_bytes_per_batch
  110. else:
  111. raise ValueError("Unknown chunk type")
  112. local_number_of_batches = math.ceil(batch_num_float) if not drop_last_batch else math.floor(batch_num_float)
  113. if number_of_batches is not None:
  114. if local_number_of_batches != number_of_batches:
  115. raise ValueError("Size mismatch in data files")
  116. else:
  117. number_of_batches = local_number_of_batches
  118. self._categorical_features_files = None
  119. if len(categorical_features_to_read) > 0:
  120. self._categorical_features_files = [categorical_feature_files[feature] for feature in
  121. categorical_features_to_read]
  122. self._categorical_bytes_per_batch = [bytes_per_feature[feature] * self._batch_size for feature in
  123. categorical_features_to_read]
  124. self._categorical_types = [feature_spec.feature_spec[feature][DTYPE_SELECTOR] for feature in
  125. categorical_features_to_read]
  126. self._num_entries = number_of_batches
  127. self._prefetch_depth = min(prefetch_depth, self._num_entries)
  128. self._prefetch_queue = queue.Queue()
  129. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
  130. def __len__(self):
  131. return self._num_entries
  132. def __getitem__(self, idx: int):
  133. """ Numerical features are returned in the order they appear in the channel spec section
  134. For performance reasons, this is required to be the order they are saved in, as specified
  135. by the relevant chunk in source spec.
  136. Categorical features are returned in the order they appear in the channel spec section """
  137. if idx >= self._num_entries:
  138. raise IndexError()
  139. if self._prefetch_depth <= 1:
  140. return self._get_item(idx)
  141. # At the start, fill up the prefetching queue
  142. if idx == 0:
  143. for i in range(self._prefetch_depth):
  144. self._prefetch_queue.put(self._executor.submit(self._get_item, (i)))
  145. # Extend the prefetching window by one if not at the end of the dataset
  146. if idx < self._num_entries - self._prefetch_depth:
  147. self._prefetch_queue.put(self._executor.submit(self._get_item, (idx + self._prefetch_depth)))
  148. return self._prefetch_queue.get().result()
  149. def _get_item(self, idx: int) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
  150. click = self._get_label(idx)
  151. numerical_features = self._get_numerical_features(idx)
  152. categorical_features = self._get_categorical_features(idx)
  153. return numerical_features, categorical_features, click
  154. def _get_label(self, idx: int) -> torch.Tensor:
  155. raw_label_data = os.pread(self._label_file, self._label_bytes_per_batch,
  156. idx * self._label_bytes_per_batch)
  157. array = np.frombuffer(raw_label_data, dtype=bool)
  158. return torch.from_numpy(array).to(torch.float32)
  159. def _get_numerical_features(self, idx: int) -> Optional[torch.Tensor]:
  160. if self._numerical_features_file is None:
  161. return None
  162. raw_numerical_data = os.pread(self._numerical_features_file, self._numerical_bytes_per_batch,
  163. idx * self._numerical_bytes_per_batch)
  164. array = np.frombuffer(raw_numerical_data, dtype=np.float16)
  165. return torch.from_numpy(array).view(-1, self._number_of_numerical_features)
  166. def _get_categorical_features(self, idx: int) -> Optional[torch.Tensor]:
  167. if self._categorical_features_files is None:
  168. return None
  169. categorical_features = []
  170. for cat_bytes, cat_type, cat_file in zip(self._categorical_bytes_per_batch,
  171. self._categorical_types,
  172. self._categorical_features_files):
  173. raw_cat_data = os.pread(cat_file, cat_bytes, idx * cat_bytes)
  174. array = np.frombuffer(raw_cat_data, dtype=cat_type)
  175. tensor = torch.from_numpy(array).unsqueeze(1).to(torch.long)
  176. categorical_features.append(tensor)
  177. return torch.cat(categorical_features, dim=1)
  178. def __del__(self):
  179. data_files = [self._label_file, self._numerical_features_file]
  180. if self._categorical_features_files is not None:
  181. data_files += self._categorical_features_files
  182. for data_file in data_files:
  183. if data_file is not None:
  184. os.close(data_file)