create_pretraining_data.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Create masked LM/next sentence masked_lm TF examples for BERT."""
  16. from __future__ import absolute_import, division, print_function, unicode_literals
  17. import argparse
  18. import logging
  19. import os
  20. import random
  21. from io import open
  22. import h5py
  23. import numpy as np
  24. from tqdm import tqdm, trange
  25. from tokenization import BertTokenizer
  26. import tokenization as tokenization
  27. import random
  28. import collections
  29. class TrainingInstance(object):
  30. """A single training instance (sentence pair)."""
  31. def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
  32. is_random_next):
  33. self.tokens = tokens
  34. self.segment_ids = segment_ids
  35. self.is_random_next = is_random_next
  36. self.masked_lm_positions = masked_lm_positions
  37. self.masked_lm_labels = masked_lm_labels
  38. def __str__(self):
  39. s = ""
  40. s += "tokens: %s\n" % (" ".join(
  41. [tokenization.printable_text(x) for x in self.tokens]))
  42. s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
  43. s += "is_random_next: %s\n" % self.is_random_next
  44. s += "masked_lm_positions: %s\n" % (" ".join(
  45. [str(x) for x in self.masked_lm_positions]))
  46. s += "masked_lm_labels: %s\n" % (" ".join(
  47. [tokenization.printable_text(x) for x in self.masked_lm_labels]))
  48. s += "\n"
  49. return s
  50. def __repr__(self):
  51. return self.__str__()
  52. def write_instance_to_example_file(instances, tokenizer, max_seq_length,
  53. max_predictions_per_seq, output_file):
  54. """Create TF example files from `TrainingInstance`s."""
  55. total_written = 0
  56. features = collections.OrderedDict()
  57. num_instances = len(instances)
  58. features["input_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32")
  59. features["input_mask"] = np.zeros([num_instances, max_seq_length], dtype="int32")
  60. features["segment_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32")
  61. features["masked_lm_positions"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32")
  62. features["masked_lm_ids"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32")
  63. features["next_sentence_labels"] = np.zeros(num_instances, dtype="int32")
  64. for inst_index, instance in enumerate(tqdm(instances)):
  65. input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
  66. input_mask = [1] * len(input_ids)
  67. segment_ids = list(instance.segment_ids)
  68. assert len(input_ids) <= max_seq_length
  69. while len(input_ids) < max_seq_length:
  70. input_ids.append(0)
  71. input_mask.append(0)
  72. segment_ids.append(0)
  73. assert len(input_ids) == max_seq_length
  74. assert len(input_mask) == max_seq_length
  75. assert len(segment_ids) == max_seq_length
  76. masked_lm_positions = list(instance.masked_lm_positions)
  77. masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
  78. masked_lm_weights = [1.0] * len(masked_lm_ids)
  79. while len(masked_lm_positions) < max_predictions_per_seq:
  80. masked_lm_positions.append(0)
  81. masked_lm_ids.append(0)
  82. masked_lm_weights.append(0.0)
  83. next_sentence_label = 1 if instance.is_random_next else 0
  84. features["input_ids"][inst_index] = input_ids
  85. features["input_mask"][inst_index] = input_mask
  86. features["segment_ids"][inst_index] = segment_ids
  87. features["masked_lm_positions"][inst_index] = masked_lm_positions
  88. features["masked_lm_ids"][inst_index] = masked_lm_ids
  89. features["next_sentence_labels"][inst_index] = next_sentence_label
  90. total_written += 1
  91. # if inst_index < 20:
  92. # tf.logging.info("*** Example ***")
  93. # tf.logging.info("tokens: %s" % " ".join(
  94. # [tokenization.printable_text(x) for x in instance.tokens]))
  95. # for feature_name in features.keys():
  96. # feature = features[feature_name]
  97. # values = []
  98. # if feature.int64_list.value:
  99. # values = feature.int64_list.value
  100. # elif feature.float_list.value:
  101. # values = feature.float_list.value
  102. # tf.logging.info(
  103. # "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
  104. print("saving data")
  105. f= h5py.File(output_file, 'w')
  106. f.create_dataset("input_ids", data=features["input_ids"], dtype='i4', compression='gzip')
  107. f.create_dataset("input_mask", data=features["input_mask"], dtype='i1', compression='gzip')
  108. f.create_dataset("segment_ids", data=features["segment_ids"], dtype='i1', compression='gzip')
  109. f.create_dataset("masked_lm_positions", data=features["masked_lm_positions"], dtype='i4', compression='gzip')
  110. f.create_dataset("masked_lm_ids", data=features["masked_lm_ids"], dtype='i4', compression='gzip')
  111. f.create_dataset("next_sentence_labels", data=features["next_sentence_labels"], dtype='i1', compression='gzip')
  112. f.flush()
  113. f.close()
  114. def create_training_instances(input_files, tokenizer, max_seq_length,
  115. dupe_factor, short_seq_prob, masked_lm_prob,
  116. max_predictions_per_seq, rng):
  117. """Create `TrainingInstance`s from raw text."""
  118. all_documents = [[]]
  119. # Input file format:
  120. # (1) One sentence per line. These should ideally be actual sentences, not
  121. # entire paragraphs or arbitrary spans of text. (Because we use the
  122. # sentence boundaries for the "next sentence prediction" task).
  123. # (2) Blank lines between documents. Document boundaries are needed so
  124. # that the "next sentence prediction" task doesn't span between documents.
  125. for input_file in input_files:
  126. print("creating instance from {}".format(input_file))
  127. with open(input_file, "r") as reader:
  128. while True:
  129. line = tokenization.convert_to_unicode(reader.readline())
  130. if not line:
  131. break
  132. line = line.strip()
  133. # Empty lines are used as document delimiters
  134. if not line:
  135. all_documents.append([])
  136. tokens = tokenizer.tokenize(line)
  137. if tokens:
  138. all_documents[-1].append(tokens)
  139. # Remove empty documents
  140. all_documents = [x for x in all_documents if x]
  141. rng.shuffle(all_documents)
  142. vocab_words = list(tokenizer.vocab.keys())
  143. instances = []
  144. for _ in range(dupe_factor):
  145. for document_index in range(len(all_documents)):
  146. instances.extend(
  147. create_instances_from_document(
  148. all_documents, document_index, max_seq_length, short_seq_prob,
  149. masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
  150. rng.shuffle(instances)
  151. return instances
  152. def create_instances_from_document(
  153. all_documents, document_index, max_seq_length, short_seq_prob,
  154. masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
  155. """Creates `TrainingInstance`s for a single document."""
  156. document = all_documents[document_index]
  157. # Account for [CLS], [SEP], [SEP]
  158. max_num_tokens = max_seq_length - 3
  159. # We *usually* want to fill up the entire sequence since we are padding
  160. # to `max_seq_length` anyways, so short sequences are generally wasted
  161. # computation. However, we *sometimes*
  162. # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
  163. # sequences to minimize the mismatch between pre-training and fine-tuning.
  164. # The `target_seq_length` is just a rough target however, whereas
  165. # `max_seq_length` is a hard limit.
  166. target_seq_length = max_num_tokens
  167. if rng.random() < short_seq_prob:
  168. target_seq_length = rng.randint(2, max_num_tokens)
  169. # We DON'T just concatenate all of the tokens from a document into a long
  170. # sequence and choose an arbitrary split point because this would make the
  171. # next sentence prediction task too easy. Instead, we split the input into
  172. # segments "A" and "B" based on the actual "sentences" provided by the user
  173. # input.
  174. instances = []
  175. current_chunk = []
  176. current_length = 0
  177. i = 0
  178. while i < len(document):
  179. segment = document[i]
  180. current_chunk.append(segment)
  181. current_length += len(segment)
  182. if i == len(document) - 1 or current_length >= target_seq_length:
  183. if current_chunk:
  184. # `a_end` is how many segments from `current_chunk` go into the `A`
  185. # (first) sentence.
  186. a_end = 1
  187. if len(current_chunk) >= 2:
  188. a_end = rng.randint(1, len(current_chunk) - 1)
  189. tokens_a = []
  190. for j in range(a_end):
  191. tokens_a.extend(current_chunk[j])
  192. tokens_b = []
  193. # Random next
  194. is_random_next = False
  195. if len(current_chunk) == 1 or rng.random() < 0.5:
  196. is_random_next = True
  197. target_b_length = target_seq_length - len(tokens_a)
  198. # This should rarely go for more than one iteration for large
  199. # corpora. However, just to be careful, we try to make sure that
  200. # the random document is not the same as the document
  201. # we're processing.
  202. for _ in range(10):
  203. random_document_index = rng.randint(0, len(all_documents) - 1)
  204. if random_document_index != document_index:
  205. break
  206. #If picked random document is the same as the current document
  207. if random_document_index == document_index:
  208. is_random_next = False
  209. random_document = all_documents[random_document_index]
  210. random_start = rng.randint(0, len(random_document) - 1)
  211. for j in range(random_start, len(random_document)):
  212. tokens_b.extend(random_document[j])
  213. if len(tokens_b) >= target_b_length:
  214. break
  215. # We didn't actually use these segments so we "put them back" so
  216. # they don't go to waste.
  217. num_unused_segments = len(current_chunk) - a_end
  218. i -= num_unused_segments
  219. # Actual next
  220. else:
  221. is_random_next = False
  222. for j in range(a_end, len(current_chunk)):
  223. tokens_b.extend(current_chunk[j])
  224. truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
  225. assert len(tokens_a) >= 1
  226. assert len(tokens_b) >= 1
  227. tokens = []
  228. segment_ids = []
  229. tokens.append("[CLS]")
  230. segment_ids.append(0)
  231. for token in tokens_a:
  232. tokens.append(token)
  233. segment_ids.append(0)
  234. tokens.append("[SEP]")
  235. segment_ids.append(0)
  236. for token in tokens_b:
  237. tokens.append(token)
  238. segment_ids.append(1)
  239. tokens.append("[SEP]")
  240. segment_ids.append(1)
  241. (tokens, masked_lm_positions,
  242. masked_lm_labels) = create_masked_lm_predictions(
  243. tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
  244. instance = TrainingInstance(
  245. tokens=tokens,
  246. segment_ids=segment_ids,
  247. is_random_next=is_random_next,
  248. masked_lm_positions=masked_lm_positions,
  249. masked_lm_labels=masked_lm_labels)
  250. instances.append(instance)
  251. current_chunk = []
  252. current_length = 0
  253. i += 1
  254. return instances
  255. MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
  256. ["index", "label"])
  257. def create_masked_lm_predictions(tokens, masked_lm_prob,
  258. max_predictions_per_seq, vocab_words, rng):
  259. """Creates the predictions for the masked LM objective."""
  260. cand_indexes = []
  261. for (i, token) in enumerate(tokens):
  262. if token == "[CLS]" or token == "[SEP]":
  263. continue
  264. cand_indexes.append(i)
  265. rng.shuffle(cand_indexes)
  266. output_tokens = list(tokens)
  267. num_to_predict = min(max_predictions_per_seq,
  268. max(1, int(round(len(tokens) * masked_lm_prob))))
  269. masked_lms = []
  270. covered_indexes = set()
  271. for index in cand_indexes:
  272. if len(masked_lms) >= num_to_predict:
  273. break
  274. if index in covered_indexes:
  275. continue
  276. covered_indexes.add(index)
  277. masked_token = None
  278. # 80% of the time, replace with [MASK]
  279. if rng.random() < 0.8:
  280. masked_token = "[MASK]"
  281. else:
  282. # 10% of the time, keep original
  283. if rng.random() < 0.5:
  284. masked_token = tokens[index]
  285. # 10% of the time, replace with random word
  286. else:
  287. masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
  288. output_tokens[index] = masked_token
  289. masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
  290. masked_lms = sorted(masked_lms, key=lambda x: x.index)
  291. masked_lm_positions = []
  292. masked_lm_labels = []
  293. for p in masked_lms:
  294. masked_lm_positions.append(p.index)
  295. masked_lm_labels.append(p.label)
  296. return (output_tokens, masked_lm_positions, masked_lm_labels)
  297. def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
  298. """Truncates a pair of sequences to a maximum sequence length."""
  299. while True:
  300. total_length = len(tokens_a) + len(tokens_b)
  301. if total_length <= max_num_tokens:
  302. break
  303. trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
  304. assert len(trunc_tokens) >= 1
  305. # We want to sometimes truncate from the front and sometimes from the
  306. # back to add more randomness and avoid biases.
  307. if rng.random() < 0.5:
  308. del trunc_tokens[0]
  309. else:
  310. trunc_tokens.pop()
  311. def main():
  312. parser = argparse.ArgumentParser()
  313. ## Required parameters
  314. parser.add_argument("--vocab_file",
  315. default=None,
  316. type=str,
  317. required=True,
  318. help="The vocabulary the BERT model will train on.")
  319. parser.add_argument("--input_file",
  320. default=None,
  321. type=str,
  322. required=True,
  323. help="The input train corpus. can be directory with .txt files or a path to a single file")
  324. parser.add_argument("--output_file",
  325. default=None,
  326. type=str,
  327. required=True,
  328. help="The output file where the model checkpoints will be written.")
  329. ## Other parameters
  330. # str
  331. parser.add_argument("--bert_model", default="bert-large-uncased", type=str, required=False,
  332. help="Bert pre-trained model selected in the list: bert-base-uncased, "
  333. "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
  334. #int
  335. parser.add_argument("--max_seq_length",
  336. default=128,
  337. type=int,
  338. help="The maximum total input sequence length after WordPiece tokenization. \n"
  339. "Sequences longer than this will be truncated, and sequences shorter \n"
  340. "than this will be padded.")
  341. parser.add_argument("--dupe_factor",
  342. default=10,
  343. type=int,
  344. help="Number of times to duplicate the input data (with different masks).")
  345. parser.add_argument("--max_predictions_per_seq",
  346. default=20,
  347. type=int,
  348. help="Maximum sequence length.")
  349. # floats
  350. parser.add_argument("--masked_lm_prob",
  351. default=0.15,
  352. type=float,
  353. help="Masked LM probability.")
  354. parser.add_argument("--short_seq_prob",
  355. default=0.1,
  356. type=float,
  357. help="Probability to create a sequence shorter than maximum sequence length")
  358. parser.add_argument("--do_lower_case",
  359. action='store_true',
  360. default=True,
  361. help="Whether to lower case the input text. True for uncased models, False for cased models.")
  362. parser.add_argument('--random_seed',
  363. type=int,
  364. default=12345,
  365. help="random seed for initialization")
  366. args = parser.parse_args()
  367. tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512)
  368. input_files = []
  369. if os.path.isfile(args.input_file):
  370. input_files.append(args.input_file)
  371. elif os.path.isdir(args.input_file):
  372. input_files = [os.path.join(args.input_file, f) for f in os.listdir(args.input_file) if (os.path.isfile(os.path.join(args.input_file, f)) and f.endswith('.txt') )]
  373. else:
  374. raise ValueError("{} is not a valid path".format(args.input_file))
  375. rng = random.Random(args.random_seed)
  376. instances = create_training_instances(
  377. input_files, tokenizer, args.max_seq_length, args.dupe_factor,
  378. args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq,
  379. rng)
  380. output_file = args.output_file
  381. write_instance_to_example_file(instances, tokenizer, args.max_seq_length,
  382. args.max_predictions_per_seq, output_file)
  383. if __name__ == "__main__":
  384. main()