| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- # Copyright 2017 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Utility functions for building models."""
- from __future__ import print_function
- import collections
- import os
- import time
- import numpy as np
- import six
- import tensorflow as tf
- from utils import math_utils
- from utils import misc_utils as utils
- from utils import vocab_utils
- __all__ = [
- "get_initializer", "create_emb_for_encoder_and_decoder", "create_rnn_cell",
- "gradient_clip", "create_or_load_model", "load_model", "avg_checkpoints",
- ]
- # If a vocab size is greater than this value, put the embedding on cpu instead
- VOCAB_SIZE_THRESHOLD_CPU = 50000
- def get_initializer(init_op, seed=None, init_weight=0):
- """Create an initializer. init_weight is only for uniform."""
- if init_op == "uniform":
- assert init_weight
- return tf.random_uniform_initializer(
- -init_weight, init_weight, seed=seed)
- elif init_op == "glorot_normal":
- return tf.keras.initializers.glorot_normal(
- seed=seed)
- elif init_op == "glorot_uniform":
- return tf.keras.initializers.glorot_uniform(
- seed=seed)
- elif init_op.isdigit():
- # dtype is default fp32 for variables.
- val = int(init_op)
- return tf.constant_initializer(val)
- else:
- raise ValueError("Unknown init_op %s" % init_op)
- class ExtraArgs(collections.namedtuple(
- "ExtraArgs", ("single_cell_fn", "model_device_fn",
- "attention_mechanism_fn", "encoder_emb_lookup_fn"))):
- pass
- class TrainModel(
- collections.namedtuple("TrainModel", ("graph", "model", "iterator",
- "skip_count_placeholder"))):
- pass
- def _get_embed_device(vocab_size):
- """Decide on which device to place an embed matrix given its vocab size."""
- if vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
- return "/cpu:0"
- else:
- return "/gpu:0"
- def _create_pretrained_emb_from_txt(
- vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32,
- scope=None):
- """Load pretrain embeding from embed_file, and return an embedding matrix.
- Args:
- vocab_file: Path to vocab file.
- embed_file: Path to a Glove formmated embedding txt file.
- num_trainable_tokens: Make the first n tokens in the vocab file as trainable
- variables. Default is 3, which is "<unk>", "<s>" and "</s>".
- dtype: data type.
- scope: tf scope name.
- Returns:
- pretrained embedding table variable.
- """
- vocab, _ = vocab_utils.load_vocab(vocab_file)
- trainable_tokens = vocab[:num_trainable_tokens]
- utils.print_out("# Using pretrained embedding: %s." % embed_file)
- utils.print_out(" with trainable tokens: ")
- emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file)
- for token in trainable_tokens:
- utils.print_out(" %s" % token)
- if token not in emb_dict:
- emb_dict[token] = [0.0] * emb_size
- emb_mat = np.array(
- [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype())
- emb_mat = tf.constant(emb_mat)
- emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
- with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope:
- emb_mat_var = tf.get_variable(
- "emb_mat_var", [num_trainable_tokens, emb_size])
- return tf.concat([emb_mat_var, emb_mat_const], 0)
- def _create_or_load_embed(embed_name, vocab_file, embed_file,
- vocab_size, embed_size, dtype):
- """Create a new or load an existing embedding matrix."""
- if vocab_file and embed_file:
- embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file)
- else:
- embedding = tf.get_variable(
- embed_name, [vocab_size, embed_size], dtype)
- return embedding
- def create_emb_for_encoder_and_decoder(share_vocab,
- src_vocab_size,
- tgt_vocab_size,
- src_embed_size,
- tgt_embed_size,
- dtype=tf.float32,
- num_enc_partitions=0,
- num_dec_partitions=0,
- src_vocab_file=None,
- tgt_vocab_file=None,
- src_embed_file=None,
- tgt_embed_file=None,
- use_char_encode=False,
- scope=None):
- """Create embedding matrix for both encoder and decoder.
- Args:
- share_vocab: A boolean. Whether to share embedding matrix for both
- encoder and decoder.
- src_vocab_size: An integer. The source vocab size.
- tgt_vocab_size: An integer. The target vocab size.
- src_embed_size: An integer. The embedding dimension for the encoder's
- embedding.
- tgt_embed_size: An integer. The embedding dimension for the decoder's
- embedding.
- dtype: dtype of the embedding matrix. Default to float32.
- num_enc_partitions: number of partitions used for the encoder's embedding
- vars.
- num_dec_partitions: number of partitions used for the decoder's embedding
- vars.
- src_vocab_file: A string. The source vocabulary file.
- tgt_vocab_file: A string. The target vocabulary file.
- src_embed_file: A string. The source embedding file.
- tgt_embed_file: A string. The target embedding file.
- use_char_encode: A boolean. If true, use char encoder.
- scope: VariableScope for the created subgraph. Default to "embedding".
- Returns:
- embedding_encoder: Encoder's embedding matrix.
- embedding_decoder: Decoder's embedding matrix.
- Raises:
- ValueError: if use share_vocab but source and target have different vocab
- size.
- """
- if num_enc_partitions <= 1:
- enc_partitioner = None
- else:
- # Note: num_partitions > 1 is required for distributed training due to
- # embedding_lookup tries to colocate single partition-ed embedding variable
- # with lookup ops. This may cause embedding variables being placed on worker
- # jobs.
- enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions)
- if num_dec_partitions <= 1:
- dec_partitioner = None
- else:
- # Note: num_partitions > 1 is required for distributed training due to
- # embedding_lookup tries to colocate single partition-ed embedding variable
- # with lookup ops. This may cause embedding variables being placed on worker
- # jobs.
- dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions)
- if src_embed_file and enc_partitioner:
- raise ValueError(
- "Can't set num_enc_partitions > 1 when using pretrained encoder "
- "embedding")
- if tgt_embed_file and dec_partitioner:
- raise ValueError(
- "Can't set num_dec_partitions > 1 when using pretrained decdoer "
- "embedding")
- with tf.variable_scope(
- scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope:
- # Share embedding
- if share_vocab:
- if src_vocab_size != tgt_vocab_size:
- raise ValueError("Share embedding but different src/tgt vocab sizes"
- " %d vs. %d" % (src_vocab_size, tgt_vocab_size))
- assert src_embed_size == tgt_embed_size
- utils.print_out("# Use the same embedding for source and target")
- vocab_file = src_vocab_file or tgt_vocab_file
- embed_file = src_embed_file or tgt_embed_file
- embedding_encoder = _create_or_load_embed(
- "embedding_share", vocab_file, embed_file,
- src_vocab_size, src_embed_size, dtype)
- embedding_decoder = embedding_encoder
- else:
- if not use_char_encode:
- with tf.variable_scope("encoder", partitioner=enc_partitioner):
- embedding_encoder = _create_or_load_embed(
- "embedding_encoder", src_vocab_file, src_embed_file,
- src_vocab_size, src_embed_size, dtype)
- else:
- embedding_encoder = None
- with tf.variable_scope("decoder", partitioner=dec_partitioner):
- embedding_decoder = _create_or_load_embed(
- "embedding_decoder", tgt_vocab_file, tgt_embed_file,
- tgt_vocab_size, tgt_embed_size, dtype)
- return embedding_encoder, embedding_decoder
- def build_cell(cell, input_shape):
- if isinstance(cell, tf.contrib.rnn.MultiRNNCell):
- assert isinstance(input_shape, collections.Sequence)
- for i, c in enumerate(cell._cells):
- if i == 0:
- c.build((None, input_shape))
- else:
- c.build((None, c.num_units))
- return
- if isinstance(cell, tf.nn.rnn_cell.DropoutWrapper):
- build_cell(cell._cell, input_shape)
- elif isinstance(cell, tf.nn.rnn_cell.ResidualWrapper):
- build_cell(cell._cell, input_shape)
- elif isinstance(cell, tf.nn.rnn_cell.LSTMCell):
- cell.build(input_shape)
- else:
- raise ValueError("%s not supported" % type(cell))
- def _single_cell(unit_type, num_units, forget_bias, dropout, mode,
- dtype=None, residual_connection=False, residual_fn=None,
- use_block_lstm=False):
- """Create an instance of a single RNN cell."""
- # dropout (= 1 - keep_prob) is set to 0 during eval and infer
- dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
- # Cell Type
- if unit_type == "lstm":
- utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False)
- if not use_block_lstm:
- single_cell = tf.nn.rnn_cell.LSTMCell(
- num_units, dtype=dtype, forget_bias=forget_bias)
- else:
- single_cell = tf.contrib.rnn.LSTMBlockCell(
- num_units, forget_bias=forget_bias)
- elif unit_type == "gru":
- utils.print_out(" GRU", new_line=False)
- single_cell = tf.contrib.rnn.GRUCell(num_units)
- elif unit_type == "layer_norm_lstm":
- utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias,
- new_line=False)
- single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
- num_units,
- forget_bias=forget_bias,
- layer_norm=True)
- elif unit_type == "nas":
- utils.print_out(" NASCell", new_line=False)
- single_cell = tf.contrib.rnn.NASCell(num_units)
- else:
- raise ValueError("Unknown unit type %s!" % unit_type)
- # Dropout (= 1 - keep_prob)
- if dropout > 0.0:
- single_cell = tf.contrib.rnn.DropoutWrapper(
- cell=single_cell, input_keep_prob=(1.0 - dropout))
- utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout),
- new_line=False)
- # Residual
- if residual_connection:
- single_cell = tf.contrib.rnn.ResidualWrapper(
- single_cell, residual_fn=residual_fn)
- utils.print_out(" %s" % type(single_cell).__name__, new_line=False)
- return single_cell
- def _cell_list(unit_type, num_units, num_layers, num_residual_layers,
- forget_bias, dropout, mode, dtype=None,
- single_cell_fn=None, residual_fn=None, use_block_lstm=False):
- """Create a list of RNN cells."""
- if not single_cell_fn:
- single_cell_fn = _single_cell
- # Multi-GPU
- cell_list = []
- for i in range(num_layers):
- utils.print_out(" cell %d" % i, new_line=False)
- single_cell = single_cell_fn(
- unit_type=unit_type,
- num_units=num_units,
- forget_bias=forget_bias,
- dropout=dropout,
- mode=mode,
- dtype=dtype,
- residual_connection=(i >= num_layers - num_residual_layers),
- residual_fn=residual_fn,
- use_block_lstm=use_block_lstm
- )
- utils.print_out("")
- cell_list.append(single_cell)
- return cell_list
- def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers,
- forget_bias, dropout, mode, dtype=None,
- single_cell_fn=None, use_block_lstm=False):
- """Create multi-layer RNN cell.
- Args:
- unit_type: string representing the unit type, i.e. "lstm".
- num_units: the depth of each unit.
- num_layers: number of cells.
- num_residual_layers: Number of residual layers from top to bottom. For
- example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN
- cells in the returned list will be wrapped with `ResidualWrapper`.
- forget_bias: the initial forget bias of the RNNCell(s).
- dropout: floating point value between 0.0 and 1.0:
- the probability of dropout. this is ignored if `mode != TRAIN`.
- mode: either tf.contrib.learn.TRAIN/EVAL/INFER
- single_cell_fn: allow for adding customized cell.
- When not specified, we default to model_helper._single_cell
- Returns:
- An `RNNCell` instance.
- """
- cell_list = _cell_list(unit_type=unit_type,
- num_units=num_units,
- num_layers=num_layers,
- num_residual_layers=num_residual_layers,
- forget_bias=forget_bias,
- dropout=dropout,
- mode=mode,
- dtype=dtype,
- single_cell_fn=single_cell_fn,
- use_block_lstm=use_block_lstm)
- if len(cell_list) == 1: # Single layer.
- return cell_list[0]
- else: # Multi layers
- return tf.contrib.rnn.MultiRNNCell(cell_list)
- def gradient_clip(gradients, max_gradient_norm):
- """Clipping gradients of a model."""
- clipped_gradients, gradient_norm = math_utils.clip_by_global_norm(
- gradients, max_gradient_norm)
- return clipped_gradients, gradient_norm
- def print_variables_in_ckpt(ckpt_path):
- """Print a list of variables in a checkpoint together with their shapes."""
- utils.print_out("# Variables in ckpt %s" % ckpt_path)
- reader = tf.train.NewCheckpointReader(ckpt_path)
- variable_map = reader.get_variable_to_shape_map()
- for key in sorted(variable_map.keys()):
- utils.print_out(" %s: %s" % (key, variable_map[key]))
- def load_model(model, ckpt_path, session, name):
- """Load model from a checkpoint."""
- start_time = time.time()
- try:
- model.saver.restore(session, ckpt_path)
- except tf.errors.NotFoundError as e:
- utils.print_out("Can't load checkpoint")
- print_variables_in_ckpt(ckpt_path)
- utils.print_out("%s" % str(e))
- session.run(tf.tables_initializer())
- utils.print_out(
- " loaded %s model parameters from %s, time %.2fs" %
- (name, ckpt_path, time.time() - start_time))
- return model
- def avg_checkpoints(model_dir, num_last_checkpoints, global_step_name):
- """Average the last N checkpoints in the model_dir."""
- checkpoint_state = tf.train.get_checkpoint_state(model_dir)
- if not checkpoint_state:
- utils.print_out("# No checkpoint file found in directory: %s" % model_dir)
- return None
- # Checkpoints are ordered from oldest to newest.
- checkpoints = (
- checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
- if len(checkpoints) < num_last_checkpoints:
- utils.print_out(
- "# Skipping averaging checkpoints because not enough checkpoints is "
- "available.")
- return None
- avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
- if not tf.gfile.Exists(avg_model_dir):
- utils.print_out(
- "# Creating new directory %s for saving averaged checkpoints." %
- avg_model_dir)
- tf.gfile.MakeDirs(avg_model_dir)
- utils.print_out("# Reading and averaging variables in checkpoints:")
- var_list = tf.contrib.framework.list_variables(checkpoints[0])
- var_values, var_dtypes = {}, {}
- for (name, shape) in var_list:
- if name != global_step_name:
- var_values[name] = np.zeros(shape)
- for checkpoint in checkpoints:
- utils.print_out(" %s" % checkpoint)
- reader = tf.contrib.framework.load_checkpoint(checkpoint)
- for name in var_values:
- tensor = reader.get_tensor(name)
- var_dtypes[name] = tensor.dtype
- var_values[name] += tensor
- for name in var_values:
- var_values[name] /= len(checkpoints)
- # Build a graph with same variables in the checkpoints, and save the averaged
- # variables into the avg_model_dir.
- with tf.Graph().as_default():
- tf_vars = [
- tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
- for v in var_values
- ]
- placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
- assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
- saver = tf.train.Saver(tf.all_variables(), save_relative_paths=True)
- with tf.Session() as sess:
- sess.run(tf.initialize_all_variables())
- for p, assign_op, (name, value) in zip(placeholders, assign_ops,
- six.iteritems(var_values)):
- sess.run(assign_op, {p: value})
- # Use the built saver to save the averaged checkpoint. Only keep 1
- # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
- saver.save(
- sess,
- os.path.join(avg_model_dir, "translate.ckpt"))
- return avg_model_dir
- def create_or_load_model(model, model_dir, session, name):
- """Create translation model and initialize or load parameters in session."""
- latest_ckpt = tf.train.latest_checkpoint(model_dir)
- if latest_ckpt:
- model = load_model(model, latest_ckpt, session, name)
- else:
- start_time = time.time()
- session.run(tf.global_variables_initializer())
- session.run(tf.tables_initializer())
- utils.print_out(" created %s model with fresh parameters, time %.2fs" %
- (name, time.time() - start_time))
- global_step = model.global_step.eval(session=session)
- return model, global_step
|