dataloader.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected]), Tomasz Cheda ([email protected])
  16. import os
  17. from .defaults import TRAIN_MAPPING, TEST_MAPPING
  18. from .feature_spec import FeatureSpec
  19. from .raw_binary_dataset import TfRawBinaryDataset, DatasetMetadata
  20. from .synthetic_dataset import SyntheticDataset
  21. from .split_tfrecords_multihot_dataset import SplitTFRecordsDataset
  22. def get_dataset_metadata(dataset_path, feature_spec):
  23. fspec_path = os.path.join(dataset_path, feature_spec)
  24. feature_spec = FeatureSpec.from_yaml(fspec_path)
  25. dataset_metadata = DatasetMetadata(num_numerical_features=feature_spec.get_number_of_numerical_features(),
  26. categorical_cardinalities=feature_spec.get_categorical_sizes())
  27. return dataset_metadata
  28. def _create_pipelines_synthetic_fspec(**kwargs):
  29. fspec_path = os.path.join(kwargs['dataset_path'], kwargs['feature_spec'])
  30. feature_spec = FeatureSpec.from_yaml(fspec_path)
  31. dataset_metadata = DatasetMetadata(num_numerical_features=feature_spec.get_number_of_numerical_features(),
  32. categorical_cardinalities=feature_spec.get_categorical_sizes())
  33. local_table_sizes = [dataset_metadata.categorical_cardinalities[i] for i in kwargs['table_ids']]
  34. names = feature_spec.get_categorical_feature_names()
  35. local_names = [names[i] for i in kwargs['table_ids']]
  36. local_table_hotness = [feature_spec.feature_spec[name]["hotness"] for name in local_names]
  37. local_table_alpha = [feature_spec.feature_spec[name]["alpha"] for name in local_names]
  38. print('local table sizes: ', local_table_sizes)
  39. print('Local table hotness: ', local_table_hotness)
  40. train_dataset = SyntheticDataset(batch_size=kwargs['train_batch_size'],
  41. num_numerical_features=dataset_metadata.num_numerical_features,
  42. categorical_feature_cardinalities=local_table_sizes,
  43. categorical_feature_hotness=local_table_hotness,
  44. categorical_feature_alpha=local_table_alpha,
  45. num_batches=kwargs.get('synthetic_dataset_train_batches', int(1e9)),
  46. num_workers=kwargs['world_size'],
  47. variable_hotness=False)
  48. test_dataset = SyntheticDataset(batch_size=kwargs['test_batch_size'],
  49. num_numerical_features=dataset_metadata.num_numerical_features,
  50. categorical_feature_cardinalities=local_table_sizes,
  51. categorical_feature_hotness=local_table_hotness,
  52. categorical_feature_alpha=local_table_alpha,
  53. num_batches=kwargs.get('synthetic_dataset_valid_batches', int(1e9)),
  54. num_workers=kwargs['world_size'],
  55. variable_hotness=False)
  56. return train_dataset, test_dataset
  57. def _create_pipelines_tf_raw(**kwargs):
  58. fspec_path = os.path.join(kwargs['dataset_path'], kwargs['feature_spec'])
  59. feature_spec = FeatureSpec.from_yaml(fspec_path)
  60. local_categorical_names = feature_spec.cat_positions_to_names(kwargs['table_ids'])
  61. train_dataset = TfRawBinaryDataset(feature_spec=feature_spec,
  62. instance=TRAIN_MAPPING,
  63. batch_size=kwargs['train_batch_size'],
  64. numerical_features_enabled=True,
  65. local_categorical_feature_names=local_categorical_names,
  66. rank=kwargs['rank'],
  67. world_size=kwargs['world_size'],
  68. concat_features=kwargs['concat_features'],
  69. data_parallel_categoricals=kwargs['data_parallel_input'])
  70. test_dataset = TfRawBinaryDataset(feature_spec=feature_spec,
  71. instance=TEST_MAPPING,
  72. batch_size=kwargs['test_batch_size'],
  73. numerical_features_enabled=True,
  74. local_categorical_feature_names=local_categorical_names,
  75. rank=kwargs['rank'],
  76. world_size=kwargs['world_size'],
  77. concat_features=kwargs['concat_features'],
  78. data_parallel_categoricals=kwargs['data_parallel_input'])
  79. return train_dataset, test_dataset
  80. def _create_pipelines_split_tfrecords(**kwargs):
  81. fspec_path = os.path.join(kwargs['dataset_path'], kwargs['feature_spec'])
  82. feature_spec = FeatureSpec.from_yaml(fspec_path)
  83. train_dataset = SplitTFRecordsDataset(dataset_dir=feature_spec.base_directory + '/train/',
  84. feature_ids=kwargs['table_ids'],
  85. num_numerical=feature_spec.get_number_of_numerical_features(),
  86. rank=kwargs['rank'], world_size=kwargs['world_size'],
  87. batch_size=kwargs['train_batch_size'])
  88. test_dataset = SplitTFRecordsDataset(dataset_dir=feature_spec.base_directory + '/test/',
  89. feature_ids=kwargs['table_ids'],
  90. num_numerical=feature_spec.get_number_of_numerical_features(),
  91. rank=kwargs['rank'], world_size=kwargs['world_size'],
  92. batch_size=kwargs['test_batch_size'])
  93. return train_dataset, test_dataset
  94. def create_input_pipelines(dataset_type, dataset_path, train_batch_size, test_batch_size,
  95. table_ids, feature_spec, rank=0, world_size=1, concat_features=False,
  96. data_parallel_input=False):
  97. # pass along all arguments except dataset type
  98. kwargs = locals()
  99. del kwargs['dataset_type']
  100. #hardcoded for now
  101. kwargs['synthetic_dataset_use_feature_spec'] = True
  102. if dataset_type == 'synthetic' and not kwargs['synthetic_dataset_use_feature_spec']:
  103. return _create_pipelines_synthetic(**kwargs)
  104. elif dataset_type == 'synthetic' and kwargs['synthetic_dataset_use_feature_spec']: # synthetic based on feature spec
  105. return _create_pipelines_synthetic_fspec(**kwargs)
  106. elif dataset_type == 'tf_raw':
  107. return _create_pipelines_tf_raw(**kwargs)
  108. elif dataset_type == 'split_tfrecords':
  109. return _create_pipelines_split_tfrecords(**kwargs)
  110. else:
  111. raise ValueError(f'Unsupported dataset type: {dataset_type}')