TextSharding.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software
  9. # distributed under the License is distributed on an "AS IS" BASIS,
  10. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. # See the License for the specific language governing permissions and
  12. # limitations under the License.
  13. from collections import defaultdict
  14. from itertools import islice
  15. import multiprocessing
  16. import statistics
  17. class Sharding:
  18. def __init__(self, input_files, output_name_prefix, n_training_shards, n_test_shards, fraction_test_set):
  19. assert len(input_files) > 0, 'The input file list must contain at least one file.'
  20. assert n_training_shards > 0, 'There must be at least one output shard.'
  21. assert n_test_shards > 0, 'There must be at least one output shard.'
  22. self.n_training_shards = n_training_shards
  23. self.n_test_shards = n_test_shards
  24. self.fraction_test_set = fraction_test_set
  25. self.input_files = input_files
  26. self.output_name_prefix = output_name_prefix
  27. self.output_training_identifier = '_training'
  28. self.output_test_identifier = '_test'
  29. self.output_file_extension = '.txt'
  30. self.articles = {} # key: integer identifier, value: list of articles
  31. self.sentences = {} # key: integer identifier, value: list of sentences
  32. self.output_training_files = {} # key: filename, value: list of articles to go into file
  33. self.output_test_files = {} # key: filename, value: list of articles to go into file
  34. self.init_output_files()
  35. # Remember, the input files contain one article per line (the whitespace check is to skip extraneous blank lines)
  36. def load_articles(self):
  37. print('Start: Loading Articles')
  38. global_article_count = 0
  39. for input_file in self.input_files:
  40. print('input file:', input_file)
  41. with open(input_file, mode='r', newline='\n') as f:
  42. for i, line in enumerate(f):
  43. if line.strip():
  44. self.articles[global_article_count] = line.rstrip()
  45. global_article_count += 1
  46. print('End: Loading Articles: There are', len(self.articles), 'articles.')
  47. def segment_articles_into_sentences(self, segmenter):
  48. print('Start: Sentence Segmentation')
  49. if len(self.articles) is 0:
  50. self.load_articles()
  51. assert len(self.articles) is not 0, 'Please check that input files are present and contain data.'
  52. # TODO: WIP: multiprocessing (create independent ranges and spawn processes)
  53. use_multiprocessing = 'serial'
  54. def chunks(data, size=len(self.articles)):
  55. it = iter(data)
  56. for i in range(0, len(data), size):
  57. yield {k: data[k] for k in islice(it, size)}
  58. if use_multiprocessing == 'manager':
  59. manager = multiprocessing.Manager()
  60. return_dict = manager.dict()
  61. jobs = []
  62. n_processes = 7 # in addition to the main process, total = n_proc+1
  63. def work(articles, return_dict):
  64. sentences = {}
  65. for i, article in enumerate(articles):
  66. sentences[i] = segmenter.segment_string(articles[article])
  67. if i % 5000 == 0:
  68. print('Segmenting article', i)
  69. return_dict.update(sentences)
  70. for item in chunks(self.articles, len(self.articles)):
  71. p = multiprocessing.Process(target=work, args=(item, return_dict))
  72. # Busy wait
  73. while len(jobs) >= n_processes:
  74. pass
  75. jobs.append(p)
  76. p.start()
  77. for proc in jobs:
  78. proc.join()
  79. elif use_multiprocessing == 'queue':
  80. work_queue = multiprocessing.Queue()
  81. jobs = []
  82. for item in chunks(self.articles, len(self.articles)):
  83. pass
  84. else: # serial option
  85. for i, article in enumerate(self.articles):
  86. self.sentences[i] = segmenter.segment_string(self.articles[article])
  87. if i % 5000 == 0:
  88. print('Segmenting article', i)
  89. print('End: Sentence Segmentation')
  90. def init_output_files(self):
  91. print('Start: Init Output Files')
  92. assert len(self.output_training_files) is 0, 'Internal storage self.output_files already contains data. This function is intended to be used by the constructor only.'
  93. assert len(self.output_test_files) is 0, 'Internal storage self.output_files already contains data. This function is intended to be used by the constructor only.'
  94. for i in range(self.n_training_shards):
  95. name = self.output_name_prefix + self.output_training_identifier + '_' + str(i) + self.output_file_extension
  96. self.output_training_files[name] = []
  97. for i in range(self.n_test_shards):
  98. name = self.output_name_prefix + self.output_test_identifier + '_' + str(i) + self.output_file_extension
  99. self.output_test_files[name] = []
  100. print('End: Init Output Files')
  101. def get_sentences_per_shard(self, shard):
  102. result = 0
  103. for article_id in shard:
  104. result += len(self.sentences[article_id])
  105. return result
  106. def distribute_articles_over_shards(self):
  107. print('Start: Distribute Articles Over Shards')
  108. assert len(self.articles) >= self.n_training_shards + self.n_test_shards, 'There are fewer articles than shards. Please add more data or reduce the number of shards requested.'
  109. # Create dictionary with - key: sentence count per article, value: article id number
  110. sentence_counts = defaultdict(lambda: [])
  111. max_sentences = 0
  112. total_sentences = 0
  113. for article_id in self.sentences:
  114. current_length = len(self.sentences[article_id])
  115. sentence_counts[current_length].append(article_id)
  116. max_sentences = max(max_sentences, current_length)
  117. total_sentences += current_length
  118. n_sentences_assigned_to_training = int((1 - self.fraction_test_set) * total_sentences)
  119. nominal_sentences_per_training_shard = n_sentences_assigned_to_training // self.n_training_shards
  120. nominal_sentences_per_test_shard = (total_sentences - n_sentences_assigned_to_training) // self.n_test_shards
  121. consumed_article_set = set({})
  122. unused_article_set = set(self.articles.keys())
  123. # Make first pass and add one article worth of lines per file
  124. for file in self.output_training_files:
  125. current_article_id = sentence_counts[max_sentences][-1]
  126. sentence_counts[max_sentences].pop(-1)
  127. self.output_training_files[file].append(current_article_id)
  128. consumed_article_set.add(current_article_id)
  129. unused_article_set.remove(current_article_id)
  130. # Maintain the max sentence count
  131. while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
  132. max_sentences -= 1
  133. if len(self.sentences[current_article_id]) > nominal_sentences_per_training_shard:
  134. nominal_sentences_per_training_shard = len(self.sentences[current_article_id])
  135. print('Warning: A single article contains more than the nominal number of sentences per training shard.')
  136. for file in self.output_test_files:
  137. current_article_id = sentence_counts[max_sentences][-1]
  138. sentence_counts[max_sentences].pop(-1)
  139. self.output_test_files[file].append(current_article_id)
  140. consumed_article_set.add(current_article_id)
  141. unused_article_set.remove(current_article_id)
  142. # Maintain the max sentence count
  143. while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
  144. max_sentences -= 1
  145. if len(self.sentences[current_article_id]) > nominal_sentences_per_test_shard:
  146. nominal_sentences_per_test_shard = len(self.sentences[current_article_id])
  147. print('Warning: A single article contains more than the nominal number of sentences per test shard.')
  148. training_counts = []
  149. test_counts = []
  150. for shard in self.output_training_files:
  151. training_counts.append(self.get_sentences_per_shard(self.output_training_files[shard]))
  152. for shard in self.output_test_files:
  153. test_counts.append(self.get_sentences_per_shard(self.output_test_files[shard]))
  154. training_median = statistics.median(training_counts)
  155. test_median = statistics.median(test_counts)
  156. # Make subsequent passes over files to find articles to add without going over limit
  157. history_remaining = []
  158. n_history_remaining = 4
  159. while len(consumed_article_set) < len(self.articles):
  160. for fidx, file in enumerate(self.output_training_files):
  161. nominal_next_article_size = min(nominal_sentences_per_training_shard - training_counts[fidx], max_sentences)
  162. # Maintain the max sentence count
  163. while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
  164. max_sentences -= 1
  165. while len(sentence_counts[nominal_next_article_size]) == 0 and nominal_next_article_size > 0:
  166. nominal_next_article_size -= 1
  167. if nominal_next_article_size not in sentence_counts or nominal_next_article_size is 0 or training_counts[fidx] > training_median:
  168. continue # skip adding to this file, will come back later if no file can accept unused articles
  169. current_article_id = sentence_counts[nominal_next_article_size][-1]
  170. sentence_counts[nominal_next_article_size].pop(-1)
  171. self.output_training_files[file].append(current_article_id)
  172. consumed_article_set.add(current_article_id)
  173. unused_article_set.remove(current_article_id)
  174. for fidx, file in enumerate(self.output_test_files):
  175. nominal_next_article_size = min(nominal_sentences_per_test_shard - test_counts[fidx], max_sentences)
  176. # Maintain the max sentence count
  177. while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
  178. max_sentences -= 1
  179. while len(sentence_counts[nominal_next_article_size]) == 0 and nominal_next_article_size > 0:
  180. nominal_next_article_size -= 1
  181. if nominal_next_article_size not in sentence_counts or nominal_next_article_size is 0 or test_counts[fidx] > test_median:
  182. continue # skip adding to this file, will come back later if no file can accept unused articles
  183. current_article_id = sentence_counts[nominal_next_article_size][-1]
  184. sentence_counts[nominal_next_article_size].pop(-1)
  185. self.output_test_files[file].append(current_article_id)
  186. consumed_article_set.add(current_article_id)
  187. unused_article_set.remove(current_article_id)
  188. # If unable to place articles a few times, bump up nominal sizes by fraction until articles get placed
  189. if len(history_remaining) == n_history_remaining:
  190. history_remaining.pop(0)
  191. history_remaining.append(len(unused_article_set))
  192. history_same = True
  193. for i in range(1, len(history_remaining)):
  194. history_same = history_same and (history_remaining[i-1] == history_remaining[i])
  195. if history_same:
  196. nominal_sentences_per_training_shard += 1
  197. # nominal_sentences_per_test_shard += 1
  198. training_counts = []
  199. test_counts = []
  200. for shard in self.output_training_files:
  201. training_counts.append(self.get_sentences_per_shard(self.output_training_files[shard]))
  202. for shard in self.output_test_files:
  203. test_counts.append(self.get_sentences_per_shard(self.output_test_files[shard]))
  204. training_median = statistics.median(training_counts)
  205. test_median = statistics.median(test_counts)
  206. print('Distributing data over shards:', len(unused_article_set), 'articles remaining.')
  207. if len(unused_article_set) != 0:
  208. print('Warning: Some articles did not make it into output files.')
  209. for shard in self.output_training_files:
  210. print('Training shard:', self.get_sentences_per_shard(self.output_training_files[shard]))
  211. for shard in self.output_test_files:
  212. print('Test shard:', self.get_sentences_per_shard(self.output_test_files[shard]))
  213. print('End: Distribute Articles Over Shards')
  214. def write_shards_to_disk(self):
  215. print('Start: Write Shards to Disk')
  216. for shard in self.output_training_files:
  217. self.write_single_shard(shard, self.output_training_files[shard])
  218. for shard in self.output_test_files:
  219. self.write_single_shard(shard, self.output_test_files[shard])
  220. print('End: Write Shards to Disk')
  221. def write_single_shard(self, shard_name, shard):
  222. with open(shard_name, mode='w', newline='\n') as f:
  223. for article_id in shard:
  224. for line in self.sentences[article_id]:
  225. f.write(line + '\n')
  226. f.write('\n') # Line break between articles
  227. import nltk
  228. nltk.download('punkt')
  229. class NLTKSegmenter:
  230. def __init(self):
  231. pass
  232. def segment_string(self, article):
  233. return nltk.tokenize.sent_tokenize(article)