run_squad.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Run BERT on SQuAD."""
  16. from __future__ import absolute_import, division, print_function
  17. import argparse
  18. import collections
  19. import json
  20. import logging
  21. import math
  22. import os
  23. import random
  24. import sys
  25. from io import open
  26. import numpy as np
  27. import torch
  28. from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
  29. TensorDataset)
  30. from torch.utils.data.distributed import DistributedSampler
  31. from tqdm import tqdm, trange
  32. from apex import amp
  33. from schedulers import LinearWarmUpScheduler
  34. from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
  35. from modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
  36. from optimization import BertAdam, warmup_linear
  37. from tokenization import (BasicTokenizer, BertTokenizer, whitespace_tokenize)
  38. from utils import is_main_process
  39. if sys.version_info[0] == 2:
  40. import cPickle as pickle
  41. else:
  42. import pickle
  43. logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
  44. datefmt='%m/%d/%Y %H:%M:%S',
  45. level=logging.INFO)
  46. logger = logging.getLogger(__name__)
  47. class SquadExample(object):
  48. """
  49. A single training/test example for the Squad dataset.
  50. For examples without an answer, the start and end position are -1.
  51. """
  52. def __init__(self,
  53. qas_id,
  54. question_text,
  55. doc_tokens,
  56. orig_answer_text=None,
  57. start_position=None,
  58. end_position=None,
  59. is_impossible=None):
  60. self.qas_id = qas_id
  61. self.question_text = question_text
  62. self.doc_tokens = doc_tokens
  63. self.orig_answer_text = orig_answer_text
  64. self.start_position = start_position
  65. self.end_position = end_position
  66. self.is_impossible = is_impossible
  67. def __str__(self):
  68. return self.__repr__()
  69. def __repr__(self):
  70. s = ""
  71. s += "qas_id: %s" % (self.qas_id)
  72. s += ", question_text: %s" % (
  73. self.question_text)
  74. s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
  75. if self.start_position:
  76. s += ", start_position: %d" % (self.start_position)
  77. if self.end_position:
  78. s += ", end_position: %d" % (self.end_position)
  79. if self.is_impossible:
  80. s += ", is_impossible: %r" % (self.is_impossible)
  81. return s
  82. class InputFeatures(object):
  83. """A single set of features of data."""
  84. def __init__(self,
  85. unique_id,
  86. example_index,
  87. doc_span_index,
  88. tokens,
  89. token_to_orig_map,
  90. token_is_max_context,
  91. input_ids,
  92. input_mask,
  93. segment_ids,
  94. start_position=None,
  95. end_position=None,
  96. is_impossible=None):
  97. self.unique_id = unique_id
  98. self.example_index = example_index
  99. self.doc_span_index = doc_span_index
  100. self.tokens = tokens
  101. self.token_to_orig_map = token_to_orig_map
  102. self.token_is_max_context = token_is_max_context
  103. self.input_ids = input_ids
  104. self.input_mask = input_mask
  105. self.segment_ids = segment_ids
  106. self.start_position = start_position
  107. self.end_position = end_position
  108. self.is_impossible = is_impossible
  109. def read_squad_examples(input_file, is_training, version_2_with_negative):
  110. """Read a SQuAD json file into a list of SquadExample."""
  111. with open(input_file, "r", encoding='utf-8') as reader:
  112. input_data = json.load(reader)["data"]
  113. def is_whitespace(c):
  114. if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
  115. return True
  116. return False
  117. examples = []
  118. for entry in input_data:
  119. for paragraph in entry["paragraphs"]:
  120. paragraph_text = paragraph["context"]
  121. doc_tokens = []
  122. char_to_word_offset = []
  123. prev_is_whitespace = True
  124. for c in paragraph_text:
  125. if is_whitespace(c):
  126. prev_is_whitespace = True
  127. else:
  128. if prev_is_whitespace:
  129. doc_tokens.append(c)
  130. else:
  131. doc_tokens[-1] += c
  132. prev_is_whitespace = False
  133. char_to_word_offset.append(len(doc_tokens) - 1)
  134. for qa in paragraph["qas"]:
  135. qas_id = qa["id"]
  136. question_text = qa["question"]
  137. start_position = None
  138. end_position = None
  139. orig_answer_text = None
  140. is_impossible = False
  141. if is_training:
  142. if version_2_with_negative:
  143. is_impossible = qa["is_impossible"]
  144. if (len(qa["answers"]) != 1) and (not is_impossible):
  145. raise ValueError(
  146. "For training, each question should have exactly 1 answer.")
  147. if not is_impossible:
  148. answer = qa["answers"][0]
  149. orig_answer_text = answer["text"]
  150. answer_offset = answer["answer_start"]
  151. answer_length = len(orig_answer_text)
  152. start_position = char_to_word_offset[answer_offset]
  153. end_position = char_to_word_offset[answer_offset + answer_length - 1]
  154. # Only add answers where the text can be exactly recovered from the
  155. # document. If this CAN'T happen it's likely due to weird Unicode
  156. # stuff so we will just skip the example.
  157. #
  158. # Note that this means for training mode, every example is NOT
  159. # guaranteed to be preserved.
  160. actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
  161. cleaned_answer_text = " ".join(
  162. whitespace_tokenize(orig_answer_text))
  163. if actual_text.find(cleaned_answer_text) == -1:
  164. logger.warning("Could not find answer: '%s' vs. '%s'",
  165. actual_text, cleaned_answer_text)
  166. continue
  167. else:
  168. start_position = -1
  169. end_position = -1
  170. orig_answer_text = ""
  171. example = SquadExample(
  172. qas_id=qas_id,
  173. question_text=question_text,
  174. doc_tokens=doc_tokens,
  175. orig_answer_text=orig_answer_text,
  176. start_position=start_position,
  177. end_position=end_position,
  178. is_impossible=is_impossible)
  179. examples.append(example)
  180. return examples
  181. def convert_examples_to_features(examples, tokenizer, max_seq_length,
  182. doc_stride, max_query_length, is_training):
  183. """Loads a data file into a list of `InputBatch`s."""
  184. unique_id = 1000000000
  185. features = []
  186. for (example_index, example) in enumerate(examples):
  187. query_tokens = tokenizer.tokenize(example.question_text)
  188. if len(query_tokens) > max_query_length:
  189. query_tokens = query_tokens[0:max_query_length]
  190. tok_to_orig_index = []
  191. orig_to_tok_index = []
  192. all_doc_tokens = []
  193. for (i, token) in enumerate(example.doc_tokens):
  194. orig_to_tok_index.append(len(all_doc_tokens))
  195. sub_tokens = tokenizer.tokenize(token)
  196. for sub_token in sub_tokens:
  197. tok_to_orig_index.append(i)
  198. all_doc_tokens.append(sub_token)
  199. tok_start_position = None
  200. tok_end_position = None
  201. if is_training and example.is_impossible:
  202. tok_start_position = -1
  203. tok_end_position = -1
  204. if is_training and not example.is_impossible:
  205. tok_start_position = orig_to_tok_index[example.start_position]
  206. if example.end_position < len(example.doc_tokens) - 1:
  207. tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
  208. else:
  209. tok_end_position = len(all_doc_tokens) - 1
  210. (tok_start_position, tok_end_position) = _improve_answer_span(
  211. all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
  212. example.orig_answer_text)
  213. # The -3 accounts for [CLS], [SEP] and [SEP]
  214. max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
  215. # We can have documents that are longer than the maximum sequence length.
  216. # To deal with this we do a sliding window approach, where we take chunks
  217. # of the up to our max length with a stride of `doc_stride`.
  218. _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
  219. "DocSpan", ["start", "length"])
  220. doc_spans = []
  221. start_offset = 0
  222. while start_offset < len(all_doc_tokens):
  223. length = len(all_doc_tokens) - start_offset
  224. if length > max_tokens_for_doc:
  225. length = max_tokens_for_doc
  226. doc_spans.append(_DocSpan(start=start_offset, length=length))
  227. if start_offset + length == len(all_doc_tokens):
  228. break
  229. start_offset += min(length, doc_stride)
  230. for (doc_span_index, doc_span) in enumerate(doc_spans):
  231. tokens = []
  232. token_to_orig_map = {}
  233. token_is_max_context = {}
  234. segment_ids = []
  235. tokens.append("[CLS]")
  236. segment_ids.append(0)
  237. for token in query_tokens:
  238. tokens.append(token)
  239. segment_ids.append(0)
  240. tokens.append("[SEP]")
  241. segment_ids.append(0)
  242. for i in range(doc_span.length):
  243. split_token_index = doc_span.start + i
  244. token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
  245. is_max_context = _check_is_max_context(doc_spans, doc_span_index,
  246. split_token_index)
  247. token_is_max_context[len(tokens)] = is_max_context
  248. tokens.append(all_doc_tokens[split_token_index])
  249. segment_ids.append(1)
  250. tokens.append("[SEP]")
  251. segment_ids.append(1)
  252. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  253. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  254. # tokens are attended to.
  255. input_mask = [1] * len(input_ids)
  256. # Zero-pad up to the sequence length.
  257. while len(input_ids) < max_seq_length:
  258. input_ids.append(0)
  259. input_mask.append(0)
  260. segment_ids.append(0)
  261. assert len(input_ids) == max_seq_length
  262. assert len(input_mask) == max_seq_length
  263. assert len(segment_ids) == max_seq_length
  264. start_position = None
  265. end_position = None
  266. if is_training and not example.is_impossible:
  267. # For training, if our document chunk does not contain an annotation
  268. # we throw it out, since there is nothing to predict.
  269. doc_start = doc_span.start
  270. doc_end = doc_span.start + doc_span.length - 1
  271. out_of_span = False
  272. if not (tok_start_position >= doc_start and
  273. tok_end_position <= doc_end):
  274. out_of_span = True
  275. if out_of_span:
  276. start_position = 0
  277. end_position = 0
  278. else:
  279. doc_offset = len(query_tokens) + 2
  280. start_position = tok_start_position - doc_start + doc_offset
  281. end_position = tok_end_position - doc_start + doc_offset
  282. if is_training and example.is_impossible:
  283. start_position = 0
  284. end_position = 0
  285. if example_index < 20:
  286. logger.info("*** Example ***")
  287. logger.info("unique_id: %s" % (unique_id))
  288. logger.info("example_index: %s" % (example_index))
  289. logger.info("doc_span_index: %s" % (doc_span_index))
  290. logger.info("tokens: %s" % " ".join(tokens))
  291. logger.info("token_to_orig_map: %s" % " ".join([
  292. "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
  293. logger.info("token_is_max_context: %s" % " ".join([
  294. "%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
  295. ]))
  296. logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
  297. logger.info(
  298. "input_mask: %s" % " ".join([str(x) for x in input_mask]))
  299. logger.info(
  300. "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
  301. if is_training and example.is_impossible:
  302. logger.info("impossible example")
  303. if is_training and not example.is_impossible:
  304. answer_text = " ".join(tokens[start_position:(end_position + 1)])
  305. logger.info("start_position: %d" % (start_position))
  306. logger.info("end_position: %d" % (end_position))
  307. logger.info(
  308. "answer: %s" % (answer_text))
  309. features.append(
  310. InputFeatures(
  311. unique_id=unique_id,
  312. example_index=example_index,
  313. doc_span_index=doc_span_index,
  314. tokens=tokens,
  315. token_to_orig_map=token_to_orig_map,
  316. token_is_max_context=token_is_max_context,
  317. input_ids=input_ids,
  318. input_mask=input_mask,
  319. segment_ids=segment_ids,
  320. start_position=start_position,
  321. end_position=end_position,
  322. is_impossible=example.is_impossible))
  323. unique_id += 1
  324. return features
  325. def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
  326. orig_answer_text):
  327. """Returns tokenized answer spans that better match the annotated answer."""
  328. # The SQuAD annotations are character based. We first project them to
  329. # whitespace-tokenized words. But then after WordPiece tokenization, we can
  330. # often find a "better match". For example:
  331. #
  332. # Question: What year was John Smith born?
  333. # Context: The leader was John Smith (1895-1943).
  334. # Answer: 1895
  335. #
  336. # The original whitespace-tokenized answer will be "(1895-1943).". However
  337. # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
  338. # the exact answer, 1895.
  339. #
  340. # However, this is not always possible. Consider the following:
  341. #
  342. # Question: What country is the top exporter of electornics?
  343. # Context: The Japanese electronics industry is the lagest in the world.
  344. # Answer: Japan
  345. #
  346. # In this case, the annotator chose "Japan" as a character sub-span of
  347. # the word "Japanese". Since our WordPiece tokenizer does not split
  348. # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
  349. # in SQuAD, but does happen.
  350. tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
  351. for new_start in range(input_start, input_end + 1):
  352. for new_end in range(input_end, new_start - 1, -1):
  353. text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
  354. if text_span == tok_answer_text:
  355. return (new_start, new_end)
  356. return (input_start, input_end)
  357. def _check_is_max_context(doc_spans, cur_span_index, position):
  358. """Check if this is the 'max context' doc span for the token."""
  359. # Because of the sliding window approach taken to scoring documents, a single
  360. # token can appear in multiple documents. E.g.
  361. # Doc: the man went to the store and bought a gallon of milk
  362. # Span A: the man went to the
  363. # Span B: to the store and bought
  364. # Span C: and bought a gallon of
  365. # ...
  366. #
  367. # Now the word 'bought' will have two scores from spans B and C. We only
  368. # want to consider the score with "maximum context", which we define as
  369. # the *minimum* of its left and right context (the *sum* of left and
  370. # right context will always be the same, of course).
  371. #
  372. # In the example the maximum context for 'bought' would be span C since
  373. # it has 1 left context and 3 right context, while span B has 4 left context
  374. # and 0 right context.
  375. best_score = None
  376. best_span_index = None
  377. for (span_index, doc_span) in enumerate(doc_spans):
  378. end = doc_span.start + doc_span.length - 1
  379. if position < doc_span.start:
  380. continue
  381. if position > end:
  382. continue
  383. num_left_context = position - doc_span.start
  384. num_right_context = end - position
  385. score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
  386. if best_score is None or score > best_score:
  387. best_score = score
  388. best_span_index = span_index
  389. return cur_span_index == best_span_index
  390. RawResult = collections.namedtuple("RawResult",
  391. ["unique_id", "start_logits", "end_logits"])
  392. def write_predictions(all_examples, all_features, all_results, n_best_size,
  393. max_answer_length, do_lower_case, output_prediction_file,
  394. output_nbest_file, output_null_log_odds_file, verbose_logging,
  395. version_2_with_negative, null_score_diff_threshold):
  396. """Write final predictions to the json file and log-odds of null if needed."""
  397. logger.info("Writing predictions to: %s" % (output_prediction_file))
  398. logger.info("Writing nbest to: %s" % (output_nbest_file))
  399. example_index_to_features = collections.defaultdict(list)
  400. for feature in all_features:
  401. example_index_to_features[feature.example_index].append(feature)
  402. unique_id_to_result = {}
  403. for result in all_results:
  404. unique_id_to_result[result.unique_id] = result
  405. _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
  406. "PrelimPrediction",
  407. ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
  408. all_predictions = collections.OrderedDict()
  409. all_nbest_json = collections.OrderedDict()
  410. scores_diff_json = collections.OrderedDict()
  411. for (example_index, example) in enumerate(all_examples):
  412. features = example_index_to_features[example_index]
  413. prelim_predictions = []
  414. # keep track of the minimum score of null start+end of position 0
  415. score_null = 1000000 # large and positive
  416. min_null_feature_index = 0 # the paragraph slice with min mull score
  417. null_start_logit = 0 # the start logit at the slice with min null score
  418. null_end_logit = 0 # the end logit at the slice with min null score
  419. for (feature_index, feature) in enumerate(features):
  420. result = unique_id_to_result[feature.unique_id]
  421. start_indexes = _get_best_indexes(result.start_logits, n_best_size)
  422. end_indexes = _get_best_indexes(result.end_logits, n_best_size)
  423. # if we could have irrelevant answers, get the min score of irrelevant
  424. if version_2_with_negative:
  425. feature_null_score = result.start_logits[0] + result.end_logits[0]
  426. if feature_null_score < score_null:
  427. score_null = feature_null_score
  428. min_null_feature_index = feature_index
  429. null_start_logit = result.start_logits[0]
  430. null_end_logit = result.end_logits[0]
  431. for start_index in start_indexes:
  432. for end_index in end_indexes:
  433. # We could hypothetically create invalid predictions, e.g., predict
  434. # that the start of the span is in the question. We throw out all
  435. # invalid predictions.
  436. if start_index >= len(feature.tokens):
  437. continue
  438. if end_index >= len(feature.tokens):
  439. continue
  440. if start_index not in feature.token_to_orig_map:
  441. continue
  442. if end_index not in feature.token_to_orig_map:
  443. continue
  444. if not feature.token_is_max_context.get(start_index, False):
  445. continue
  446. if end_index < start_index:
  447. continue
  448. length = end_index - start_index + 1
  449. if length > max_answer_length:
  450. continue
  451. prelim_predictions.append(
  452. _PrelimPrediction(
  453. feature_index=feature_index,
  454. start_index=start_index,
  455. end_index=end_index,
  456. start_logit=result.start_logits[start_index],
  457. end_logit=result.end_logits[end_index]))
  458. if version_2_with_negative:
  459. prelim_predictions.append(
  460. _PrelimPrediction(
  461. feature_index=min_null_feature_index,
  462. start_index=0,
  463. end_index=0,
  464. start_logit=null_start_logit,
  465. end_logit=null_end_logit))
  466. prelim_predictions = sorted(
  467. prelim_predictions,
  468. key=lambda x: (x.start_logit + x.end_logit),
  469. reverse=True)
  470. _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
  471. "NbestPrediction", ["text", "start_logit", "end_logit"])
  472. seen_predictions = {}
  473. nbest = []
  474. for pred in prelim_predictions:
  475. if len(nbest) >= n_best_size:
  476. break
  477. feature = features[pred.feature_index]
  478. if pred.start_index > 0: # this is a non-null prediction
  479. tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
  480. orig_doc_start = feature.token_to_orig_map[pred.start_index]
  481. orig_doc_end = feature.token_to_orig_map[pred.end_index]
  482. orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
  483. tok_text = " ".join(tok_tokens)
  484. # De-tokenize WordPieces that have been split off.
  485. tok_text = tok_text.replace(" ##", "")
  486. tok_text = tok_text.replace("##", "")
  487. # Clean whitespace
  488. tok_text = tok_text.strip()
  489. tok_text = " ".join(tok_text.split())
  490. orig_text = " ".join(orig_tokens)
  491. final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
  492. if final_text in seen_predictions:
  493. continue
  494. seen_predictions[final_text] = True
  495. else:
  496. final_text = ""
  497. seen_predictions[final_text] = True
  498. nbest.append(
  499. _NbestPrediction(
  500. text=final_text,
  501. start_logit=pred.start_logit,
  502. end_logit=pred.end_logit))
  503. # if we didn't include the empty option in the n-best, include it
  504. if version_2_with_negative:
  505. if "" not in seen_predictions:
  506. nbest.append(
  507. _NbestPrediction(
  508. text="",
  509. start_logit=null_start_logit,
  510. end_logit=null_end_logit))
  511. # In very rare edge cases we could only have single null prediction.
  512. # So we just create a nonce prediction in this case to avoid failure.
  513. if len(nbest) == 1:
  514. nbest.insert(0,
  515. _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
  516. # In very rare edge cases we could have no valid predictions. So we
  517. # just create a nonce prediction in this case to avoid failure.
  518. if not nbest:
  519. nbest.append(
  520. _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
  521. assert len(nbest) >= 1
  522. total_scores = []
  523. best_non_null_entry = None
  524. for entry in nbest:
  525. total_scores.append(entry.start_logit + entry.end_logit)
  526. if not best_non_null_entry:
  527. if entry.text:
  528. best_non_null_entry = entry
  529. probs = _compute_softmax(total_scores)
  530. nbest_json = []
  531. for (i, entry) in enumerate(nbest):
  532. output = collections.OrderedDict()
  533. output["text"] = entry.text
  534. output["probability"] = probs[i]
  535. output["start_logit"] = entry.start_logit
  536. output["end_logit"] = entry.end_logit
  537. nbest_json.append(output)
  538. assert len(nbest_json) >= 1
  539. if not version_2_with_negative:
  540. all_predictions[example.qas_id] = nbest_json[0]["text"]
  541. else:
  542. # predict "" iff the null score - the score of best non-null > threshold
  543. score_diff = score_null - best_non_null_entry.start_logit - (
  544. best_non_null_entry.end_logit)
  545. scores_diff_json[example.qas_id] = score_diff
  546. if score_diff > null_score_diff_threshold:
  547. all_predictions[example.qas_id] = ""
  548. else:
  549. all_predictions[example.qas_id] = best_non_null_entry.text
  550. all_nbest_json[example.qas_id] = nbest_json
  551. with open(output_prediction_file, "w") as writer:
  552. writer.write(json.dumps(all_predictions, indent=4) + "\n")
  553. with open(output_nbest_file, "w") as writer:
  554. writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
  555. if version_2_with_negative:
  556. with open(output_null_log_odds_file, "w") as writer:
  557. writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
  558. def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
  559. """Project the tokenized prediction back to the original text."""
  560. # When we created the data, we kept track of the alignment between original
  561. # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
  562. # now `orig_text` contains the span of our original text corresponding to the
  563. # span that we predicted.
  564. #
  565. # However, `orig_text` may contain extra characters that we don't want in
  566. # our prediction.
  567. #
  568. # For example, let's say:
  569. # pred_text = steve smith
  570. # orig_text = Steve Smith's
  571. #
  572. # We don't want to return `orig_text` because it contains the extra "'s".
  573. #
  574. # We don't want to return `pred_text` because it's already been normalized
  575. # (the SQuAD eval script also does punctuation stripping/lower casing but
  576. # our tokenizer does additional normalization like stripping accent
  577. # characters).
  578. #
  579. # What we really want to return is "Steve Smith".
  580. #
  581. # Therefore, we have to apply a semi-complicated alignment heruistic between
  582. # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
  583. # can fail in certain cases in which case we just return `orig_text`.
  584. def _strip_spaces(text):
  585. ns_chars = []
  586. ns_to_s_map = collections.OrderedDict()
  587. for (i, c) in enumerate(text):
  588. if c == " ":
  589. continue
  590. ns_to_s_map[len(ns_chars)] = i
  591. ns_chars.append(c)
  592. ns_text = "".join(ns_chars)
  593. return (ns_text, ns_to_s_map)
  594. # We first tokenize `orig_text`, strip whitespace from the result
  595. # and `pred_text`, and check if they are the same length. If they are
  596. # NOT the same length, the heuristic has failed. If they are the same
  597. # length, we assume the characters are one-to-one aligned.
  598. tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
  599. tok_text = " ".join(tokenizer.tokenize(orig_text))
  600. start_position = tok_text.find(pred_text)
  601. if start_position == -1:
  602. if verbose_logging:
  603. logger.info(
  604. "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
  605. return orig_text
  606. end_position = start_position + len(pred_text) - 1
  607. (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  608. (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
  609. if len(orig_ns_text) != len(tok_ns_text):
  610. if verbose_logging:
  611. logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
  612. orig_ns_text, tok_ns_text)
  613. return orig_text
  614. # We then project the characters in `pred_text` back to `orig_text` using
  615. # the character-to-character alignment.
  616. tok_s_to_ns_map = {}
  617. for (i, tok_index) in tok_ns_to_s_map.items():
  618. tok_s_to_ns_map[tok_index] = i
  619. orig_start_position = None
  620. if start_position in tok_s_to_ns_map:
  621. ns_start_position = tok_s_to_ns_map[start_position]
  622. if ns_start_position in orig_ns_to_s_map:
  623. orig_start_position = orig_ns_to_s_map[ns_start_position]
  624. if orig_start_position is None:
  625. if verbose_logging:
  626. logger.info("Couldn't map start position")
  627. return orig_text
  628. orig_end_position = None
  629. if end_position in tok_s_to_ns_map:
  630. ns_end_position = tok_s_to_ns_map[end_position]
  631. if ns_end_position in orig_ns_to_s_map:
  632. orig_end_position = orig_ns_to_s_map[ns_end_position]
  633. if orig_end_position is None:
  634. if verbose_logging:
  635. logger.info("Couldn't map end position")
  636. return orig_text
  637. output_text = orig_text[orig_start_position:(orig_end_position + 1)]
  638. return output_text
  639. def _get_best_indexes(logits, n_best_size):
  640. """Get the n-best logits from a list."""
  641. index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
  642. best_indexes = []
  643. for i in range(len(index_and_score)):
  644. if i >= n_best_size:
  645. break
  646. best_indexes.append(index_and_score[i][0])
  647. return best_indexes
  648. def _compute_softmax(scores):
  649. """Compute softmax probability over raw logits."""
  650. if not scores:
  651. return []
  652. max_score = None
  653. for score in scores:
  654. if max_score is None or score > max_score:
  655. max_score = score
  656. exp_scores = []
  657. total_sum = 0.0
  658. for score in scores:
  659. x = math.exp(score - max_score)
  660. exp_scores.append(x)
  661. total_sum += x
  662. probs = []
  663. for score in exp_scores:
  664. probs.append(score / total_sum)
  665. return probs
  666. def main():
  667. parser = argparse.ArgumentParser()
  668. ## Required parameters
  669. parser.add_argument("--bert_model", default=None, type=str, required=True,
  670. help="Bert pre-trained model selected in the list: bert-base-uncased, "
  671. "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
  672. "bert-base-multilingual-cased, bert-base-chinese.")
  673. parser.add_argument("--output_dir", default=None, type=str, required=True,
  674. help="The output directory where the model checkpoints and predictions will be written.")
  675. parser.add_argument("--init_checkpoint",
  676. default=None,
  677. type=str,
  678. required=True,
  679. help="The checkpoint file from pretraining")
  680. ## Other parameters
  681. parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
  682. parser.add_argument("--predict_file", default=None, type=str,
  683. help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
  684. parser.add_argument("--max_seq_length", default=384, type=int,
  685. help="The maximum total input sequence length after WordPiece tokenization. Sequences "
  686. "longer than this will be truncated, and sequences shorter than this will be padded.")
  687. parser.add_argument("--doc_stride", default=128, type=int,
  688. help="When splitting up a long document into chunks, how much stride to take between chunks.")
  689. parser.add_argument("--max_query_length", default=64, type=int,
  690. help="The maximum number of tokens for the question. Questions longer than this will "
  691. "be truncated to this length.")
  692. parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
  693. parser.add_argument("--old", action='store_true', help="use old fp16 optimizer")
  694. parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
  695. parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
  696. parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
  697. parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
  698. parser.add_argument("--num_train_epochs", default=3.0, type=float,
  699. help="Total number of training epochs to perform.")
  700. parser.add_argument("--max_steps", default=-1.0, type=float,
  701. help="Total number of training steps to perform.")
  702. parser.add_argument("--warmup_proportion", default=0.1, type=float,
  703. help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
  704. "of training.")
  705. parser.add_argument("--n_best_size", default=20, type=int,
  706. help="The total number of n-best predictions to generate in the nbest_predictions.json "
  707. "output file.")
  708. parser.add_argument("--max_answer_length", default=30, type=int,
  709. help="The maximum length of an answer that can be generated. This is needed because the start "
  710. "and end predictions are not conditioned on one another.")
  711. parser.add_argument("--verbose_logging", action='store_true',
  712. help="If true, all of the warnings related to data processing will be printed. "
  713. "A number of warnings are expected for a normal SQuAD evaluation.")
  714. parser.add_argument("--no_cuda",
  715. action='store_true',
  716. help="Whether not to use CUDA when available")
  717. parser.add_argument('--seed',
  718. type=int,
  719. default=42,
  720. help="random seed for initialization")
  721. parser.add_argument('--gradient_accumulation_steps',
  722. type=int,
  723. default=1,
  724. help="Number of updates steps to accumulate before performing a backward/update pass.")
  725. parser.add_argument("--do_lower_case",
  726. action='store_true',
  727. help="Whether to lower case the input text. True for uncased models, False for cased models.")
  728. parser.add_argument("--local_rank",
  729. type=int,
  730. default=-1,
  731. help="local_rank for distributed training on gpus")
  732. parser.add_argument('--fp16',
  733. action='store_true',
  734. help="Whether to use 16-bit float precision instead of 32-bit")
  735. parser.add_argument('--loss_scale',
  736. type=float, default=0,
  737. help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
  738. "0 (default value): dynamic loss scaling.\n"
  739. "Positive power of 2: static loss scaling value.\n")
  740. parser.add_argument('--version_2_with_negative',
  741. action='store_true',
  742. help='If true, the SQuAD examples contain some that do not have an answer.')
  743. parser.add_argument('--null_score_diff_threshold',
  744. type=float, default=0.0,
  745. help="If null_score - best_non_null is greater than the threshold predict null.")
  746. parser.add_argument('--vocab_file',
  747. type=str, default=None, required=True,
  748. help="Vocabulary mapping/file BERT was pretrainined on")
  749. parser.add_argument("--config_file",
  750. default=None,
  751. type=str,
  752. required=True,
  753. help="The BERT model config")
  754. parser.add_argument('--log_freq',
  755. type=int, default=50,
  756. help='frequency of logging loss.')
  757. args = parser.parse_args()
  758. if args.local_rank == -1 or args.no_cuda:
  759. device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
  760. n_gpu = torch.cuda.device_count()
  761. else:
  762. torch.cuda.set_device(args.local_rank)
  763. device = torch.device("cuda", args.local_rank)
  764. n_gpu = 1
  765. # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
  766. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  767. logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
  768. device, n_gpu, bool(args.local_rank != -1), args.fp16))
  769. if args.gradient_accumulation_steps < 1:
  770. raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
  771. args.gradient_accumulation_steps))
  772. args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
  773. random.seed(args.seed)
  774. np.random.seed(args.seed)
  775. torch.manual_seed(args.seed)
  776. if n_gpu > 0:
  777. torch.cuda.manual_seed_all(args.seed)
  778. if not args.do_train and not args.do_predict:
  779. raise ValueError("At least one of `do_train` or `do_predict` must be True.")
  780. if args.do_train:
  781. if not args.train_file:
  782. raise ValueError(
  783. "If `do_train` is True, then `train_file` must be specified.")
  784. if args.do_predict:
  785. if not args.predict_file:
  786. raise ValueError(
  787. "If `do_predict` is True, then `predict_file` must be specified.")
  788. if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and os.listdir(args.output_dir)!=['logfile.txt']:
  789. print("WARNING: Output directory {} already exists and is not empty.".format(args.output_dir), os.listdir(args.output_dir))
  790. if not os.path.exists(args.output_dir):
  791. os.makedirs(args.output_dir)
  792. tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
  793. # tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
  794. train_examples = None
  795. num_train_optimization_steps = None
  796. if args.do_train:
  797. train_examples = read_squad_examples(
  798. input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
  799. num_train_optimization_steps = int(
  800. len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
  801. if args.local_rank != -1:
  802. num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
  803. # Prepare model
  804. config = BertConfig.from_json_file(args.config_file)
  805. # Padding for divisibility by 8
  806. if config.vocab_size % 8 != 0:
  807. config.vocab_size += 8 - (config.vocab_size % 8)
  808. model = BertForQuestionAnswering(config)
  809. # model = BertForQuestionAnswering.from_pretrained(args.bert_model,
  810. # cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
  811. if is_main_process():
  812. print("LOADING CHECKOINT")
  813. model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')["model"], strict=False)
  814. if is_main_process():
  815. print("LOADED CHECKPOINT")
  816. model.to(device)
  817. if args.fp16 and args.old:
  818. model.half()
  819. # Prepare optimizer
  820. param_optimizer = list(model.named_parameters())
  821. # hack to remove pooler, which is not used
  822. # thus it produce None grad that break apex
  823. param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
  824. no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  825. optimizer_grouped_parameters = [
  826. {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  827. {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  828. ]
  829. if args.do_train:
  830. if args.fp16:
  831. try:
  832. # from fused_adam_local import FusedAdamBert as FusedAdam
  833. from apex.optimizers import FusedAdam
  834. from apex.optimizers import FP16_Optimizer
  835. except ImportError:
  836. raise ImportError(
  837. "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
  838. # import ipdb; ipdb.set_trace()
  839. optimizer = FusedAdam(optimizer_grouped_parameters,
  840. lr=args.learning_rate,
  841. bias_correction=False,
  842. max_grad_norm=1.0)
  843. if args.loss_scale == 0:
  844. if args.old:
  845. optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
  846. else:
  847. model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False,
  848. loss_scale="dynamic")
  849. else:
  850. if args.old:
  851. optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
  852. else:
  853. model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale=args.loss_scale)
  854. if not args.old and args.do_train:
  855. scheduler = LinearWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=num_train_optimization_steps)
  856. else:
  857. optimizer = BertAdam(optimizer_grouped_parameters,
  858. lr=args.learning_rate,
  859. warmup=args.warmup_proportion,
  860. t_total=num_train_optimization_steps)
  861. #print(model)
  862. if args.local_rank != -1:
  863. try:
  864. from apex.parallel import DistributedDataParallel as DDP
  865. except ImportError:
  866. raise ImportError(
  867. "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
  868. model = DDP(model)
  869. elif n_gpu > 1:
  870. model = torch.nn.DataParallel(model)
  871. global_step = 0
  872. if args.do_train:
  873. cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format(
  874. list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride),
  875. str(args.max_query_length))
  876. train_features = None
  877. try:
  878. with open(cached_train_features_file, "rb") as reader:
  879. train_features = pickle.load(reader)
  880. except:
  881. train_features = convert_examples_to_features(
  882. examples=train_examples,
  883. tokenizer=tokenizer,
  884. max_seq_length=args.max_seq_length,
  885. doc_stride=args.doc_stride,
  886. max_query_length=args.max_query_length,
  887. is_training=True)
  888. if args.local_rank == -1 or torch.distributed.get_rank() == 0:
  889. logger.info(" Saving train features into cached file %s", cached_train_features_file)
  890. with open(cached_train_features_file, "wb") as writer:
  891. pickle.dump(train_features, writer)
  892. logger.info("***** Running training *****")
  893. logger.info(" Num orig examples = %d", len(train_examples))
  894. logger.info(" Num split examples = %d", len(train_features))
  895. logger.info(" Batch size = %d", args.train_batch_size)
  896. logger.info(" Num steps = %d", num_train_optimization_steps)
  897. all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
  898. all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
  899. all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
  900. all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
  901. all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
  902. train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
  903. all_start_positions, all_end_positions)
  904. if args.local_rank == -1:
  905. train_sampler = RandomSampler(train_data)
  906. else:
  907. train_sampler = DistributedSampler(train_data)
  908. train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
  909. model.train()
  910. for _ in trange(int(args.num_train_epochs), desc="Epoch"):
  911. for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
  912. # Terminate early for benchmarking
  913. if args.max_steps > 0 and global_step > args.max_steps:
  914. break
  915. if n_gpu == 1:
  916. batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
  917. input_ids, input_mask, segment_ids, start_positions, end_positions = batch
  918. loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
  919. if n_gpu > 1:
  920. loss = loss.mean() # mean() to average on multi-gpu.
  921. if args.gradient_accumulation_steps > 1:
  922. loss = loss / args.gradient_accumulation_steps
  923. if args.fp16:
  924. if args.old:
  925. optimizer.backward(loss)
  926. else:
  927. with amp.scale_loss(loss, optimizer) as scaled_loss:
  928. scaled_loss.backward()
  929. else:
  930. loss.backward()
  931. # if args.fp16:
  932. # optimizer.backward(loss)
  933. # else:
  934. # loss.backward()
  935. if (step + 1) % args.gradient_accumulation_steps == 0:
  936. if args.fp16 :
  937. # modify learning rate with special warm up for BERT which FusedAdam doesn't do
  938. if not args.old:
  939. scheduler.step()
  940. else:
  941. lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
  942. for param_group in optimizer.param_groups:
  943. param_group['lr'] = lr_this_step
  944. optimizer.step()
  945. optimizer.zero_grad()
  946. global_step += 1
  947. if step % args.log_freq == 0:
  948. # logger.info("Step {}: Loss {}, LR {} ".format(global_step, loss.item(), lr_this_step))
  949. logger.info(
  950. "Step {}: Loss {}, LR {} ".format(global_step, loss.item(), optimizer.param_groups[0]['lr']))
  951. if args.do_train:
  952. # Save a trained model and the associated configuration
  953. model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
  954. output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
  955. torch.save(model_to_save.state_dict(), output_model_file)
  956. output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
  957. with open(output_config_file, 'w') as f:
  958. f.write(model_to_save.config.to_json_string())
  959. # # Load a trained model and config that you have fine-tuned
  960. # config = BertConfig(output_config_file)
  961. # model = BertForQuestionAnswering(config)
  962. # model.load_state_dict(torch.load(output_model_file))
  963. # else:
  964. # model = BertForQuestionAnswering.from_pretrained(args.bert_model)
  965. if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
  966. if not args.do_train and args.fp16:
  967. model.half()
  968. eval_examples = read_squad_examples(
  969. input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
  970. eval_features = convert_examples_to_features(
  971. examples=eval_examples,
  972. tokenizer=tokenizer,
  973. max_seq_length=args.max_seq_length,
  974. doc_stride=args.doc_stride,
  975. max_query_length=args.max_query_length,
  976. is_training=False)
  977. logger.info("***** Running predictions *****")
  978. logger.info(" Num orig examples = %d", len(eval_examples))
  979. logger.info(" Num split examples = %d", len(eval_features))
  980. logger.info(" Batch size = %d", args.predict_batch_size)
  981. all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
  982. all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
  983. all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
  984. all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
  985. eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
  986. # Run prediction for full data
  987. eval_sampler = SequentialSampler(eval_data)
  988. eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
  989. model.eval()
  990. all_results = []
  991. logger.info("Start evaluating")
  992. for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
  993. if len(all_results) % 1000 == 0:
  994. logger.info("Processing example: %d" % (len(all_results)))
  995. input_ids = input_ids.to(device)
  996. input_mask = input_mask.to(device)
  997. segment_ids = segment_ids.to(device)
  998. with torch.no_grad():
  999. batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
  1000. for i, example_index in enumerate(example_indices):
  1001. start_logits = batch_start_logits[i].detach().cpu().tolist()
  1002. end_logits = batch_end_logits[i].detach().cpu().tolist()
  1003. eval_feature = eval_features[example_index.item()]
  1004. unique_id = int(eval_feature.unique_id)
  1005. all_results.append(RawResult(unique_id=unique_id,
  1006. start_logits=start_logits,
  1007. end_logits=end_logits))
  1008. output_prediction_file = os.path.join(args.output_dir, "predictions.json")
  1009. output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
  1010. output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
  1011. write_predictions(eval_examples, eval_features, all_results,
  1012. args.n_best_size, args.max_answer_length,
  1013. args.do_lower_case, output_prediction_file,
  1014. output_nbest_file, output_null_log_odds_file, args.verbose_logging,
  1015. args.version_2_with_negative, args.null_score_diff_threshold)
  1016. if __name__ == "__main__":
  1017. main()