model_helper.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utility functions for building models."""
  16. from __future__ import print_function
  17. import collections
  18. import os
  19. import time
  20. import numpy as np
  21. import six
  22. import tensorflow as tf
  23. from utils import math_utils
  24. from utils import misc_utils as utils
  25. from utils import vocab_utils
  26. __all__ = [
  27. "get_initializer", "create_emb_for_encoder_and_decoder", "create_rnn_cell",
  28. "gradient_clip", "create_or_load_model", "load_model", "avg_checkpoints",
  29. ]
  30. # If a vocab size is greater than this value, put the embedding on cpu instead
  31. VOCAB_SIZE_THRESHOLD_CPU = 50000
  32. def get_initializer(init_op, seed=None, init_weight=0):
  33. """Create an initializer. init_weight is only for uniform."""
  34. if init_op == "uniform":
  35. assert init_weight
  36. return tf.random_uniform_initializer(
  37. -init_weight, init_weight, seed=seed)
  38. elif init_op == "glorot_normal":
  39. return tf.keras.initializers.glorot_normal(
  40. seed=seed)
  41. elif init_op == "glorot_uniform":
  42. return tf.keras.initializers.glorot_uniform(
  43. seed=seed)
  44. elif init_op.isdigit():
  45. # dtype is default fp32 for variables.
  46. val = int(init_op)
  47. return tf.constant_initializer(val)
  48. else:
  49. raise ValueError("Unknown init_op %s" % init_op)
  50. class ExtraArgs(collections.namedtuple(
  51. "ExtraArgs", ("single_cell_fn", "model_device_fn",
  52. "attention_mechanism_fn", "encoder_emb_lookup_fn"))):
  53. pass
  54. class TrainModel(
  55. collections.namedtuple("TrainModel", ("graph", "model", "iterator",
  56. "skip_count_placeholder"))):
  57. pass
  58. def _get_embed_device(vocab_size):
  59. """Decide on which device to place an embed matrix given its vocab size."""
  60. if vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
  61. return "/cpu:0"
  62. else:
  63. return "/gpu:0"
  64. def _create_pretrained_emb_from_txt(
  65. vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32,
  66. scope=None):
  67. """Load pretrain embeding from embed_file, and return an embedding matrix.
  68. Args:
  69. vocab_file: Path to vocab file.
  70. embed_file: Path to a Glove formmated embedding txt file.
  71. num_trainable_tokens: Make the first n tokens in the vocab file as trainable
  72. variables. Default is 3, which is "<unk>", "<s>" and "</s>".
  73. dtype: data type.
  74. scope: tf scope name.
  75. Returns:
  76. pretrained embedding table variable.
  77. """
  78. vocab, _ = vocab_utils.load_vocab(vocab_file)
  79. trainable_tokens = vocab[:num_trainable_tokens]
  80. utils.print_out("# Using pretrained embedding: %s." % embed_file)
  81. utils.print_out(" with trainable tokens: ")
  82. emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file)
  83. for token in trainable_tokens:
  84. utils.print_out(" %s" % token)
  85. if token not in emb_dict:
  86. emb_dict[token] = [0.0] * emb_size
  87. emb_mat = np.array(
  88. [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype())
  89. emb_mat = tf.constant(emb_mat)
  90. emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
  91. with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope:
  92. emb_mat_var = tf.get_variable(
  93. "emb_mat_var", [num_trainable_tokens, emb_size])
  94. return tf.concat([emb_mat_var, emb_mat_const], 0)
  95. def _create_or_load_embed(embed_name, vocab_file, embed_file,
  96. vocab_size, embed_size, dtype):
  97. """Create a new or load an existing embedding matrix."""
  98. if vocab_file and embed_file:
  99. embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file)
  100. else:
  101. embedding = tf.get_variable(
  102. embed_name, [vocab_size, embed_size], dtype)
  103. return embedding
  104. def create_emb_for_encoder_and_decoder(share_vocab,
  105. src_vocab_size,
  106. tgt_vocab_size,
  107. src_embed_size,
  108. tgt_embed_size,
  109. dtype=tf.float32,
  110. num_enc_partitions=0,
  111. num_dec_partitions=0,
  112. src_vocab_file=None,
  113. tgt_vocab_file=None,
  114. src_embed_file=None,
  115. tgt_embed_file=None,
  116. use_char_encode=False,
  117. scope=None):
  118. """Create embedding matrix for both encoder and decoder.
  119. Args:
  120. share_vocab: A boolean. Whether to share embedding matrix for both
  121. encoder and decoder.
  122. src_vocab_size: An integer. The source vocab size.
  123. tgt_vocab_size: An integer. The target vocab size.
  124. src_embed_size: An integer. The embedding dimension for the encoder's
  125. embedding.
  126. tgt_embed_size: An integer. The embedding dimension for the decoder's
  127. embedding.
  128. dtype: dtype of the embedding matrix. Default to float32.
  129. num_enc_partitions: number of partitions used for the encoder's embedding
  130. vars.
  131. num_dec_partitions: number of partitions used for the decoder's embedding
  132. vars.
  133. src_vocab_file: A string. The source vocabulary file.
  134. tgt_vocab_file: A string. The target vocabulary file.
  135. src_embed_file: A string. The source embedding file.
  136. tgt_embed_file: A string. The target embedding file.
  137. use_char_encode: A boolean. If true, use char encoder.
  138. scope: VariableScope for the created subgraph. Default to "embedding".
  139. Returns:
  140. embedding_encoder: Encoder's embedding matrix.
  141. embedding_decoder: Decoder's embedding matrix.
  142. Raises:
  143. ValueError: if use share_vocab but source and target have different vocab
  144. size.
  145. """
  146. if num_enc_partitions <= 1:
  147. enc_partitioner = None
  148. else:
  149. # Note: num_partitions > 1 is required for distributed training due to
  150. # embedding_lookup tries to colocate single partition-ed embedding variable
  151. # with lookup ops. This may cause embedding variables being placed on worker
  152. # jobs.
  153. enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions)
  154. if num_dec_partitions <= 1:
  155. dec_partitioner = None
  156. else:
  157. # Note: num_partitions > 1 is required for distributed training due to
  158. # embedding_lookup tries to colocate single partition-ed embedding variable
  159. # with lookup ops. This may cause embedding variables being placed on worker
  160. # jobs.
  161. dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions)
  162. if src_embed_file and enc_partitioner:
  163. raise ValueError(
  164. "Can't set num_enc_partitions > 1 when using pretrained encoder "
  165. "embedding")
  166. if tgt_embed_file and dec_partitioner:
  167. raise ValueError(
  168. "Can't set num_dec_partitions > 1 when using pretrained decdoer "
  169. "embedding")
  170. with tf.variable_scope(
  171. scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope:
  172. # Share embedding
  173. if share_vocab:
  174. if src_vocab_size != tgt_vocab_size:
  175. raise ValueError("Share embedding but different src/tgt vocab sizes"
  176. " %d vs. %d" % (src_vocab_size, tgt_vocab_size))
  177. assert src_embed_size == tgt_embed_size
  178. utils.print_out("# Use the same embedding for source and target")
  179. vocab_file = src_vocab_file or tgt_vocab_file
  180. embed_file = src_embed_file or tgt_embed_file
  181. embedding_encoder = _create_or_load_embed(
  182. "embedding_share", vocab_file, embed_file,
  183. src_vocab_size, src_embed_size, dtype)
  184. embedding_decoder = embedding_encoder
  185. else:
  186. if not use_char_encode:
  187. with tf.variable_scope("encoder", partitioner=enc_partitioner):
  188. embedding_encoder = _create_or_load_embed(
  189. "embedding_encoder", src_vocab_file, src_embed_file,
  190. src_vocab_size, src_embed_size, dtype)
  191. else:
  192. embedding_encoder = None
  193. with tf.variable_scope("decoder", partitioner=dec_partitioner):
  194. embedding_decoder = _create_or_load_embed(
  195. "embedding_decoder", tgt_vocab_file, tgt_embed_file,
  196. tgt_vocab_size, tgt_embed_size, dtype)
  197. return embedding_encoder, embedding_decoder
  198. def build_cell(cell, input_shape):
  199. if isinstance(cell, tf.contrib.rnn.MultiRNNCell):
  200. assert isinstance(input_shape, collections.Sequence)
  201. for i, c in enumerate(cell._cells):
  202. if i == 0:
  203. c.build((None, input_shape))
  204. else:
  205. c.build((None, c.num_units))
  206. return
  207. if isinstance(cell, tf.nn.rnn_cell.DropoutWrapper):
  208. build_cell(cell._cell, input_shape)
  209. elif isinstance(cell, tf.nn.rnn_cell.ResidualWrapper):
  210. build_cell(cell._cell, input_shape)
  211. elif isinstance(cell, tf.nn.rnn_cell.LSTMCell):
  212. cell.build(input_shape)
  213. else:
  214. raise ValueError("%s not supported" % type(cell))
  215. def _single_cell(unit_type, num_units, forget_bias, dropout, mode,
  216. dtype=None, residual_connection=False, residual_fn=None,
  217. use_block_lstm=False):
  218. """Create an instance of a single RNN cell."""
  219. # dropout (= 1 - keep_prob) is set to 0 during eval and infer
  220. dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
  221. # Cell Type
  222. if unit_type == "lstm":
  223. utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False)
  224. if not use_block_lstm:
  225. single_cell = tf.nn.rnn_cell.LSTMCell(
  226. num_units, dtype=dtype, forget_bias=forget_bias)
  227. else:
  228. single_cell = tf.contrib.rnn.LSTMBlockCell(
  229. num_units, forget_bias=forget_bias)
  230. elif unit_type == "gru":
  231. utils.print_out(" GRU", new_line=False)
  232. single_cell = tf.contrib.rnn.GRUCell(num_units)
  233. elif unit_type == "layer_norm_lstm":
  234. utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias,
  235. new_line=False)
  236. single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
  237. num_units,
  238. forget_bias=forget_bias,
  239. layer_norm=True)
  240. elif unit_type == "nas":
  241. utils.print_out(" NASCell", new_line=False)
  242. single_cell = tf.contrib.rnn.NASCell(num_units)
  243. else:
  244. raise ValueError("Unknown unit type %s!" % unit_type)
  245. # Dropout (= 1 - keep_prob)
  246. if dropout > 0.0:
  247. single_cell = tf.contrib.rnn.DropoutWrapper(
  248. cell=single_cell, input_keep_prob=(1.0 - dropout))
  249. utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout),
  250. new_line=False)
  251. # Residual
  252. if residual_connection:
  253. single_cell = tf.contrib.rnn.ResidualWrapper(
  254. single_cell, residual_fn=residual_fn)
  255. utils.print_out(" %s" % type(single_cell).__name__, new_line=False)
  256. return single_cell
  257. def _cell_list(unit_type, num_units, num_layers, num_residual_layers,
  258. forget_bias, dropout, mode, dtype=None,
  259. single_cell_fn=None, residual_fn=None, use_block_lstm=False):
  260. """Create a list of RNN cells."""
  261. if not single_cell_fn:
  262. single_cell_fn = _single_cell
  263. # Multi-GPU
  264. cell_list = []
  265. for i in range(num_layers):
  266. utils.print_out(" cell %d" % i, new_line=False)
  267. single_cell = single_cell_fn(
  268. unit_type=unit_type,
  269. num_units=num_units,
  270. forget_bias=forget_bias,
  271. dropout=dropout,
  272. mode=mode,
  273. dtype=dtype,
  274. residual_connection=(i >= num_layers - num_residual_layers),
  275. residual_fn=residual_fn,
  276. use_block_lstm=use_block_lstm
  277. )
  278. utils.print_out("")
  279. cell_list.append(single_cell)
  280. return cell_list
  281. def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers,
  282. forget_bias, dropout, mode, dtype=None,
  283. single_cell_fn=None, use_block_lstm=False):
  284. """Create multi-layer RNN cell.
  285. Args:
  286. unit_type: string representing the unit type, i.e. "lstm".
  287. num_units: the depth of each unit.
  288. num_layers: number of cells.
  289. num_residual_layers: Number of residual layers from top to bottom. For
  290. example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN
  291. cells in the returned list will be wrapped with `ResidualWrapper`.
  292. forget_bias: the initial forget bias of the RNNCell(s).
  293. dropout: floating point value between 0.0 and 1.0:
  294. the probability of dropout. this is ignored if `mode != TRAIN`.
  295. mode: either tf.contrib.learn.TRAIN/EVAL/INFER
  296. single_cell_fn: allow for adding customized cell.
  297. When not specified, we default to model_helper._single_cell
  298. Returns:
  299. An `RNNCell` instance.
  300. """
  301. cell_list = _cell_list(unit_type=unit_type,
  302. num_units=num_units,
  303. num_layers=num_layers,
  304. num_residual_layers=num_residual_layers,
  305. forget_bias=forget_bias,
  306. dropout=dropout,
  307. mode=mode,
  308. dtype=dtype,
  309. single_cell_fn=single_cell_fn,
  310. use_block_lstm=use_block_lstm)
  311. if len(cell_list) == 1: # Single layer.
  312. return cell_list[0]
  313. else: # Multi layers
  314. return tf.contrib.rnn.MultiRNNCell(cell_list)
  315. def gradient_clip(gradients, max_gradient_norm):
  316. """Clipping gradients of a model."""
  317. clipped_gradients, gradient_norm = math_utils.clip_by_global_norm(
  318. gradients, max_gradient_norm)
  319. return clipped_gradients, gradient_norm
  320. def print_variables_in_ckpt(ckpt_path):
  321. """Print a list of variables in a checkpoint together with their shapes."""
  322. utils.print_out("# Variables in ckpt %s" % ckpt_path)
  323. reader = tf.train.NewCheckpointReader(ckpt_path)
  324. variable_map = reader.get_variable_to_shape_map()
  325. for key in sorted(variable_map.keys()):
  326. utils.print_out(" %s: %s" % (key, variable_map[key]))
  327. def load_model(model, ckpt_path, session, name):
  328. """Load model from a checkpoint."""
  329. start_time = time.time()
  330. try:
  331. model.saver.restore(session, ckpt_path)
  332. except tf.errors.NotFoundError as e:
  333. utils.print_out("Can't load checkpoint")
  334. print_variables_in_ckpt(ckpt_path)
  335. utils.print_out("%s" % str(e))
  336. session.run(tf.tables_initializer())
  337. utils.print_out(
  338. " loaded %s model parameters from %s, time %.2fs" %
  339. (name, ckpt_path, time.time() - start_time))
  340. return model
  341. def avg_checkpoints(model_dir, num_last_checkpoints, global_step_name):
  342. """Average the last N checkpoints in the model_dir."""
  343. checkpoint_state = tf.train.get_checkpoint_state(model_dir)
  344. if not checkpoint_state:
  345. utils.print_out("# No checkpoint file found in directory: %s" % model_dir)
  346. return None
  347. # Checkpoints are ordered from oldest to newest.
  348. checkpoints = (
  349. checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
  350. if len(checkpoints) < num_last_checkpoints:
  351. utils.print_out(
  352. "# Skipping averaging checkpoints because not enough checkpoints is "
  353. "available.")
  354. return None
  355. avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
  356. if not tf.gfile.Exists(avg_model_dir):
  357. utils.print_out(
  358. "# Creating new directory %s for saving averaged checkpoints." %
  359. avg_model_dir)
  360. tf.gfile.MakeDirs(avg_model_dir)
  361. utils.print_out("# Reading and averaging variables in checkpoints:")
  362. var_list = tf.contrib.framework.list_variables(checkpoints[0])
  363. var_values, var_dtypes = {}, {}
  364. for (name, shape) in var_list:
  365. if name != global_step_name:
  366. var_values[name] = np.zeros(shape)
  367. for checkpoint in checkpoints:
  368. utils.print_out(" %s" % checkpoint)
  369. reader = tf.contrib.framework.load_checkpoint(checkpoint)
  370. for name in var_values:
  371. tensor = reader.get_tensor(name)
  372. var_dtypes[name] = tensor.dtype
  373. var_values[name] += tensor
  374. for name in var_values:
  375. var_values[name] /= len(checkpoints)
  376. # Build a graph with same variables in the checkpoints, and save the averaged
  377. # variables into the avg_model_dir.
  378. with tf.Graph().as_default():
  379. tf_vars = [
  380. tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
  381. for v in var_values
  382. ]
  383. placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
  384. assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
  385. saver = tf.train.Saver(tf.all_variables(), save_relative_paths=True)
  386. with tf.Session() as sess:
  387. sess.run(tf.initialize_all_variables())
  388. for p, assign_op, (name, value) in zip(placeholders, assign_ops,
  389. six.iteritems(var_values)):
  390. sess.run(assign_op, {p: value})
  391. # Use the built saver to save the averaged checkpoint. Only keep 1
  392. # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
  393. saver.save(
  394. sess,
  395. os.path.join(avg_model_dir, "translate.ckpt"))
  396. return avg_model_dir
  397. def create_or_load_model(model, model_dir, session, name):
  398. """Create translation model and initialize or load parameters in session."""
  399. latest_ckpt = tf.train.latest_checkpoint(model_dir)
  400. if latest_ckpt:
  401. model = load_model(model, latest_ckpt, session, name)
  402. else:
  403. start_time = time.time()
  404. session.run(tf.global_variables_initializer())
  405. session.run(tf.tables_initializer())
  406. utils.print_out(" created %s model with fresh parameters, time %.2fs" %
  407. (name, time.time() - start_time))
  408. global_step = model.global_step.eval(session=session)
  409. return model, global_step