|
|
@@ -23,13 +23,17 @@ def main(args):
|
|
|
if args.input_files:
|
|
|
args.input_files = args.input_files.split(',')
|
|
|
|
|
|
+ hdf5_tfrecord_folder_prefix = "_lower_case_" + str(args.do_lower_case) + "_seq_len_" + str(args.max_seq_length) \
|
|
|
+ + "_max_pred_" + str(args.max_predictions_per_seq) + "_masked_lm_prob_" + str(args.masked_lm_prob) \
|
|
|
+ + "_random_seed_" + str(args.random_seed) + "_dupe_factor_" + str(args.dupe_factor)
|
|
|
+
|
|
|
directory_structure = {
|
|
|
'download' : working_dir + '/download', # Downloaded and decompressed
|
|
|
'extracted' : working_dir +'/extracted', # Extracted from whatever the initial format is (e.g., wikiextractor)
|
|
|
'formatted' : working_dir + '/formatted_one_article_per_line', # This is the level where all sources should look the same
|
|
|
- 'sharded' : working_dir + '/sharded',
|
|
|
- 'tfrecord' : working_dir + '/tfrecord',
|
|
|
- 'hdf5': working_dir + '/hdf5'
|
|
|
+ 'sharded' : working_dir + '/sharded_' + "training_shards_" + str(args.n_training_shards) + "_test_shards_" + str(args.n_test_shards) + "_fraction_" + str(args.fraction_test_set),
|
|
|
+ 'tfrecord' : working_dir + '/tfrecord'+ hdf5_tfrecord_folder_prefix,
|
|
|
+ 'hdf5': working_dir + '/hdf5' + hdf5_tfrecord_folder_prefix
|
|
|
}
|
|
|
|
|
|
print('\nDirectory Structure:')
|
|
|
@@ -100,8 +104,7 @@ def main(args):
|
|
|
elif args.dataset == 'books_wiki_en_corpus':
|
|
|
args.input_files = [directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt', directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt']
|
|
|
|
|
|
- if args.output_file_prefix is None:
|
|
|
- args.output_file_prefix = directory_structure['sharded'] + '/' + args.dataset + '/' + args.dataset
|
|
|
+ output_file_prefix = directory_structure['sharded'] + '/' + args.dataset + '/' + args.dataset
|
|
|
|
|
|
if not os.path.exists(directory_structure['sharded']):
|
|
|
os.makedirs(directory_structure['sharded'])
|
|
|
@@ -114,7 +117,7 @@ def main(args):
|
|
|
# Different languages (e.g., Chinese simplified/traditional) may require translation and
|
|
|
# other packages to be called from here -- just add a conditional branch for those extra steps
|
|
|
segmenter = TextSharding.NLTKSegmenter()
|
|
|
- sharding = TextSharding.Sharding(args.input_files, args.output_file_prefix, args.n_training_shards, args.n_test_shards, args.fraction_test_set)
|
|
|
+ sharding = TextSharding.Sharding(args.input_files, output_file_prefix, args.n_training_shards, args.n_test_shards, args.fraction_test_set)
|
|
|
|
|
|
sharding.load_articles()
|
|
|
sharding.segment_articles_into_sentences(segmenter)
|
|
|
@@ -127,15 +130,15 @@ def main(args):
|
|
|
elif args.action == 'create_tfrecord_files':
|
|
|
assert False, 'TFrecord creation not supported in this PyTorch model example release.' \
|
|
|
''
|
|
|
- if not os.path.exists(directory_structure['tfrecord']):
|
|
|
- os.makedirs(directory_structure['tfrecord'])
|
|
|
+ if not os.path.exists(directory_structure['tfrecord'] + "/" + args.dataset):
|
|
|
+ os.makedirs(directory_structure['tfrecord'] + "/" + args.dataset)
|
|
|
|
|
|
def create_record_worker(filename_prefix, shard_id, output_format='tfrecord'):
|
|
|
bert_preprocessing_command = 'python /workspace/bert/create_pretraining_data.py'
|
|
|
bert_preprocessing_command += ' --input_file=' + directory_structure['sharded'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.txt'
|
|
|
bert_preprocessing_command += ' --output_file=' + directory_structure['tfrecord'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
|
|
|
bert_preprocessing_command += ' --vocab_file=' + args.vocab_file
|
|
|
- bert_preprocessing_command += ' --do_lower_case=' + 'true' if args.do_lower_case else 'false'
|
|
|
+ bert_preprocessing_command += ' --do_lower_case' if args.do_lower_case else ''
|
|
|
bert_preprocessing_command += ' --max_seq_length=' + str(args.max_seq_length)
|
|
|
bert_preprocessing_command += ' --max_predictions_per_seq=' + str(args.max_predictions_per_seq)
|
|
|
bert_preprocessing_command += ' --masked_lm_prob=' + str(args.masked_lm_prob)
|
|
|
@@ -149,14 +152,17 @@ def main(args):
|
|
|
# This could be better optimized (fine if all take equal time)
|
|
|
if shard_id % args.n_processes == 0 and shard_id > 0:
|
|
|
bert_preprocessing_process.wait()
|
|
|
+ return last_process
|
|
|
+
|
|
|
+ output_file_prefix = args.dataset
|
|
|
|
|
|
for i in range(args.n_training_shards):
|
|
|
- create_record_worker(args.output_file_prefix + '_training', i)
|
|
|
+ last_process =create_record_worker(output_file_prefix + '_training', i)
|
|
|
|
|
|
last_process.wait()
|
|
|
|
|
|
for i in range(args.n_test_shards):
|
|
|
- create_record_worker(args.output_file_prefix + '_test', i)
|
|
|
+ last_process = create_record_worker(output_file_prefix + '_test', i)
|
|
|
|
|
|
last_process.wait()
|
|
|
|
|
|
@@ -164,17 +170,20 @@ def main(args):
|
|
|
elif args.action == 'create_hdf5_files':
|
|
|
last_process = None
|
|
|
|
|
|
+ if not os.path.exists(directory_structure['hdf5'] + "/" + args.dataset):
|
|
|
+ os.makedirs(directory_structure['hdf5'] + "/" + args.dataset)
|
|
|
+
|
|
|
def create_record_worker(filename_prefix, shard_id, output_format='hdf5'):
|
|
|
bert_preprocessing_command = 'python /workspace/bert/create_pretraining_data.py'
|
|
|
bert_preprocessing_command += ' --input_file=' + directory_structure['sharded'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.txt'
|
|
|
- bert_preprocessing_command += ' --output_file=' + directory_structure['tfrecord'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
|
|
|
+ bert_preprocessing_command += ' --output_file=' + directory_structure['hdf5'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
|
|
|
bert_preprocessing_command += ' --vocab_file=' + args.vocab_file
|
|
|
bert_preprocessing_command += ' --do_lower_case' if args.do_lower_case else ''
|
|
|
- bert_preprocessing_command += ' --max_seq_length=' + args.max_seq_length
|
|
|
- bert_preprocessing_command += ' --max_predictions_per_seq=' + args.max_predictions_per_seq
|
|
|
- bert_preprocessing_command += ' --masked_lm_prob=' + args.masked_lm_prob
|
|
|
- bert_preprocessing_command += ' --random_seed=' + args.random_seed
|
|
|
- bert_preprocessing_command += ' --dupe_factor=' + args.dupe_factor
|
|
|
+ bert_preprocessing_command += ' --max_seq_length=' + str(args.max_seq_length)
|
|
|
+ bert_preprocessing_command += ' --max_predictions_per_seq=' + str(args.max_predictions_per_seq)
|
|
|
+ bert_preprocessing_command += ' --masked_lm_prob=' + str(args.masked_lm_prob)
|
|
|
+ bert_preprocessing_command += ' --random_seed=' + str(args.random_seed)
|
|
|
+ bert_preprocessing_command += ' --dupe_factor=' + str(args.dupe_factor)
|
|
|
bert_preprocessing_process = subprocess.Popen(bert_preprocessing_command, shell=True)
|
|
|
bert_preprocessing_process.communicate()
|
|
|
|
|
|
@@ -183,14 +192,17 @@ def main(args):
|
|
|
# This could be better optimized (fine if all take equal time)
|
|
|
if shard_id % args.n_processes == 0 and shard_id > 0:
|
|
|
bert_preprocessing_process.wait()
|
|
|
+ return last_process
|
|
|
+
|
|
|
+ output_file_prefix = args.dataset
|
|
|
|
|
|
for i in range(args.n_training_shards):
|
|
|
- create_record_worker(args.output_file_prefix + '_training', i)
|
|
|
+ last_process = create_record_worker(output_file_prefix + '_training', i)
|
|
|
|
|
|
last_process.wait()
|
|
|
|
|
|
for i in range(args.n_test_shards):
|
|
|
- create_record_worker(args.output_file_prefix + '_test', i)
|
|
|
+ last_process = create_record_worker(output_file_prefix + '_test', i)
|
|
|
|
|
|
last_process.wait()
|
|
|
|
|
|
@@ -236,12 +248,6 @@ if __name__ == "__main__":
|
|
|
help='Specify the input files in a comma-separated list (no spaces)'
|
|
|
)
|
|
|
|
|
|
- parser.add_argument(
|
|
|
- '--output_file_prefix',
|
|
|
- type=str,
|
|
|
- help='Specify the naming convention (prefix) of the output files'
|
|
|
- )
|
|
|
-
|
|
|
parser.add_argument(
|
|
|
'--n_training_shards',
|
|
|
type=int,
|