extract_features.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Extract pre-computed feature vectors from a PyTorch BERT model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import argparse
  20. import collections
  21. import logging
  22. import json
  23. import re
  24. import torch
  25. from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
  26. from torch.utils.data.distributed import DistributedSampler
  27. from tokenization import BertTokenizer
  28. from modeling import BertModel
  29. logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
  30. datefmt = '%m/%d/%Y %H:%M:%S',
  31. level = logging.INFO)
  32. logger = logging.getLogger(__name__)
  33. class InputExample(object):
  34. def __init__(self, unique_id, text_a, text_b):
  35. self.unique_id = unique_id
  36. self.text_a = text_a
  37. self.text_b = text_b
  38. class InputFeatures(object):
  39. """A single set of features of data."""
  40. def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
  41. self.unique_id = unique_id
  42. self.tokens = tokens
  43. self.input_ids = input_ids
  44. self.input_mask = input_mask
  45. self.input_type_ids = input_type_ids
  46. def convert_examples_to_features(examples, seq_length, tokenizer):
  47. """Loads a data file into a list of `InputBatch`s."""
  48. features = []
  49. for (ex_index, example) in enumerate(examples):
  50. tokens_a = tokenizer.tokenize(example.text_a)
  51. tokens_b = None
  52. if example.text_b:
  53. tokens_b = tokenizer.tokenize(example.text_b)
  54. if tokens_b:
  55. # Modifies `tokens_a` and `tokens_b` in place so that the total
  56. # length is less than the specified length.
  57. # Account for [CLS], [SEP], [SEP] with "- 3"
  58. _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
  59. else:
  60. # Account for [CLS] and [SEP] with "- 2"
  61. if len(tokens_a) > seq_length - 2:
  62. tokens_a = tokens_a[0:(seq_length - 2)]
  63. # The convention in BERT is:
  64. # (a) For sequence pairs:
  65. # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  66. # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  67. # (b) For single sequences:
  68. # tokens: [CLS] the dog is hairy . [SEP]
  69. # type_ids: 0 0 0 0 0 0 0
  70. #
  71. # Where "type_ids" are used to indicate whether this is the first
  72. # sequence or the second sequence. The embedding vectors for `type=0` and
  73. # `type=1` were learned during pre-training and are added to the wordpiece
  74. # embedding vector (and position vector). This is not *strictly* necessary
  75. # since the [SEP] token unambigiously separates the sequences, but it makes
  76. # it easier for the model to learn the concept of sequences.
  77. #
  78. # For classification tasks, the first vector (corresponding to [CLS]) is
  79. # used as as the "sentence vector". Note that this only makes sense because
  80. # the entire model is fine-tuned.
  81. tokens = []
  82. input_type_ids = []
  83. tokens.append("[CLS]")
  84. input_type_ids.append(0)
  85. for token in tokens_a:
  86. tokens.append(token)
  87. input_type_ids.append(0)
  88. tokens.append("[SEP]")
  89. input_type_ids.append(0)
  90. if tokens_b:
  91. for token in tokens_b:
  92. tokens.append(token)
  93. input_type_ids.append(1)
  94. tokens.append("[SEP]")
  95. input_type_ids.append(1)
  96. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  97. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  98. # tokens are attended to.
  99. input_mask = [1] * len(input_ids)
  100. # Zero-pad up to the sequence length.
  101. while len(input_ids) < seq_length:
  102. input_ids.append(0)
  103. input_mask.append(0)
  104. input_type_ids.append(0)
  105. assert len(input_ids) == seq_length
  106. assert len(input_mask) == seq_length
  107. assert len(input_type_ids) == seq_length
  108. if ex_index < 5:
  109. logger.info("*** Example ***")
  110. logger.info("unique_id: %s" % (example.unique_id))
  111. logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
  112. logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
  113. logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
  114. logger.info(
  115. "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
  116. features.append(
  117. InputFeatures(
  118. unique_id=example.unique_id,
  119. tokens=tokens,
  120. input_ids=input_ids,
  121. input_mask=input_mask,
  122. input_type_ids=input_type_ids))
  123. return features
  124. def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  125. """Truncates a sequence pair in place to the maximum length."""
  126. # This is a simple heuristic which will always truncate the longer sequence
  127. # one token at a time. This makes more sense than truncating an equal percent
  128. # of tokens from each, since if one sequence is very short then each token
  129. # that's truncated likely contains more information than a longer sequence.
  130. while True:
  131. total_length = len(tokens_a) + len(tokens_b)
  132. if total_length <= max_length:
  133. break
  134. if len(tokens_a) > len(tokens_b):
  135. tokens_a.pop()
  136. else:
  137. tokens_b.pop()
  138. def read_examples(input_file):
  139. """Read a list of `InputExample`s from an input file."""
  140. examples = []
  141. unique_id = 0
  142. with open(input_file, "r", encoding='utf-8') as reader:
  143. while True:
  144. line = reader.readline()
  145. if not line:
  146. break
  147. line = line.strip()
  148. text_a = None
  149. text_b = None
  150. m = re.match(r"^(.*) \|\|\| (.*)$", line)
  151. if m is None:
  152. text_a = line
  153. else:
  154. text_a = m.group(1)
  155. text_b = m.group(2)
  156. examples.append(
  157. InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
  158. unique_id += 1
  159. return examples
  160. def main():
  161. parser = argparse.ArgumentParser()
  162. ## Required parameters
  163. parser.add_argument("--input_file", default=None, type=str, required=True)
  164. parser.add_argument("--output_file", default=None, type=str, required=True)
  165. parser.add_argument("--bert_model", default=None, type=str, required=True,
  166. help="Bert pre-trained model selected in the list: bert-base-uncased, "
  167. "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
  168. ## Other parameters
  169. parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
  170. parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
  171. parser.add_argument("--max_seq_length", default=128, type=int,
  172. help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
  173. "than this will be truncated, and sequences shorter than this will be padded.")
  174. parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
  175. parser.add_argument("--local_rank",
  176. type=int,
  177. default=-1,
  178. help = "local_rank for distributed training on gpus")
  179. parser.add_argument("--no_cuda",
  180. action='store_true',
  181. help="Whether not to use CUDA when available")
  182. args = parser.parse_args()
  183. if args.local_rank == -1 or args.no_cuda:
  184. device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
  185. n_gpu = torch.cuda.device_count()
  186. else:
  187. device = torch.device("cuda", args.local_rank)
  188. n_gpu = 1
  189. # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
  190. torch.distributed.init_process_group(backend='nccl')
  191. logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
  192. layer_indexes = [int(x) for x in args.layers.split(",")]
  193. tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
  194. examples = read_examples(args.input_file)
  195. features = convert_examples_to_features(
  196. examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)
  197. unique_id_to_feature = {}
  198. for feature in features:
  199. unique_id_to_feature[feature.unique_id] = feature
  200. model = BertModel.from_pretrained(args.bert_model)
  201. model.to(device)
  202. if args.local_rank != -1:
  203. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
  204. output_device=args.local_rank)
  205. elif n_gpu > 1:
  206. model = torch.nn.DataParallel(model)
  207. all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
  208. all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
  209. all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
  210. eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
  211. if args.local_rank == -1:
  212. eval_sampler = SequentialSampler(eval_data)
  213. else:
  214. eval_sampler = DistributedSampler(eval_data)
  215. eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
  216. model.eval()
  217. with open(args.output_file, "w", encoding='utf-8') as writer:
  218. for input_ids, input_mask, example_indices in eval_dataloader:
  219. input_ids = input_ids.to(device)
  220. input_mask = input_mask.to(device)
  221. all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
  222. all_encoder_layers = all_encoder_layers
  223. for b, example_index in enumerate(example_indices):
  224. feature = features[example_index.item()]
  225. unique_id = int(feature.unique_id)
  226. # feature = unique_id_to_feature[unique_id]
  227. output_json = collections.OrderedDict()
  228. output_json["linex_index"] = unique_id
  229. all_out_features = []
  230. for (i, token) in enumerate(feature.tokens):
  231. all_layers = []
  232. for (j, layer_index) in enumerate(layer_indexes):
  233. layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
  234. layer_output = layer_output[b]
  235. layers = collections.OrderedDict()
  236. layers["index"] = layer_index
  237. layers["values"] = [
  238. round(x.item(), 6) for x in layer_output[i]
  239. ]
  240. all_layers.append(layers)
  241. out_features = collections.OrderedDict()
  242. out_features["token"] = token
  243. out_features["layers"] = all_layers
  244. all_out_features.append(out_features)
  245. output_json["features"] = all_out_features
  246. writer.write(json.dumps(output_json) + "\n")
  247. if __name__ == "__main__":
  248. main()