pretraining_dataset.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. import h5py
  16. import numpy as np
  17. import paddle
  18. from paddle.io import DataLoader, Dataset
  19. from utils.collate import Stack
  20. def create_pretraining_dataset(args,
  21. input_file,
  22. data_holders,
  23. worker_init=None,
  24. places=None):
  25. train_data = PretrainingDataset(
  26. input_file=input_file, max_pred_length=args.max_predictions_per_seq)
  27. train_batch_sampler = paddle.io.BatchSampler(
  28. train_data, batch_size=args.batch_size, shuffle=True)
  29. def _collate_data(data, stack_fn=Stack()):
  30. num_fields = len(data[0])
  31. out = [None] * num_fields
  32. [
  33. input_ids, segment_ids, input_mask, masked_lm_positions,
  34. masked_lm_labels, next_sentence_labels, masked_lm_scale
  35. ] = [0, 1, 2, 3, 4, 5, 6]
  36. for i in (input_ids, segment_ids, input_mask, next_sentence_labels):
  37. out[i] = stack_fn([x[i] for x in data])
  38. _, seq_length = out[input_ids].shape
  39. size = sum(len(x[masked_lm_positions]) for x in data)
  40. if size % 8 != 0:
  41. size += 8 - (size % 8)
  42. out[masked_lm_positions] = np.full(size, 0, dtype=np.int32)
  43. out[masked_lm_labels] = np.full([size, 1], -1, dtype=np.int64)
  44. mask_token_num = 0
  45. for i, x in enumerate(data):
  46. for j, pos in enumerate(x[masked_lm_positions]):
  47. out[masked_lm_positions][mask_token_num] = i * seq_length + pos
  48. out[masked_lm_labels][mask_token_num] = x[masked_lm_labels][j]
  49. mask_token_num += 1
  50. # The value of masked_lm_scale is equal to mask_token_num,
  51. # which would be used to compute average masked_lm_loss.
  52. out.append(np.asarray([mask_token_num], dtype=np.float32))
  53. if args.amp and args.use_pure_fp16:
  54. #out[input_mask] = out[input_mask].astype(np.float16)
  55. out[masked_lm_scale] = out[masked_lm_scale].astype(np.float16)
  56. return out
  57. train_data_loader = DataLoader(
  58. dataset=train_data,
  59. places=places,
  60. feed_list=data_holders,
  61. batch_sampler=train_batch_sampler,
  62. collate_fn=_collate_data,
  63. num_workers=0,
  64. worker_init_fn=worker_init,
  65. return_list=False)
  66. return train_data_loader
  67. def create_pretraining_data_holder():
  68. input_ids = paddle.static.data(
  69. name="input_ids", shape=[-1, -1], dtype="int64")
  70. segment_ids = paddle.static.data(
  71. name="segment_ids", shape=[-1, -1], dtype="int64")
  72. input_mask = paddle.static.data(
  73. name="input_mask", shape=[-1, 1, 1, -1], dtype="int64")
  74. masked_lm_positions = paddle.static.data(
  75. name="masked_lm_positions", shape=[-1], dtype="int32")
  76. masked_lm_labels = paddle.static.data(
  77. name="masked_lm_labels", shape=[-1, 1], dtype="int64")
  78. next_sentence_labels = paddle.static.data(
  79. name="next_sentence_labels", shape=[-1, 1], dtype="int64")
  80. masked_lm_scale = paddle.static.data(
  81. name="masked_lm_scale", shape=[-1, 1], dtype="float32")
  82. return [
  83. input_ids, segment_ids, input_mask, masked_lm_positions,
  84. masked_lm_labels, next_sentence_labels, masked_lm_scale
  85. ]
  86. def select_dataset_file_for_each_worker(files, f_start_id, num_trainers,
  87. trainer_id):
  88. """
  89. Spliting the train file according to the worker index.
  90. """
  91. num_files = len(files)
  92. if num_trainers > num_files:
  93. remainder = num_trainers % num_files
  94. data_file = files[(
  95. f_start_id * num_trainers + trainer_id + remainder * f_start_id) %
  96. num_files]
  97. else:
  98. data_file = files[(f_start_id * num_trainers + trainer_id) % num_files]
  99. return data_file
  100. class WorkerInitObj:
  101. "Construct the object with different seed, and the Dataloader will generate the data "
  102. "with different seed in each worker."
  103. def __init__(self, seed):
  104. self.seed = seed
  105. def __call__(self, pid):
  106. np.random.seed(seed=self.seed + pid)
  107. random.seed(self.seed + pid)
  108. class PretrainingDataset(Dataset):
  109. def __init__(self, input_file, max_pred_length):
  110. self.input_file = input_file
  111. self.max_pred_length = max_pred_length
  112. f = h5py.File(input_file, "r")
  113. keys = [
  114. 'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions',
  115. 'masked_lm_ids', 'next_sentence_labels'
  116. ]
  117. self.inputs = [np.asarray(f[key][:]) for key in keys]
  118. f.close()
  119. def __len__(self):
  120. 'Denotes the total number of samples'
  121. return len(self.inputs[0])
  122. def __getitem__(self, index):
  123. # convert next_sentence_labels (index=5) to np.ndarray type
  124. [
  125. input_ids, input_mask, segment_ids, masked_lm_positions,
  126. masked_lm_ids, next_sentence_labels
  127. ] = [
  128. input[index].astype(np.int64)
  129. if indice < 5 else np.asarray(input[index].astype(np.int64))
  130. for indice, input in enumerate(self.inputs)
  131. ]
  132. # input_mask = (1 - np.reshape(
  133. # input_mask.astype(np.float32), [1, 1, input_mask.shape[0]])) * -1e4
  134. input_mask = np.reshape(input_mask, [1, 1, input_mask.shape[0]])
  135. index = self.max_pred_length
  136. padded_mask_indices = (masked_lm_positions == 0).nonzero()[0]
  137. if len(padded_mask_indices) != 0:
  138. index = padded_mask_indices[0].item()
  139. else:
  140. index = self.max_pred_length
  141. masked_lm_labels = masked_lm_ids[:index]
  142. masked_lm_positions = masked_lm_positions[:index]
  143. # softmax_with_cross_entropy enforce last dim size equal 1
  144. masked_lm_labels = np.expand_dims(masked_lm_labels, axis=-1)
  145. next_sentence_labels = np.expand_dims(next_sentence_labels, axis=-1)
  146. return [
  147. input_ids, segment_ids, input_mask, masked_lm_positions,
  148. masked_lm_labels, next_sentence_labels
  149. ]