dataloader.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import multiprocessing
  15. from functools import partial
  16. import tensorflow as tf
  17. from sim.data.defaults import (DIMENSIONS_SELECTOR, LABEL_CHANNEL, NEGATIVE_HISTORY_CHANNEL, POSITIVE_HISTORY_CHANNEL,
  18. TARGET_ITEM_FEATURES_CHANNEL, USER_FEATURES_CHANNEL, REMAINDER_FILENAME)
  19. def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
  20. channel_spec = feature_spec.channel_spec
  21. features = feature_spec.feature_spec
  22. user_features = {
  23. f_name: tf.reshape(sample[f_name], [-1]) for f_name in channel_spec[USER_FEATURES_CHANNEL]
  24. }
  25. target_item_features = {
  26. f_name: tf.reshape(sample[f_name], [-1]) for f_name in channel_spec[TARGET_ITEM_FEATURES_CHANNEL]
  27. }
  28. padded_positive = {
  29. f_name: tf.reshape(sample[f_name], [-1, features[f_name][DIMENSIONS_SELECTOR][0]])
  30. for f_name in channel_spec[POSITIVE_HISTORY_CHANNEL]
  31. }
  32. padded_negative = {
  33. f_name: tf.reshape(sample[f_name], [-1, features[f_name][DIMENSIONS_SELECTOR][0]])
  34. for f_name in channel_spec[NEGATIVE_HISTORY_CHANNEL]
  35. }
  36. long_sequence_features = {
  37. f_name: val[:, :long_seq_length] for f_name, val in padded_positive.items()
  38. }
  39. short_sequence_features = {
  40. f_name: val[:, long_seq_length:] for f_name, val in padded_positive.items()
  41. }
  42. short_neg_sequence_features = {
  43. f_name: val[:, long_seq_length:] for f_name, val in padded_negative.items()
  44. }
  45. first_positive_feature_name = channel_spec[POSITIVE_HISTORY_CHANNEL][0]
  46. first_positive_feature = padded_positive[first_positive_feature_name]
  47. history_mask = tf.cast(tf.greater(first_positive_feature, 0), tf.float32)
  48. long_sequence_mask = history_mask[:, :long_seq_length]
  49. short_sequence_mask = history_mask[:, long_seq_length:]
  50. label_name = channel_spec[LABEL_CHANNEL][0]
  51. target = tf.reshape(sample[label_name], [-1])
  52. return {
  53. "user_features": user_features,
  54. "target_item_features": target_item_features,
  55. "long_sequence_features": long_sequence_features,
  56. "short_sequence_features": short_sequence_features,
  57. "short_neg_sequence_features": short_neg_sequence_features,
  58. "long_sequence_mask": long_sequence_mask,
  59. "short_sequence_mask": short_sequence_mask,
  60. "other_features": None
  61. }, target
  62. def split_prebatch(sample, split_into):
  63. res = {}
  64. for f_name, val in sample.items():
  65. res[f_name] = tf.reshape(val, [split_into, -1])
  66. return tf.data.Dataset.from_tensor_slices(res)
  67. def get_dataloader_tfrecord(
  68. file_paths,
  69. feature_spec,
  70. batch_size,
  71. long_seq_length,
  72. num_gpus=1,
  73. id=0,
  74. drop_remainder=False,
  75. repeat_count=0,
  76. prefetch_buffer_size=90,
  77. num_parallel_calls=None,
  78. disable_cache=False,
  79. prebatch_size=0
  80. ):
  81. features = feature_spec.feature_spec
  82. prebatched = prebatch_size > 0
  83. remainder_file = None
  84. if file_paths[-1].name == REMAINDER_FILENAME:
  85. remainder_file = file_paths[-1:]
  86. file_paths = file_paths[:-1]
  87. tf_feature_spec = {}
  88. for name, feature in features.items():
  89. dimensions = feature.get(DIMENSIONS_SELECTOR)
  90. if dimensions is None:
  91. dimensions = [1] if prebatched else []
  92. if prebatched:
  93. dimensions = dimensions.copy()
  94. dimensions[0] *= prebatch_size
  95. tf_feature_spec[name] = tf.io.FixedLenFeature(dimensions, tf.int64)
  96. if num_parallel_calls is None:
  97. num_cpus = multiprocessing.cpu_count()
  98. num_parallel_calls = 4 * num_cpus // num_gpus
  99. dataset = tf.data.TFRecordDataset(file_paths, num_parallel_reads=num_parallel_calls)
  100. dataset = dataset.shard(num_gpus, id)
  101. splitting_function = None
  102. if prebatched:
  103. if batch_size >= prebatch_size:
  104. batch_size = batch_size // prebatch_size
  105. else:
  106. split_into = prebatch_size // batch_size
  107. splitting_function = partial(split_prebatch, split_into=split_into)
  108. batch_size = 1
  109. dataset = dataset.batch(
  110. batch_size, drop_remainder=drop_remainder, num_parallel_calls=num_parallel_calls
  111. )
  112. dataset = dataset.map(
  113. map_func=partial(tf.io.parse_example, features=tf_feature_spec),
  114. num_parallel_calls=num_parallel_calls
  115. )
  116. if splitting_function is not None:
  117. dataset = dataset.flat_map(splitting_function)
  118. if not drop_remainder and id == 0 and remainder_file is not None:
  119. tf_feature_spec_remainder = {
  120. name: tf.io.RaggedFeature(tf.int64) for name in tf_feature_spec
  121. }
  122. remainder = tf.data.TFRecordDataset(remainder_file)
  123. remainder = remainder.map(
  124. map_func=partial(tf.io.parse_example, features=tf_feature_spec_remainder)
  125. )
  126. dataset = dataset.concatenate(remainder)
  127. dataset = dataset.map(
  128. map_func=partial(_remap_column_values_tfrecord, feature_spec=feature_spec, long_seq_length=long_seq_length),
  129. num_parallel_calls=num_parallel_calls
  130. )
  131. if repeat_count > 0:
  132. dataset = dataset.repeat(
  133. count=repeat_count
  134. )
  135. if prefetch_buffer_size > 0:
  136. dataset = dataset.prefetch(
  137. buffer_size=prefetch_buffer_size
  138. )
  139. if not disable_cache:
  140. dataset = dataset.cache()
  141. return dataset