bertPrep.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software
  9. # distributed under the License is distributed on an "AS IS" BASIS,
  10. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. # See the License for the specific language governing permissions and
  12. # limitations under the License.
  13. import BookscorpusTextFormatting
  14. import Downloader
  15. import TextSharding
  16. import WikicorpusTextFormatting
  17. import argparse
  18. import itertools
  19. import multiprocessing
  20. import os
  21. import pprint
  22. import subprocess
  23. def main(args):
  24. working_dir = os.environ['BERT_PREP_WORKING_DIR']
  25. print('Working Directory:', working_dir)
  26. print('Action:', args.action)
  27. print('Dataset Name:', args.dataset)
  28. if args.input_files:
  29. args.input_files = args.input_files.split(',')
  30. hdf5_tfrecord_folder_prefix = "_lower_case_" + str(args.do_lower_case) + "_seq_len_" + str(args.max_seq_length) \
  31. + "_max_pred_" + str(args.max_predictions_per_seq) + "_masked_lm_prob_" + str(args.masked_lm_prob) \
  32. + "_random_seed_" + str(args.random_seed) + "_dupe_factor_" + str(args.dupe_factor)
  33. directory_structure = {
  34. 'download' : working_dir + '/download', # Downloaded and decompressed
  35. 'extracted' : working_dir +'/extracted', # Extracted from whatever the initial format is (e.g., wikiextractor)
  36. 'formatted' : working_dir + '/formatted_one_article_per_line', # This is the level where all sources should look the same
  37. 'sharded' : working_dir + '/sharded_' + "training_shards_" + str(args.n_training_shards) + "_test_shards_" + str(args.n_test_shards) + "_fraction_" + str(args.fraction_test_set),
  38. 'tfrecord' : working_dir + '/tfrecord'+ hdf5_tfrecord_folder_prefix,
  39. 'hdf5': working_dir + '/hdf5' + hdf5_tfrecord_folder_prefix
  40. }
  41. print('\nDirectory Structure:')
  42. pp = pprint.PrettyPrinter(indent=2)
  43. pp.pprint(directory_structure)
  44. print('')
  45. if args.action == 'download':
  46. if not os.path.exists(directory_structure['download']):
  47. os.makedirs(directory_structure['download'])
  48. downloader = Downloader.Downloader(args.dataset, directory_structure['download'])
  49. downloader.download()
  50. elif args.action == 'text_formatting':
  51. assert args.dataset != 'google_pretrained_weights' and args.dataset != 'nvidia_pretrained_weights' and args.dataset != 'squad' and args.dataset != 'mrpc', 'Cannot perform text_formatting on pretrained weights'
  52. if not os.path.exists(directory_structure['extracted']):
  53. os.makedirs(directory_structure['extracted'])
  54. if not os.path.exists(directory_structure['formatted']):
  55. os.makedirs(directory_structure['formatted'])
  56. if args.dataset == 'bookscorpus':
  57. books_path = directory_structure['download'] + '/bookscorpus'
  58. #books_path = directory_structure['download']
  59. output_filename = directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt'
  60. books_formatter = BookscorpusTextFormatting.BookscorpusTextFormatting(books_path, output_filename, recursive=True)
  61. books_formatter.merge()
  62. elif args.dataset == 'wikicorpus_en':
  63. if args.skip_wikiextractor == 0:
  64. path_to_wikiextractor_in_container = '/workspace/wikiextractor/WikiExtractor.py'
  65. wikiextractor_command = path_to_wikiextractor_in_container + ' ' + directory_structure['download'] + '/' + args.dataset + '/wikicorpus_en.xml ' + '-b 100M --processes ' + str(args.n_processes) + ' -o ' + directory_structure['extracted'] + '/' + args.dataset
  66. print('WikiExtractor Command:', wikiextractor_command)
  67. wikiextractor_process = subprocess.run(wikiextractor_command, shell=True, check=True)
  68. #wikiextractor_process.communicate()
  69. wiki_path = directory_structure['extracted'] + '/wikicorpus_en'
  70. output_filename = directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt'
  71. wiki_formatter = WikicorpusTextFormatting.WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
  72. wiki_formatter.merge()
  73. elif args.dataset == 'wikicorpus_zh':
  74. assert False, 'wikicorpus_zh not fully supported at this time. The simplified/tradition Chinese data needs to be translated and properly segmented still, and should work once this step is added.'
  75. if args.skip_wikiextractor == 0:
  76. path_to_wikiextractor_in_container = '/workspace/wikiextractor/WikiExtractor.py'
  77. wikiextractor_command = path_to_wikiextractor_in_container + ' ' + directory_structure['download'] + '/' + args.dataset + '/wikicorpus_zh.xml ' + '-b 100M --processes ' + str(args.n_processes) + ' -o ' + directory_structure['extracted'] + '/' + args.dataset
  78. print('WikiExtractor Command:', wikiextractor_command)
  79. wikiextractor_process = subprocess.run(wikiextractor_command, shell=True, check=True)
  80. #wikiextractor_process.communicate()
  81. wiki_path = directory_structure['extracted'] + '/wikicorpus_zh'
  82. output_filename = directory_structure['formatted'] + '/wikicorpus_zh_one_article_per_line.txt'
  83. wiki_formatter = WikicorpusTextFormatting.WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
  84. wiki_formatter.merge()
  85. assert os.stat(output_filename).st_size > 0, 'File glob did not pick up extracted wiki files from WikiExtractor.'
  86. elif args.action == 'sharding':
  87. # Note: books+wiki requires user to provide list of input_files (comma-separated with no spaces)
  88. if args.dataset == 'bookscorpus' or 'wikicorpus' in args.dataset or 'books_wiki' in args.dataset:
  89. if args.input_files is None:
  90. if args.dataset == 'bookscorpus':
  91. args.input_files = [directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt']
  92. elif args.dataset == 'wikicorpus_en':
  93. args.input_files = [directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt']
  94. elif args.dataset == 'wikicorpus_zh':
  95. args.input_files = [directory_structure['formatted'] + '/wikicorpus_zh_one_article_per_line.txt']
  96. elif args.dataset == 'books_wiki_en_corpus':
  97. args.input_files = [directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt', directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt']
  98. output_file_prefix = directory_structure['sharded'] + '/' + args.dataset + '/' + args.dataset
  99. if not os.path.exists(directory_structure['sharded']):
  100. os.makedirs(directory_structure['sharded'])
  101. if not os.path.exists(directory_structure['sharded'] + '/' + args.dataset):
  102. os.makedirs(directory_structure['sharded'] + '/' + args.dataset)
  103. # Segmentation is here because all datasets look the same in one article/book/whatever per line format, and
  104. # it seemed unnecessarily complicated to add an additional preprocessing step to call just for this.
  105. # Different languages (e.g., Chinese simplified/traditional) may require translation and
  106. # other packages to be called from here -- just add a conditional branch for those extra steps
  107. segmenter = TextSharding.NLTKSegmenter()
  108. sharding = TextSharding.Sharding(args.input_files, output_file_prefix, args.n_training_shards, args.n_test_shards, args.fraction_test_set)
  109. sharding.load_articles()
  110. sharding.segment_articles_into_sentences(segmenter)
  111. sharding.distribute_articles_over_shards()
  112. sharding.write_shards_to_disk()
  113. else:
  114. assert False, 'Unsupported dataset for sharding'
  115. elif args.action == 'create_tfrecord_files':
  116. assert False, 'TFrecord creation not supported in this PyTorch model example release.' \
  117. ''
  118. if not os.path.exists(directory_structure['tfrecord'] + "/" + args.dataset):
  119. os.makedirs(directory_structure['tfrecord'] + "/" + args.dataset)
  120. def create_record_worker(filename_prefix, shard_id, output_format='tfrecord'):
  121. bert_preprocessing_command = 'python /workspace/bert/create_pretraining_data.py'
  122. bert_preprocessing_command += ' --input_file=' + directory_structure['sharded'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.txt'
  123. bert_preprocessing_command += ' --output_file=' + directory_structure['tfrecord'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
  124. bert_preprocessing_command += ' --vocab_file=' + args.vocab_file
  125. bert_preprocessing_command += ' --do_lower_case' if args.do_lower_case else ''
  126. bert_preprocessing_command += ' --max_seq_length=' + str(args.max_seq_length)
  127. bert_preprocessing_command += ' --max_predictions_per_seq=' + str(args.max_predictions_per_seq)
  128. bert_preprocessing_command += ' --masked_lm_prob=' + str(args.masked_lm_prob)
  129. bert_preprocessing_command += ' --random_seed=' + str(args.random_seed)
  130. bert_preprocessing_command += ' --dupe_factor=' + str(args.dupe_factor)
  131. bert_preprocessing_process = subprocess.Popen(bert_preprocessing_command, shell=True)
  132. bert_preprocessing_process.communicate()
  133. last_process = bert_preprocessing_process
  134. # This could be better optimized (fine if all take equal time)
  135. if shard_id % args.n_processes == 0 and shard_id > 0:
  136. bert_preprocessing_process.wait()
  137. return last_process
  138. output_file_prefix = args.dataset
  139. for i in range(args.n_training_shards):
  140. last_process =create_record_worker(output_file_prefix + '_training', i)
  141. last_process.wait()
  142. for i in range(args.n_test_shards):
  143. last_process = create_record_worker(output_file_prefix + '_test', i)
  144. last_process.wait()
  145. elif args.action == 'create_hdf5_files':
  146. last_process = None
  147. if not os.path.exists(directory_structure['hdf5'] + "/" + args.dataset):
  148. os.makedirs(directory_structure['hdf5'] + "/" + args.dataset)
  149. def create_record_worker(filename_prefix, shard_id, output_format='hdf5'):
  150. bert_preprocessing_command = 'python /workspace/bert/create_pretraining_data.py'
  151. bert_preprocessing_command += ' --input_file=' + directory_structure['sharded'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.txt'
  152. bert_preprocessing_command += ' --output_file=' + directory_structure['hdf5'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
  153. bert_preprocessing_command += ' --vocab_file=' + args.vocab_file
  154. bert_preprocessing_command += ' --do_lower_case' if args.do_lower_case else ''
  155. bert_preprocessing_command += ' --max_seq_length=' + str(args.max_seq_length)
  156. bert_preprocessing_command += ' --max_predictions_per_seq=' + str(args.max_predictions_per_seq)
  157. bert_preprocessing_command += ' --masked_lm_prob=' + str(args.masked_lm_prob)
  158. bert_preprocessing_command += ' --random_seed=' + str(args.random_seed)
  159. bert_preprocessing_command += ' --dupe_factor=' + str(args.dupe_factor)
  160. bert_preprocessing_process = subprocess.Popen(bert_preprocessing_command, shell=True)
  161. bert_preprocessing_process.communicate()
  162. last_process = bert_preprocessing_process
  163. # This could be better optimized (fine if all take equal time)
  164. if shard_id % args.n_processes == 0 and shard_id > 0:
  165. bert_preprocessing_process.wait()
  166. return last_process
  167. output_file_prefix = args.dataset
  168. for i in range(args.n_training_shards):
  169. last_process = create_record_worker(output_file_prefix + '_training', i)
  170. last_process.wait()
  171. for i in range(args.n_test_shards):
  172. last_process = create_record_worker(output_file_prefix + '_test', i)
  173. last_process.wait()
  174. if __name__ == "__main__":
  175. parser = argparse.ArgumentParser(
  176. description='Preprocessing Application for Everything BERT-related'
  177. )
  178. parser.add_argument(
  179. '--action',
  180. type=str,
  181. help='Specify the action you want the app to take. e.g., generate vocab, segment, create tfrecords',
  182. choices={
  183. 'download', # Download and verify mdf5/sha sums
  184. 'text_formatting', # Convert into a file that contains one article/book per line
  185. 'sharding', # Convert previous formatted text into shards containing one sentence per line
  186. 'create_tfrecord_files', # Turn each shard into a TFrecord with masking and next sentence prediction info
  187. 'create_hdf5_files' # Turn each shard into a HDF5 file with masking and next sentence prediction info
  188. }
  189. )
  190. parser.add_argument(
  191. '--dataset',
  192. type=str,
  193. help='Specify the dataset to perform --action on',
  194. choices={
  195. 'bookscorpus',
  196. 'wikicorpus_en',
  197. 'wikicorpus_zh',
  198. 'books_wiki_en_corpus',
  199. 'google_pretrained_weights',
  200. 'nvidia_pretrained_weights',
  201. 'mrpc',
  202. 'squad',
  203. 'all'
  204. }
  205. )
  206. parser.add_argument(
  207. '--input_files',
  208. type=str,
  209. help='Specify the input files in a comma-separated list (no spaces)'
  210. )
  211. parser.add_argument(
  212. '--n_training_shards',
  213. type=int,
  214. help='Specify the number of training shards to generate',
  215. default=256
  216. )
  217. parser.add_argument(
  218. '--n_test_shards',
  219. type=int,
  220. help='Specify the number of test shards to generate',
  221. default=256
  222. )
  223. parser.add_argument(
  224. '--fraction_test_set',
  225. type=float,
  226. help='Specify the fraction (0..1) of the data to withhold for the test data split (based on number of sequences)',
  227. default=0.2
  228. )
  229. parser.add_argument(
  230. '--segmentation_method',
  231. type=str,
  232. help='Specify your choice of sentence segmentation',
  233. choices={
  234. 'nltk'
  235. },
  236. default='nltk'
  237. )
  238. parser.add_argument(
  239. '--n_processes',
  240. type=int,
  241. help='Specify the max number of processes to allow at one time',
  242. default=4
  243. )
  244. parser.add_argument(
  245. '--random_seed',
  246. type=int,
  247. help='Specify the base seed to use for any random number generation',
  248. default=12345
  249. )
  250. parser.add_argument(
  251. '--dupe_factor',
  252. type=int,
  253. help='Specify the duplication factor',
  254. default=5
  255. )
  256. parser.add_argument(
  257. '--masked_lm_prob',
  258. type=float,
  259. help='Specify the probability for masked lm',
  260. default=0.15
  261. )
  262. parser.add_argument(
  263. '--max_seq_length',
  264. type=int,
  265. help='Specify the maximum sequence length',
  266. default=512
  267. )
  268. parser.add_argument(
  269. '--max_predictions_per_seq',
  270. type=int,
  271. help='Specify the maximum number of masked words per sequence',
  272. default=20
  273. )
  274. parser.add_argument(
  275. '--do_lower_case',
  276. type=int,
  277. help='Specify whether it is cased (0) or uncased (1) (any number greater than 0 will be treated as uncased)',
  278. default=1
  279. )
  280. parser.add_argument(
  281. '--vocab_file',
  282. type=str,
  283. help='Specify absolute path to vocab file to use)'
  284. )
  285. parser.add_argument(
  286. '--skip_wikiextractor',
  287. type=int,
  288. help='Specify whether to skip wikiextractor step 0=False, 1=True',
  289. default=0
  290. )
  291. parser.add_argument(
  292. '--interactive_json_config_generator',
  293. type=str,
  294. help='Specify the action you want the app to take. e.g., generate vocab, segment, create tfrecords'
  295. )
  296. args = parser.parse_args()
  297. main(args)