modeling_test.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import collections
  19. import json
  20. import random
  21. import re
  22. import modeling
  23. import six
  24. import tensorflow as tf
  25. class BertModelTest(tf.test.TestCase):
  26. class BertModelTester(object):
  27. def __init__(self,
  28. parent,
  29. batch_size=13,
  30. seq_length=7,
  31. is_training=True,
  32. use_input_mask=True,
  33. use_token_type_ids=True,
  34. vocab_size=99,
  35. hidden_size=32,
  36. num_hidden_layers=5,
  37. num_attention_heads=4,
  38. intermediate_size=37,
  39. hidden_act="gelu",
  40. hidden_dropout_prob=0.1,
  41. attention_probs_dropout_prob=0.1,
  42. max_position_embeddings=512,
  43. type_vocab_size=16,
  44. initializer_range=0.02,
  45. scope=None):
  46. self.parent = parent
  47. self.batch_size = batch_size
  48. self.seq_length = seq_length
  49. self.is_training = is_training
  50. self.use_input_mask = use_input_mask
  51. self.use_token_type_ids = use_token_type_ids
  52. self.vocab_size = vocab_size
  53. self.hidden_size = hidden_size
  54. self.num_hidden_layers = num_hidden_layers
  55. self.num_attention_heads = num_attention_heads
  56. self.intermediate_size = intermediate_size
  57. self.hidden_act = hidden_act
  58. self.hidden_dropout_prob = hidden_dropout_prob
  59. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  60. self.max_position_embeddings = max_position_embeddings
  61. self.type_vocab_size = type_vocab_size
  62. self.initializer_range = initializer_range
  63. self.scope = scope
  64. def create_model(self):
  65. input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
  66. self.vocab_size)
  67. input_mask = None
  68. if self.use_input_mask:
  69. input_mask = BertModelTest.ids_tensor(
  70. [self.batch_size, self.seq_length], vocab_size=2)
  71. token_type_ids = None
  72. if self.use_token_type_ids:
  73. token_type_ids = BertModelTest.ids_tensor(
  74. [self.batch_size, self.seq_length], self.type_vocab_size)
  75. config = modeling.BertConfig(
  76. vocab_size=self.vocab_size,
  77. hidden_size=self.hidden_size,
  78. num_hidden_layers=self.num_hidden_layers,
  79. num_attention_heads=self.num_attention_heads,
  80. intermediate_size=self.intermediate_size,
  81. hidden_act=self.hidden_act,
  82. hidden_dropout_prob=self.hidden_dropout_prob,
  83. attention_probs_dropout_prob=self.attention_probs_dropout_prob,
  84. max_position_embeddings=self.max_position_embeddings,
  85. type_vocab_size=self.type_vocab_size,
  86. initializer_range=self.initializer_range)
  87. model = modeling.BertModel(
  88. config=config,
  89. is_training=self.is_training,
  90. input_ids=input_ids,
  91. input_mask=input_mask,
  92. token_type_ids=token_type_ids,
  93. scope=self.scope)
  94. outputs = {
  95. "embedding_output": model.get_embedding_output(),
  96. "sequence_output": model.get_sequence_output(),
  97. "pooled_output": model.get_pooled_output(),
  98. "all_encoder_layers": model.get_all_encoder_layers(),
  99. }
  100. return outputs
  101. def check_output(self, result):
  102. self.parent.assertAllEqual(
  103. result["embedding_output"].shape,
  104. [self.batch_size, self.seq_length, self.hidden_size])
  105. self.parent.assertAllEqual(
  106. result["sequence_output"].shape,
  107. [self.batch_size, self.seq_length, self.hidden_size])
  108. self.parent.assertAllEqual(result["pooled_output"].shape,
  109. [self.batch_size, self.hidden_size])
  110. def test_default(self):
  111. self.run_tester(BertModelTest.BertModelTester(self))
  112. def test_config_to_json_string(self):
  113. config = modeling.BertConfig(vocab_size=99, hidden_size=37)
  114. obj = json.loads(config.to_json_string())
  115. self.assertEqual(obj["vocab_size"], 99)
  116. self.assertEqual(obj["hidden_size"], 37)
  117. def run_tester(self, tester):
  118. with self.test_session() as sess:
  119. ops = tester.create_model()
  120. init_op = tf.group(tf.global_variables_initializer(),
  121. tf.local_variables_initializer())
  122. sess.run(init_op)
  123. output_result = sess.run(ops)
  124. tester.check_output(output_result)
  125. self.assert_all_tensors_reachable(sess, [init_op, ops])
  126. @classmethod
  127. def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
  128. """Creates a random int32 tensor of the shape within the vocab size."""
  129. if rng is None:
  130. rng = random.Random()
  131. total_dims = 1
  132. for dim in shape:
  133. total_dims *= dim
  134. values = []
  135. for _ in range(total_dims):
  136. values.append(rng.randint(0, vocab_size - 1))
  137. return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
  138. def assert_all_tensors_reachable(self, sess, outputs):
  139. """Checks that all the tensors in the graph are reachable from outputs."""
  140. graph = sess.graph
  141. ignore_strings = [
  142. "^.*/assert_less_equal/.*$",
  143. "^.*/dilation_rate$",
  144. "^.*/Tensordot/concat$",
  145. "^.*/Tensordot/concat/axis$",
  146. "^testing/.*$",
  147. ]
  148. ignore_regexes = [re.compile(x) for x in ignore_strings]
  149. unreachable = self.get_unreachable_ops(graph, outputs)
  150. filtered_unreachable = []
  151. for x in unreachable:
  152. do_ignore = False
  153. for r in ignore_regexes:
  154. m = r.match(x.name)
  155. if m is not None:
  156. do_ignore = True
  157. if do_ignore:
  158. continue
  159. filtered_unreachable.append(x)
  160. unreachable = filtered_unreachable
  161. self.assertEqual(
  162. len(unreachable), 0, "The following ops are unreachable: %s" %
  163. (" ".join([x.name for x in unreachable])))
  164. @classmethod
  165. def get_unreachable_ops(cls, graph, outputs):
  166. """Finds all of the tensors in graph that are unreachable from outputs."""
  167. outputs = cls.flatten_recursive(outputs)
  168. output_to_op = collections.defaultdict(list)
  169. op_to_all = collections.defaultdict(list)
  170. assign_out_to_in = collections.defaultdict(list)
  171. for op in graph.get_operations():
  172. for x in op.inputs:
  173. op_to_all[op.name].append(x.name)
  174. for y in op.outputs:
  175. output_to_op[y.name].append(op.name)
  176. op_to_all[op.name].append(y.name)
  177. if str(op.type) == "Assign":
  178. for y in op.outputs:
  179. for x in op.inputs:
  180. assign_out_to_in[y.name].append(x.name)
  181. assign_groups = collections.defaultdict(list)
  182. for out_name in assign_out_to_in.keys():
  183. name_group = assign_out_to_in[out_name]
  184. for n1 in name_group:
  185. assign_groups[n1].append(out_name)
  186. for n2 in name_group:
  187. if n1 != n2:
  188. assign_groups[n1].append(n2)
  189. seen_tensors = {}
  190. stack = [x.name for x in outputs]
  191. while stack:
  192. name = stack.pop()
  193. if name in seen_tensors:
  194. continue
  195. seen_tensors[name] = True
  196. if name in output_to_op:
  197. for op_name in output_to_op[name]:
  198. if op_name in op_to_all:
  199. for input_name in op_to_all[op_name]:
  200. if input_name not in stack:
  201. stack.append(input_name)
  202. expanded_names = []
  203. if name in assign_groups:
  204. for assign_name in assign_groups[name]:
  205. expanded_names.append(assign_name)
  206. for expanded_name in expanded_names:
  207. if expanded_name not in stack:
  208. stack.append(expanded_name)
  209. unreachable_ops = []
  210. for op in graph.get_operations():
  211. is_unreachable = False
  212. all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
  213. for name in all_names:
  214. if name not in seen_tensors:
  215. is_unreachable = True
  216. if is_unreachable:
  217. unreachable_ops.append(op)
  218. return unreachable_ops
  219. @classmethod
  220. def flatten_recursive(cls, item):
  221. """Flattens (potentially nested) a tuple/dictionary/list to a list."""
  222. output = []
  223. if isinstance(item, list):
  224. output.extend(item)
  225. elif isinstance(item, tuple):
  226. output.extend(list(item))
  227. elif isinstance(item, dict):
  228. for (_, v) in six.iteritems(item):
  229. output.append(v)
  230. else:
  231. return [item]
  232. flat_output = []
  233. for x in output:
  234. flat_output.extend(cls.flatten_recursive(x))
  235. return flat_output
  236. if __name__ == "__main__":
  237. tf.test.main()