run_classifier_with_tfhub.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BERT finetuning runner with TF-Hub."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import optimization
  21. import run_classifier
  22. import tokenization
  23. import tensorflow as tf
  24. import tensorflow_hub as hub
  25. flags = tf.flags
  26. FLAGS = flags.FLAGS
  27. flags.DEFINE_string(
  28. "bert_hub_module_handle", None,
  29. "Handle for the BERT TF-Hub module.")
  30. def create_model(is_training, input_ids, input_mask, segment_ids, labels,
  31. num_labels, bert_hub_module_handle):
  32. """Creates a classification model."""
  33. tags = set()
  34. if is_training:
  35. tags.add("train")
  36. bert_module = hub.Module(bert_hub_module_handle, tags=tags, trainable=True)
  37. bert_inputs = dict(
  38. input_ids=input_ids,
  39. input_mask=input_mask,
  40. segment_ids=segment_ids)
  41. bert_outputs = bert_module(
  42. inputs=bert_inputs,
  43. signature="tokens",
  44. as_dict=True)
  45. # In the demo, we are doing a simple classification task on the entire
  46. # segment.
  47. #
  48. # If you want to use the token-level output, use
  49. # bert_outputs["sequence_output"] instead.
  50. output_layer = bert_outputs["pooled_output"]
  51. hidden_size = output_layer.shape[-1].value
  52. output_weights = tf.get_variable(
  53. "output_weights", [num_labels, hidden_size],
  54. initializer=tf.truncated_normal_initializer(stddev=0.02))
  55. output_bias = tf.get_variable(
  56. "output_bias", [num_labels], initializer=tf.zeros_initializer())
  57. with tf.variable_scope("loss"):
  58. if is_training:
  59. # I.e., 0.1 dropout
  60. output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
  61. logits = tf.matmul(output_layer, output_weights, transpose_b=True)
  62. logits = tf.nn.bias_add(logits, output_bias)
  63. probabilities = tf.nn.softmax(logits, axis=-1)
  64. log_probs = tf.nn.log_softmax(logits, axis=-1)
  65. one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
  66. per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
  67. loss = tf.reduce_mean(per_example_loss)
  68. return (loss, per_example_loss, logits, probabilities)
  69. def model_fn_builder(num_labels, learning_rate, num_train_steps,
  70. num_warmup_steps, use_tpu, bert_hub_module_handle):
  71. """Returns `model_fn` closure for TPUEstimator."""
  72. def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
  73. """The `model_fn` for TPUEstimator."""
  74. tf.logging.info("*** Features ***")
  75. for name in sorted(features.keys()):
  76. tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
  77. input_ids = features["input_ids"]
  78. input_mask = features["input_mask"]
  79. segment_ids = features["segment_ids"]
  80. label_ids = features["label_ids"]
  81. is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  82. (total_loss, per_example_loss, logits, probabilities) = create_model(
  83. is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
  84. bert_hub_module_handle)
  85. output_spec = None
  86. if mode == tf.estimator.ModeKeys.TRAIN:
  87. train_op = optimization.create_optimizer(
  88. total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
  89. output_spec = tf.contrib.tpu.TPUEstimatorSpec(
  90. mode=mode,
  91. loss=total_loss,
  92. train_op=train_op)
  93. elif mode == tf.estimator.ModeKeys.EVAL:
  94. def metric_fn(per_example_loss, label_ids, logits):
  95. predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
  96. accuracy = tf.metrics.accuracy(label_ids, predictions)
  97. loss = tf.metrics.mean(per_example_loss)
  98. return {
  99. "eval_accuracy": accuracy,
  100. "eval_loss": loss,
  101. }
  102. eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
  103. output_spec = tf.contrib.tpu.TPUEstimatorSpec(
  104. mode=mode,
  105. loss=total_loss,
  106. eval_metrics=eval_metrics)
  107. elif mode == tf.estimator.ModeKeys.PREDICT:
  108. output_spec = tf.contrib.tpu.TPUEstimatorSpec(
  109. mode=mode, predictions={"probabilities": probabilities})
  110. else:
  111. raise ValueError(
  112. "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode))
  113. return output_spec
  114. return model_fn
  115. def create_tokenizer_from_hub_module(bert_hub_module_handle):
  116. """Get the vocab file and casing info from the Hub module."""
  117. with tf.Graph().as_default():
  118. bert_module = hub.Module(bert_hub_module_handle)
  119. tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
  120. with tf.Session() as sess:
  121. vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
  122. tokenization_info["do_lower_case"]])
  123. return tokenization.FullTokenizer(
  124. vocab_file=vocab_file, do_lower_case=do_lower_case)
  125. def main(_):
  126. tf.logging.set_verbosity(tf.logging.INFO)
  127. processors = {
  128. "cola": run_classifier.ColaProcessor,
  129. "mnli": run_classifier.MnliProcessor,
  130. "mrpc": run_classifier.MrpcProcessor,
  131. }
  132. if not FLAGS.do_train and not FLAGS.do_eval:
  133. raise ValueError("At least one of `do_train` or `do_eval` must be True.")
  134. tf.gfile.MakeDirs(FLAGS.output_dir)
  135. task_name = FLAGS.task_name.lower()
  136. if task_name not in processors:
  137. raise ValueError("Task not found: %s" % (task_name))
  138. processor = processors[task_name]()
  139. label_list = processor.get_labels()
  140. tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle)
  141. tpu_cluster_resolver = None
  142. if FLAGS.use_tpu and FLAGS.tpu_name:
  143. tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
  144. FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
  145. is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  146. run_config = tf.contrib.tpu.RunConfig(
  147. cluster=tpu_cluster_resolver,
  148. master=FLAGS.master,
  149. model_dir=FLAGS.output_dir,
  150. save_checkpoints_steps=FLAGS.save_checkpoints_steps,
  151. tpu_config=tf.contrib.tpu.TPUConfig(
  152. iterations_per_loop=FLAGS.iterations_per_loop,
  153. num_shards=FLAGS.num_tpu_cores,
  154. per_host_input_for_training=is_per_host))
  155. train_examples = None
  156. num_train_steps = None
  157. num_warmup_steps = None
  158. if FLAGS.do_train:
  159. train_examples = processor.get_train_examples(FLAGS.data_dir)
  160. num_train_steps = int(
  161. len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
  162. num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
  163. model_fn = model_fn_builder(
  164. num_labels=len(label_list),
  165. learning_rate=FLAGS.learning_rate,
  166. num_train_steps=num_train_steps,
  167. num_warmup_steps=num_warmup_steps,
  168. use_tpu=FLAGS.use_tpu,
  169. bert_hub_module_handle=FLAGS.bert_hub_module_handle)
  170. # If TPU is not available, this will fall back to normal Estimator on CPU
  171. # or GPU.
  172. estimator = tf.contrib.tpu.TPUEstimator(
  173. use_tpu=FLAGS.use_tpu,
  174. model_fn=model_fn,
  175. config=run_config,
  176. train_batch_size=FLAGS.train_batch_size,
  177. eval_batch_size=FLAGS.eval_batch_size,
  178. predict_batch_size=FLAGS.predict_batch_size)
  179. if FLAGS.do_train:
  180. train_features = run_classifier.convert_examples_to_features(
  181. train_examples, label_list, FLAGS.max_seq_length, tokenizer)
  182. tf.logging.info("***** Running training *****")
  183. tf.logging.info(" Num examples = %d", len(train_examples))
  184. tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
  185. tf.logging.info(" Num steps = %d", num_train_steps)
  186. train_input_fn = run_classifier.input_fn_builder(
  187. features=train_features,
  188. seq_length=FLAGS.max_seq_length,
  189. is_training=True,
  190. drop_remainder=True)
  191. estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
  192. if FLAGS.do_eval:
  193. eval_examples = processor.get_dev_examples(FLAGS.data_dir)
  194. eval_features = run_classifier.convert_examples_to_features(
  195. eval_examples, label_list, FLAGS.max_seq_length, tokenizer)
  196. tf.logging.info("***** Running evaluation *****")
  197. tf.logging.info(" Num examples = %d", len(eval_examples))
  198. tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
  199. # This tells the estimator to run through the entire set.
  200. eval_steps = None
  201. # However, if running eval on the TPU, you will need to specify the
  202. # number of steps.
  203. if FLAGS.use_tpu:
  204. # Eval will be slightly WRONG on the TPU because it will truncate
  205. # the last batch.
  206. eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
  207. eval_drop_remainder = True if FLAGS.use_tpu else False
  208. eval_input_fn = run_classifier.input_fn_builder(
  209. features=eval_features,
  210. seq_length=FLAGS.max_seq_length,
  211. is_training=False,
  212. drop_remainder=eval_drop_remainder)
  213. result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
  214. output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
  215. with tf.gfile.GFile(output_eval_file, "w") as writer:
  216. tf.logging.info("***** Eval results *****")
  217. for key in sorted(result.keys()):
  218. tf.logging.info(" %s = %s", key, str(result[key]))
  219. writer.write("%s = %s\n" % (key, str(result[key])))
  220. if FLAGS.do_predict:
  221. predict_examples = processor.get_test_examples(FLAGS.data_dir)
  222. if FLAGS.use_tpu:
  223. # Discard batch remainder if running on TPU
  224. n = len(predict_examples)
  225. predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)]
  226. predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
  227. run_classifier.file_based_convert_examples_to_features(
  228. predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
  229. predict_file)
  230. tf.logging.info("***** Running prediction*****")
  231. tf.logging.info(" Num examples = %d", len(predict_examples))
  232. tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
  233. predict_input_fn = run_classifier.file_based_input_fn_builder(
  234. input_file=predict_file,
  235. seq_length=FLAGS.max_seq_length,
  236. is_training=False,
  237. drop_remainder=FLAGS.use_tpu)
  238. result = estimator.predict(input_fn=predict_input_fn)
  239. output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
  240. with tf.gfile.GFile(output_predict_file, "w") as writer:
  241. tf.logging.info("***** Predict results *****")
  242. for prediction in result:
  243. probabilities = prediction["probabilities"]
  244. output_line = "\t".join(
  245. str(class_probability)
  246. for class_probability in probabilities) + "\n"
  247. writer.write(output_line)
  248. if __name__ == "__main__":
  249. flags.mark_flag_as_required("data_dir")
  250. flags.mark_flag_as_required("task_name")
  251. flags.mark_flag_as_required("bert_hub_module_handle")
  252. flags.mark_flag_as_required("output_dir")
  253. tf.app.run()