| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128 |
- # Copyright 2017 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- #
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """TensorFlow NMT model implementation."""
- from __future__ import print_function
- import argparse
- import os
- import random
- import sys
- import subprocess
- # import matplotlib.image as mpimg
- import numpy as np
- import time
- import tensorflow as tf
- import dllogger
- import estimator
- from utils import evaluation_utils
- from utils import iterator_utils
- from utils import misc_utils as utils
- from utils import vocab_utils
- from variable_mgr import constants
- utils.check_tensorflow_version()
- FLAGS = None
- # LINT.IfChange
- def add_arguments(parser):
- """Build ArgumentParser."""
- parser.register("type", "bool", lambda v: v.lower() == "true")
- # network
- parser.add_argument(
- "--num_units", type=int, default=1024, help="Network size.")
- parser.add_argument(
- "--num_layers", type=int, default=4, help="Network depth.")
- parser.add_argument("--num_encoder_layers", type=int, default=None,
- help="Encoder depth, equal to num_layers if None.")
- parser.add_argument("--num_decoder_layers", type=int, default=None,
- help="Decoder depth, equal to num_layers if None.")
- parser.add_argument(
- "--encoder_type",
- type=str,
- default="gnmt",
- help="""\
- uni | bi | gnmt.
- For bi, we build num_encoder_layers/2 bi-directional layers.
- For gnmt, we build 1 bi-directional layer, and (num_encoder_layers - 1)
- uni-directional layers.\
- """)
- parser.add_argument(
- "--residual",
- type="bool",
- nargs="?",
- const=True,
- default=True,
- help="Whether to add residual connections.")
- parser.add_argument("--time_major", type="bool", nargs="?", const=True,
- default=True,
- help="Whether to use time-major mode for dynamic RNN.")
- parser.add_argument("--num_embeddings_partitions", type=int, default=0,
- help="Number of partitions for embedding vars.")
- # attention mechanisms
- parser.add_argument(
- "--attention",
- type=str,
- default="normed_bahdanau",
- help="""\
- luong | scaled_luong | bahdanau | normed_bahdanau or set to "" for no
- attention\
- """)
- parser.add_argument(
- "--attention_architecture",
- type=str,
- default="gnmt_v2",
- help="""\
- standard | gnmt | gnmt_v2.
- standard: use top layer to compute attention.
- gnmt: GNMT style of computing attention, use previous bottom layer to
- compute attention.
- gnmt_v2: similar to gnmt, but use current bottom layer to compute
- attention.\
- """)
- parser.add_argument(
- "--output_attention", type="bool", nargs="?", const=True,
- default=True,
- help="""\
- Only used in standard attention_architecture. Whether use attention as
- the cell output at each timestep.
- .\
- """)
- parser.add_argument(
- "--pass_hidden_state", type="bool", nargs="?", const=True,
- default=True,
- help="""\
- Whether to pass encoder's hidden state to decoder when using an attention
- based model.\
- """)
- # optimizer
- parser.add_argument(
- "--optimizer", type=str, default="adam", help="sgd | adam")
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=5e-4,
- help="Learning rate. Adam: 0.001 | 0.0001")
- parser.add_argument("--warmup_steps", type=int, default=200,
- help="How many steps we inverse-decay learning.")
- parser.add_argument("--warmup_scheme", type=str, default="t2t", help="""\
- How to warmup learning rates. Options include:
- t2t: Tensor2Tensor's way, start with lr 100 times smaller, then
- exponentiate until the specified lr.\
- """)
- parser.add_argument(
- "--decay_scheme", type=str, default="luong234", help="""\
- How we decay learning rate. Options include:
- luong234: after 2/3 num train steps, we start halving the learning rate
- for 4 times before finishing.
- luong5: after 1/2 num train steps, we start halving the learning rate
- for 5 times before finishing.\
- luong10: after 1/2 num train steps, we start halving the learning rate
- for 10 times before finishing.\
- """)
- parser.add_argument(
- "--max_train_epochs", type=int, default=6, help="Max number of epochs.")
- parser.add_argument(
- "--target_bleu", type=float, default=None, help="Target bleu.")
- parser.add_argument("--colocate_gradients_with_ops", type="bool", nargs="?",
- const=True,
- default=True,
- help=("Whether try colocating gradients with "
- "corresponding op"))
- parser.add_argument("--label_smoothing", type=float, default=0.1,
- help=("If nonzero, smooth the labels towards "
- "1/num_classes."))
- # initializer
- parser.add_argument("--init_op", type=str, default="uniform",
- help="uniform | glorot_normal | glorot_uniform")
- parser.add_argument("--init_weight", type=float, default=0.1,
- help=("for uniform init_op, initialize weights "
- "between [-this, this]."))
- # data
- parser.add_argument(
- "--src", type=str, default="en", help="Source suffix, e.g., en.")
- parser.add_argument(
- "--tgt", type=str, default="de", help="Target suffix, e.g., de.")
- parser.add_argument(
- "--data_dir", type=str, default="data/wmt16_de_en",
- help="Training/eval data directory.")
- parser.add_argument(
- "--train_prefix",
- type=str,
- default="train.tok.clean.bpe.32000",
- help="Train prefix, expect files with src/tgt suffixes.")
- parser.add_argument(
- "--test_prefix",
- type=str,
- default="newstest2014.tok.bpe.32000",
- help="Test prefix, expect files with src/tgt suffixes.")
- parser.add_argument(
- "--translate_file",
- type=str,
- help="File to translate, works only with translate mode")
- parser.add_argument(
- "--output_dir", type=str, default="results",
- help="Store log/model files.")
- # Vocab
- parser.add_argument(
- "--vocab_prefix",
- type=str,
- default="vocab.bpe.32000",
- help="""\
- Vocab prefix, expect files with src/tgt suffixes.\
- """)
- parser.add_argument(
- "--embed_prefix",
- type=str,
- default=None,
- help="""\
- Pretrained embedding prefix, expect files with src/tgt suffixes.
- The embedding files should be Glove formatted txt files.\
- """)
- parser.add_argument("--sos", type=str, default="<s>",
- help="Start-of-sentence symbol.")
- parser.add_argument("--eos", type=str, default="</s>",
- help="End-of-sentence symbol.")
- parser.add_argument(
- "--share_vocab",
- type="bool",
- nargs="?",
- const=True,
- default=True,
- help="""\
- Whether to use the source vocab and embeddings for both source and
- target.\
- """)
- parser.add_argument("--check_special_token", type="bool", default=True,
- help="""\
- Whether check special sos, eos, unk tokens exist in the
- vocab files.\
- """)
- # Sequence lengths
- parser.add_argument(
- "--src_max_len",
- type=int,
- default=50,
- help="Max length of src sequences during training (including EOS).")
- parser.add_argument(
- "--tgt_max_len",
- type=int,
- default=50,
- help="Max length of tgt sequences during training (including BOS).")
- parser.add_argument("--src_max_len_infer", type=int, default=None,
- help="Max length of src sequences during inference (including EOS).")
- parser.add_argument("--tgt_max_len_infer", type=int, default=80,
- help="""\
- Max length of tgt sequences during inference (including BOS). Also use to restrict the
- maximum decoding length.\
- """)
- # Default settings works well (rarely need to change)
- parser.add_argument("--unit_type", type=str, default="lstm",
- help="lstm | gru | layer_norm_lstm | nas")
- parser.add_argument("--forget_bias", type=float, default=0.0,
- help="Forget bias for BasicLSTMCell.")
- parser.add_argument("--dropout", type=float, default=0.2,
- help="Dropout rate (not keep_prob)")
- parser.add_argument("--max_gradient_norm", type=float, default=5.0,
- help="Clip gradients to this norm.")
- parser.add_argument("--batch_size", type=int, default=128, help="Total batch size.")
- parser.add_argument(
- "--num_buckets",
- type=int,
- default=5,
- help="Put data into similar-length buckets (only for training).")
- # SPM
- parser.add_argument("--subword_option", type=str, default="bpe",
- choices=["", "bpe", "spm"],
- help="""\
- Set to bpe or spm to activate subword desegmentation.\
- """)
- # Experimental encoding feature.
- parser.add_argument("--use_char_encode", type="bool", default=False,
- help="""\
- Whether to split each word or bpe into character, and then
- generate the word-level representation from the character
- reprentation.
- """)
- # Misc
- parser.add_argument(
- "--save_checkpoints_steps", type=int, default=2000,
- help="save_checkpoints_steps")
- parser.add_argument(
- "--log_step_count_steps", type=int, default=10,
- help=("The frequency, in number of global steps, that the global step "
- "and the loss will be logged during training"))
- parser.add_argument(
- "--num_gpus", type=int, default=1, help="Number of gpus in each worker.")
- parser.add_argument("--hparams_path", type=str, default=None,
- help=("Path to standard hparams json file that overrides"
- "hparams values from FLAGS."))
- parser.add_argument(
- "--random_seed",
- type=int,
- default=1,
- help="Random seed (>0, set a specific seed).")
- parser.add_argument("--language_model", type="bool", nargs="?",
- const=True, default=False,
- help="True to train a language model, ignoring encoder")
- # Inference
- parser.add_argument("--ckpt", type=str, default=None,
- help="Checkpoint file to load a model for inference. (defaults to newest checkpoint)")
- parser.add_argument(
- "--infer_batch_size",
- type=int,
- default=128,
- help="Batch size for inference mode.")
- parser.add_argument("--detokenizer_file", type=str,
- default=None,
- help=("""Detokenizer script file. Default: DATA_DIR/mosesdecoder/scripts/tokenizer/detokenizer.perl"""))
- parser.add_argument("--tokenizer_file", type=str,
- default=None,
- help=("""Tokenizer script file. Default: DATA_DIR/mosesdecoder/scripts/tokenizer/tokenizer.perl"""))
- # Advanced inference arguments
- parser.add_argument("--infer_mode", type=str, default="beam_search",
- choices=["greedy", "beam_search"],
- help="Which type of decoder to use during inference.")
- parser.add_argument("--beam_width", type=int, default=5,
- help=("""\
- beam width when using beam search decoder. If 0, use standard
- decoder with greedy helper.\
- """))
- parser.add_argument(
- "--length_penalty_weight",
- type=float,
- default=0.6,
- help="Length penalty for beam search.")
- parser.add_argument(
- "--coverage_penalty_weight",
- type=float,
- default=0.1,
- help="Coverage penalty for beam search.")
- # Job info
- parser.add_argument("--num_workers", type=int, default=1,
- help="Number of workers (inference only).")
- parser.add_argument("--amp", action='store_true',
- help="use amp for training and inference")
- parser.add_argument("--use_fastmath", type="bool", default=False,
- help="use_fastmath for training and inference")
- parser.add_argument("--use_fp16", type="bool", default=False,
- help="use_fp16 for training and inference")
- parser.add_argument(
- "--fp16_loss_scale",
- type=float,
- default=128,
- help="If fp16 is enabled, the loss is multiplied by this amount "
- "right before gradients are computed, then each gradient "
- "is divided by this amount. Mathematically, this has no "
- "effect, but it helps avoid fp16 underflow. Set to 1 to "
- "effectively disable.")
- parser.add_argument(
- "--enable_auto_loss_scale",
- type="bool",
- default=True,
- help="If True and use_fp16 is True, automatically adjust the "
- "loss scale during training.")
- parser.add_argument(
- "--fp16_inc_loss_scale_every_n",
- type=int,
- default=128,
- help="If fp16 is enabled and enable_auto_loss_scale is "
- "True, increase the loss scale every n steps.")
- parser.add_argument(
- "--check_tower_loss_numerics",
- type="bool",
- default=False, # Set to false for xla.compile()
- help="whether to check tower loss numerics")
- parser.add_argument(
- "--use_fp32_batch_matmul",
- type="bool",
- default=False,
- help="Whether to use fp32 batch matmul")
- # Performance
- # XLA
- parser.add_argument(
- "--force_inputs_padding",
- type="bool",
- default=False,
- help="Force padding input batch to src_max_len and tgt_max_len")
- parser.add_argument(
- "--use_xla",
- type="bool",
- default=False,
- help="Use xla to compile a few selected locations, mostly Defuns.")
- parser.add_argument(
- "--xla_compile",
- type="bool",
- default=False,
- help="Use xla.compile() for each tower's fwd and bak pass.")
- parser.add_argument(
- "--use_autojit_xla",
- type="bool",
- default=False,
- help="Use auto jit xla.")
- # GPU knobs
- parser.add_argument(
- "--use_pintohost_optimizer",
- type="bool",
- default=False,
- help="whether to use PinToHost optimizer")
- parser.add_argument(
- "--use_cudnn_lstm",
- type="bool",
- default=False,
- help="whether to use cudnn_lstm for encoder, non residual layers")
- parser.add_argument(
- "--use_loose_bidi_cudnn_lstm",
- type="bool",
- default=False,
- help="whether to use loose bidi cudnn_lstm")
- parser.add_argument(
- "--use_fused_lstm",
- type="bool",
- default=True,
- help="whether to use fused lstm and variant. If enabled, training will "
- "use LSTMBlockFusedCell, infer will use LSTMBlockCell when appropriate.")
- parser.add_argument(
- "--use_fused_lstm_dec",
- type="bool",
- default=False,
- help="whether to use fused lstm for decoder (training only).")
- parser.add_argument(
- "--gpu_indices",
- type=str,
- default="",
- help="Indices of worker GPUs in ring order")
- # Graph knobs
- parser.add_argument("--parallel_iterations", type=int, default=10,
- help="number of parallel iterations in dynamic_rnn")
- parser.add_argument("--use_dist_strategy", type="bool", default=False,
- help="whether to use distribution strategy")
- parser.add_argument(
- "--hierarchical_copy",
- type="bool",
- default=False,
- help="Use hierarchical copies. Currently only optimized for "
- "use on a DGX-1 with 8 GPUs and may perform poorly on "
- "other hardware. Requires --num_gpus > 1, and only "
- "recommended when --num_gpus=8")
- parser.add_argument(
- "--network_topology",
- type=constants.NetworkTopology,
- default=constants.NetworkTopology.DGX1,
- choices=list(constants.NetworkTopology))
- parser.add_argument(
- "--use_block_lstm",
- type="bool",
- default=False,
- help="whether to use block lstm")
- parser.add_argument(
- "--use_defun",
- type="bool",
- default=False,
- help="whether to use Defun")
- # Gradient tricks
- parser.add_argument(
- "--gradient_repacking",
- type=int,
- default=0,
- help="Use gradient repacking. It"
- "currently only works with replicated mode. At the end of"
- "of each step, it repacks the gradients for more efficient"
- "cross-device transportation. A non-zero value specifies"
- "the number of split packs that will be formed.")
- parser.add_argument(
- "--compact_gradient_transfer",
- type="bool",
- default=True,
- help="Compact gradient as much as possible for cross-device transfer and "
- "aggregation.")
- parser.add_argument(
- "--all_reduce_spec",
- type=str,
- default="nccl",
- help="A specification of the all_reduce algorithm to be used "
- "for reducing gradients. For more details, see "
- "parse_all_reduce_spec in variable_mgr.py. An "
- "all_reduce_spec has BNF form:\n"
- "int ::= positive whole number\n"
- "g_int ::= int[KkMGT]?\n"
- "alg_spec ::= alg | alg#int\n"
- "range_spec ::= alg_spec | alg_spec/alg_spec\n"
- "spec ::= range_spec | range_spec:g_int:range_spec\n"
- "NOTE: not all syntactically correct constructs are "
- "supported.\n\n"
- "Examples:\n "
- "\"xring\" == use one global ring reduction for all "
- "tensors\n"
- "\"pscpu\" == use CPU at worker 0 to reduce all tensors\n"
- "\"nccl\" == use NCCL to locally reduce all tensors. "
- "Limited to 1 worker.\n"
- "\"nccl/xring\" == locally (to one worker) reduce values "
- "using NCCL then ring reduce across workers.\n"
- "\"pscpu:32k:xring\" == use pscpu algorithm for tensors of "
- "size up to 32kB, then xring for larger tensors.")
- parser.add_argument(
- "--agg_small_grads_max_bytes",
- type=int,
- default=0,
- help="If > 0, try to aggregate tensors of less than this "
- "number of bytes prior to all-reduce.")
- parser.add_argument(
- "--agg_small_grads_max_group",
- type=int,
- default=10,
- help="When aggregating small tensors for all-reduce do not "
- "aggregate more than this many into one new tensor.")
- parser.add_argument(
- "--allreduce_merge_scope",
- type=int,
- default=1,
- help="Establish a name scope around this many "
- "gradients prior to creating the all-reduce operations. "
- "It may affect the ability of the backend to merge "
- "parallel ops.")
- # Other knobs
- parser.add_argument(
- "--local_parameter_device",
- type=str,
- default="gpu",
- help="Device to use as parameter server: cpu or gpu. For "
- "distributed training, it can affect where caching of "
- "variables happens.")
- parser.add_argument(
- "--use_resource_vars",
- type="bool",
- default=False,
- help="Use resource variables instead of normal variables. "
- "Resource variables are slower, but this option is useful "
- "for debugging their performance.")
- parser.add_argument("--debug", type="bool", default=False,
- help="Debug train and eval")
- parser.add_argument(
- "--debug_num_train_steps", type=int, default=None, help="Num steps to train.")
- parser.add_argument("--show_metrics", type="bool", default=True,
- help="whether to show detailed metrics")
- parser.add_argument("--clip_grads", type="bool", default=True,
- help="whether to clip gradients")
- parser.add_argument("--profile", type="bool", default=False,
- help="If generate profile")
- parser.add_argument("--profile_save_steps", type=int, default=10,
- help="Save timeline every N steps.")
- parser.add_argument("--use_dynamic_rnn", type="bool", default=True)
- parser.add_argument("--use_synthetic_data", type="bool", default=False)
- parser.add_argument(
- "--mode", type=str, default="train_and_eval",
- choices=("train_and_eval", "infer", "translate"))
- def create_hparams(flags):
- """Create training hparams."""
- return tf.contrib.training.HParams(
- # Data
- src=flags.src,
- tgt=flags.tgt,
- train_prefix=os.path.join(flags.data_dir, flags.train_prefix),
- test_prefix=os.path.join(flags.data_dir, flags.test_prefix),
- translate_file=flags.translate_file,
- vocab_prefix=os.path.join(flags.data_dir, flags.vocab_prefix),
- embed_prefix=flags.embed_prefix,
- output_dir=flags.output_dir,
- # Networks
- num_units=flags.num_units,
- num_encoder_layers=(flags.num_encoder_layers or flags.num_layers),
- num_decoder_layers=(flags.num_decoder_layers or flags.num_layers),
- dropout=flags.dropout,
- unit_type=flags.unit_type,
- encoder_type=flags.encoder_type,
- residual=flags.residual,
- time_major=flags.time_major,
- num_embeddings_partitions=flags.num_embeddings_partitions,
- # Attention mechanisms
- attention=flags.attention,
- attention_architecture=flags.attention_architecture,
- output_attention=flags.output_attention,
- pass_hidden_state=flags.pass_hidden_state,
- # Train
- optimizer=flags.optimizer,
- max_train_epochs=flags.max_train_epochs,
- target_bleu=flags.target_bleu,
- label_smoothing=flags.label_smoothing,
- batch_size=flags.batch_size,
- init_op=flags.init_op,
- init_weight=flags.init_weight,
- max_gradient_norm=flags.max_gradient_norm,
- learning_rate=flags.learning_rate,
- warmup_steps=flags.warmup_steps,
- warmup_scheme=flags.warmup_scheme,
- decay_scheme=flags.decay_scheme,
- colocate_gradients_with_ops=flags.colocate_gradients_with_ops,
- # Data constraints
- num_buckets=flags.num_buckets,
- src_max_len=flags.src_max_len,
- tgt_max_len=flags.tgt_max_len,
- # Inference
- src_max_len_infer=flags.src_max_len_infer,
- tgt_max_len_infer=flags.tgt_max_len_infer,
- ckpt=flags.ckpt,
- infer_batch_size=flags.infer_batch_size,
- detokenizer_file=flags.detokenizer_file if flags.detokenizer_file is not None \
- else os.path.join(flags.data_dir, 'mosesdecoder/scripts/tokenizer/detokenizer.perl'),
- tokenizer_file=flags.tokenizer_file if flags.tokenizer_file is not None \
- else os.path.join(flags.data_dir, 'mosesdecoder/scripts/tokenizer/tokenizer.perl'),
- # Advanced inference arguments
- infer_mode=flags.infer_mode,
- beam_width=flags.beam_width,
- length_penalty_weight=flags.length_penalty_weight,
- coverage_penalty_weight=flags.coverage_penalty_weight,
- # Vocab
- sos=flags.sos if flags.sos else vocab_utils.SOS,
- eos=flags.eos if flags.eos else vocab_utils.EOS,
- subword_option=flags.subword_option,
- check_special_token=flags.check_special_token,
- use_char_encode=flags.use_char_encode,
- # Misc
- forget_bias=flags.forget_bias,
- num_gpus=flags.num_gpus,
- save_checkpoints_steps=flags.save_checkpoints_steps,
- log_step_count_steps=flags.log_step_count_steps,
- epoch_step=0, # record where we were within an epoch.
- share_vocab=flags.share_vocab,
- random_seed=flags.random_seed,
- language_model=flags.language_model,
- amp=flags.amp,
- use_fastmath=flags.use_fastmath,
- use_fp16=flags.use_fp16,
- fp16_loss_scale=flags.fp16_loss_scale,
- enable_auto_loss_scale=flags.enable_auto_loss_scale,
- fp16_inc_loss_scale_every_n=flags.fp16_inc_loss_scale_every_n,
- check_tower_loss_numerics=flags.check_tower_loss_numerics,
- use_fp32_batch_matmul=flags.use_fp32_batch_matmul,
- # Performance
- # GPU knbs
- force_inputs_padding=flags.force_inputs_padding,
- use_xla=flags.use_xla,
- xla_compile=flags.xla_compile,
- use_autojit_xla=flags.use_autojit_xla,
- use_pintohost_optimizer=flags.use_pintohost_optimizer,
- use_cudnn_lstm=flags.use_cudnn_lstm,
- use_loose_bidi_cudnn_lstm=flags.use_loose_bidi_cudnn_lstm,
- use_fused_lstm=flags.use_fused_lstm,
- use_fused_lstm_dec=flags.use_fused_lstm_dec,
- gpu_indices=flags.gpu_indices,
- # Graph knobs
- parallel_iterations=flags.parallel_iterations,
- use_dynamic_rnn=flags.use_dynamic_rnn,
- use_dist_strategy=flags.use_dist_strategy,
- hierarchical_copy=flags.hierarchical_copy,
- network_topology=flags.network_topology,
- use_block_lstm=flags.use_block_lstm,
- # Grad tricks
- gradient_repacking=flags.gradient_repacking,
- compact_gradient_transfer=flags.compact_gradient_transfer,
- all_reduce_spec=flags.all_reduce_spec,
- agg_small_grads_max_bytes=flags.agg_small_grads_max_bytes,
- agg_small_grads_max_group=flags.agg_small_grads_max_group,
- allreduce_merge_scope=flags.allreduce_merge_scope,
- # Other knobs
- local_parameter_device=("cpu" if flags.num_gpus ==0
- else flags.local_parameter_device),
- use_resource_vars=flags.use_resource_vars,
- debug=flags.debug,
- debug_num_train_steps=flags.debug_num_train_steps,
- clip_grads=flags.clip_grads,
- profile=flags.profile,
- profile_save_steps=flags.profile_save_steps,
- show_metrics=flags.show_metrics,
- use_synthetic_data=flags.use_synthetic_data,
- mode=flags.mode,
- )
- def _add_argument(hparams, key, value, update=True):
- """Add an argument to hparams; if exists, change the value if update==True."""
- if hasattr(hparams, key):
- if update:
- setattr(hparams, key, value)
- else:
- hparams.add_hparam(key, value)
- def extend_hparams(hparams):
- """Add new arguments to hparams."""
- # Sanity checks
- if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0:
- raise ValueError("For bi, num_encoder_layers %d should be even" %
- hparams.num_encoder_layers)
- if (hparams.attention_architecture in ["gnmt"] and
- hparams.num_encoder_layers < 2):
- raise ValueError("For gnmt attention architecture, "
- "num_encoder_layers %d should be >= 2" %
- hparams.num_encoder_layers)
- if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]:
- raise ValueError("subword option must be either spm, or bpe")
- if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0:
- raise ValueError("beam_width must greater than 0 when using beam_search"
- "decoder.")
- if hparams.mode == "translate" and not hparams.translate_file:
- raise ValueError("--translate_file flag must be specified in translate mode")
- # Different number of encoder / decoder layers
- assert hparams.num_encoder_layers and hparams.num_decoder_layers
- if hparams.num_encoder_layers != hparams.num_decoder_layers:
- hparams.pass_hidden_state = False
- utils.print_out("Num encoder layer %d is different from num decoder layer"
- " %d, so set pass_hidden_state to False" % (
- hparams.num_encoder_layers,
- hparams.num_decoder_layers))
- # Set residual layers
- num_encoder_residual_layers = 0
- num_decoder_residual_layers = 0
- if hparams.residual:
- if hparams.num_encoder_layers > 1:
- num_encoder_residual_layers = hparams.num_encoder_layers - 1
- if hparams.num_decoder_layers > 1:
- num_decoder_residual_layers = hparams.num_decoder_layers - 1
- if hparams.encoder_type == "gnmt":
- # The first unidirectional layer (after the bi-directional layer) in
- # the GNMT encoder can't have residual connection due to the input is
- # the concatenation of fw_cell and bw_cell's outputs.
- num_encoder_residual_layers = hparams.num_encoder_layers - 2
- # Compatible for GNMT models
- if hparams.num_encoder_layers == hparams.num_decoder_layers:
- num_decoder_residual_layers = num_encoder_residual_layers
- _add_argument(hparams, "num_encoder_residual_layers",
- num_encoder_residual_layers)
- _add_argument(hparams, "num_decoder_residual_layers",
- num_decoder_residual_layers)
- # Language modeling
- if hparams.language_model:
- hparams.attention = ""
- hparams.attention_architecture = ""
- hparams.pass_hidden_state = False
- hparams.share_vocab = True
- hparams.src = hparams.tgt
- utils.print_out("For language modeling, we turn off attention and "
- "pass_hidden_state; turn on share_vocab; set src to tgt.")
- ## Vocab
- # Get vocab file names first
- if hparams.vocab_prefix:
- src_vocab_file = hparams.vocab_prefix + "." + hparams.src
- tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
- else:
- raise ValueError("hparams.vocab_prefix must be provided.")
- # Source vocab
- src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
- src_vocab_file,
- hparams.output_dir,
- check_special_token=hparams.check_special_token,
- sos=hparams.sos,
- eos=hparams.eos,
- unk=vocab_utils.UNK,
- pad_vocab=True)
- # Target vocab
- if hparams.share_vocab:
- utils.print_out(" using source vocab for target")
- tgt_vocab_file = src_vocab_file
- tgt_vocab_size = src_vocab_size
- else:
- tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
- tgt_vocab_file,
- hparams.output_dir,
- check_special_token=hparams.check_special_token,
- sos=hparams.sos,
- eos=hparams.eos,
- unk=vocab_utils.UNK)
- _add_argument(hparams, "src_vocab_size", src_vocab_size)
- _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size)
- _add_argument(hparams, "src_vocab_file", src_vocab_file)
- _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file)
- # Num embedding partitions
- _add_argument(
- hparams, "num_enc_emb_partitions", hparams.num_embeddings_partitions)
- _add_argument(
- hparams, "num_dec_emb_partitions", hparams.num_embeddings_partitions)
- # Pretrained Embeddings
- _add_argument(hparams, "src_embed_file", "")
- _add_argument(hparams, "tgt_embed_file", "")
- if hparams.embed_prefix:
- src_embed_file = hparams.embed_prefix + "." + hparams.src
- tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt
- if tf.gfile.Exists(src_embed_file):
- utils.print_out(" src_embed_file %s exist" % src_embed_file)
- hparams.src_embed_file = src_embed_file
- utils.print_out(
- "For pretrained embeddings, set num_enc_emb_partitions to 1")
- hparams.num_enc_emb_partitions = 1
- else:
- utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file)
- if tf.gfile.Exists(tgt_embed_file):
- utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file)
- hparams.tgt_embed_file = tgt_embed_file
- utils.print_out(
- "For pretrained embeddings, set num_dec_emb_partitions to 1")
- hparams.num_dec_emb_partitions = 1
- else:
- utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file)
- # Evaluation
- metric = "bleu"
- best_metric_dir = os.path.join(hparams.output_dir, "best_" + metric)
- tf.gfile.MakeDirs(best_metric_dir)
- _add_argument(hparams, "best_" + metric, 0, update=False)
- _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir)
- return hparams
- def create_or_load_hparams(default_hparams, hparams_path):
- """Create hparams or load hparams from output_dir."""
- hparams = utils.maybe_parse_standard_hparams(default_hparams, hparams_path)
- hparams = extend_hparams(hparams)
- # Print HParams
- utils.print_hparams(hparams)
- return hparams
- def run_main(flags, default_hparams, estimator_fn):
- """Run main."""
- # Random
- random_seed = flags.random_seed
- if random_seed is not None and random_seed > 0:
- utils.print_out("# Set random seed to %d" % random_seed)
- random.seed(random_seed)
- np.random.seed(random_seed)
- tf.set_random_seed(random_seed)
- # Model output directory
- output_dir = flags.output_dir
- if output_dir and not tf.gfile.Exists(output_dir):
- utils.print_out("# Creating output directory %s ..." % output_dir)
- tf.gfile.MakeDirs(output_dir)
- # Load hparams.
- hparams = create_or_load_hparams(default_hparams, flags.hparams_path)
- # Train or Evaluation
- estimator_fn(hparams)
- return hparams
- def tokenize(hparams, file, tokenized_file):
- utils.print_out("tokenizing {} -> {}".format(file, tokenized_file))
- with open(file, 'rb') as input_file:
- with open(tokenized_file, 'wb') as output_file:
- subprocess.run([hparams.tokenizer_file, '-l', hparams.src], stdin=input_file, stdout=output_file)
- def detokenize(hparams, file, detokenized_file):
- utils.print_out("detokenizing {} -> {}".format(file, detokenized_file))
- with open(file, 'rb') as input_file:
- with open(detokenized_file, 'wb') as output_file:
- subprocess.run([hparams.detokenizer_file, '-l', hparams.tgt], stdin=input_file, stdout=output_file)
- def main(unused_argv):
- experiment_start = time.time()
- tf.logging.set_verbosity(tf.logging.INFO)
- if FLAGS.use_fp16 and FLAGS.use_dist_strategy:
- raise ValueError("use_fp16 and use_dist_strategy aren't compatible")
- if FLAGS.use_fp16 + FLAGS.amp + FLAGS.use_fastmath > 1:
- raise ValueError("Only one of use_fp16, amp, use_fastmath can be set")
- if FLAGS.amp:
- utils.print_out('Enabling TF-AMP')
- os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
- if FLAGS.use_fastmath:
- utils.print_out('Enabling FastMath')
- os.environ["TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32"] = '1'
- os.environ["TF_ENABLE_CUDNN_TENSOR_OP_MATH_FP32"] = '1'
- os.environ["TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32"] = '1'
- # Set up hacky envvars.
- # Hack that affects Defun in attention_wrapper.py
- active_xla_option_nums = np.sum([FLAGS.use_xla, FLAGS.use_autojit_xla,
- FLAGS.xla_compile])
- if active_xla_option_nums > 1:
- raise ValueError(
- "Only one of use_xla, xla_compile, use_autojit_xla can be set")
- os.environ["use_xla"] = str(FLAGS.use_xla).lower()
- if FLAGS.use_xla:
- os.environ["use_defun"] = str(True).lower()
- else:
- os.environ["use_defun"] = str(FLAGS.use_defun).lower()
- utils.print_out("use_defun is %s for attention" % os.environ["use_defun"])
- # TODO(jamesqin): retire this config after Cuda9.1
- os.environ["use_fp32_batch_matmul"] = ("true" if FLAGS.use_fp32_batch_matmul
- else "false")
- os.environ["xla_compile"] = "true" if FLAGS.xla_compile else "false"
- os.environ["force_inputs_padding"] = (
- "true" if FLAGS.force_inputs_padding else "false")
- if FLAGS.mode == "train":
- utils.print_out("Running training mode.")
- default_hparams = create_hparams(FLAGS)
- run_main(FLAGS, default_hparams, estimator.train_fn)
- elif FLAGS.mode == "infer" or FLAGS.mode == "translate":
- if FLAGS.mode == "infer":
- utils.print_out("Running inference mode.")
- translate_mode = False
- else:
- utils.print_out("Running translate mode on file {}.".format(FLAGS.translate_file))
- translate_mode = True
- # Random
- random_seed = FLAGS.random_seed
- if random_seed is not None and random_seed > 0:
- utils.print_out("# Set random seed to %d" % random_seed)
- random.seed(random_seed)
- np.random.seed(random_seed)
- tf.set_random_seed(random_seed)
- # Model output directory
- output_dir = FLAGS.output_dir
- if output_dir and not tf.gfile.Exists(output_dir):
- utils.print_out("# Creating output directory %s ..." % output_dir)
- tf.gfile.MakeDirs(output_dir)
- dllogger.init(backends=[
- dllogger.StdOutBackend(dllogger.Verbosity.DEFAULT),
- dllogger.JSONStreamBackend(dllogger.Verbosity.VERBOSE, os.path.join(FLAGS.output_dir, FLAGS.mode + '-report.json')),
- ])
- dllogger.log('PARAMETER', vars(FLAGS))
- # Load hparams.
- default_hparams = create_hparams(FLAGS)
- default_hparams.num_buckets = 1
- # The estimator model_fn is written in a way allowing train hparams to be
- # passed in infer mode.
- hparams = create_or_load_hparams(default_hparams, FLAGS.hparams_path)
- utils.print_out("infer_hparams:")
- utils.print_hparams(hparams)
- if translate_mode:
- tokenize(hparams, hparams.translate_file, hparams.translate_file + ".tok")
- eval_sentences, eval_src_tokens, _ = iterator_utils.get_effective_epoch_size(hparams, train=False)
- # Run evaluation when there's a new checkpoint
- tf.logging.info("Starting to evaluate...")
- eval_start = time.time()
- _, (eval_speed, eval_latencies), eval_output_tokens = estimator.eval_fn(hparams, hparams.ckpt, only_translate=translate_mode)
- eval_end = time.time()
- eval_delta = eval_end - eval_start
- utils.print_out("eval time for ckpt: %.2f mins (%.2f sent/sec, %.2f tokens/sec)" %
- (eval_delta / 60., eval_speed, eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences), f=sys.stderr)
- logging_data = {
- 'infer_speed_sent': eval_speed,
- 'infer_speed_toks': eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences,
- }
- for lat in sorted(eval_latencies):
- utils.print_out("eval latency_%s for ckpt: %.2f ms" % (lat, eval_latencies[lat] * 1000))
- logging_data['infer_latency_{}'.format(lat)] = eval_latencies[lat] * 1000
- dllogger.log((), logging_data)
- dllogger.flush()
- if translate_mode:
- detokenize(hparams, hparams.translate_file + ".trans.tok", hparams.translate_file + ".trans")
- else:
- assert FLAGS.mode == "train_and_eval"
- utils.print_out("Running train and eval mode.")
- # Random
- random_seed = FLAGS.random_seed
- if random_seed is not None and random_seed > 0:
- utils.print_out("# Set random seed to %d" % random_seed)
- random.seed(random_seed)
- np.random.seed(random_seed)
- tf.set_random_seed(random_seed)
- # Model output directory
- output_dir = FLAGS.output_dir
- if output_dir and not tf.gfile.Exists(output_dir):
- utils.print_out("# Creating output directory %s ..." % output_dir)
- tf.gfile.MakeDirs(output_dir)
- dllogger.init(backends=[
- dllogger.StdOutBackend(dllogger.Verbosity.DEFAULT),
- dllogger.JSONStreamBackend(dllogger.Verbosity.VERBOSE, os.path.join(FLAGS.output_dir, FLAGS.mode + '-report.json')),
- ])
- dllogger.log('PARAMETER', vars(FLAGS))
- dllogger.metadata("bleu", {"unit": None})
- dllogger.metadata("train_speed_sent", {"unit": "sequences/s"})
- dllogger.metadata("train_speed_toks", {"unit": "tokens/s"})
- # Load hparams.
- default_hparams = create_hparams(FLAGS)
- hparams = create_or_load_hparams(default_hparams, FLAGS.hparams_path)
- utils.print_out("training hparams:")
- utils.print_hparams(hparams)
- with tf.gfile.GFile(os.path.join(output_dir, "train_hparams.txt"), "w") as f:
- f.write(utils.serialize_hparams(hparams) + "\n")
- # The estimator model_fn is written in a way allowing train hparams to be
- # passed in infer mode.
- infer_hparams = tf.contrib.training.HParams(**hparams.values())
- infer_hparams.num_buckets = 1
- utils.print_out("infer_hparams:")
- utils.print_hparams(infer_hparams)
- with tf.gfile.GFile(os.path.join(output_dir, "infer_hparams.txt"), "w") as f:
- f.write(utils.serialize_hparams(infer_hparams) + "\n")
- epochs = 0
- should_stop = epochs >= FLAGS.max_train_epochs
- train_sentences, train_src_tokens, train_tgt_tokens = iterator_utils.get_effective_epoch_size(hparams)
- eval_sentences, eval_src_tokens, _ = iterator_utils.get_effective_epoch_size(hparams, train=False)
- while not should_stop:
- utils.print_out("Starting epoch %d" % epochs)
- try:
- train_start = time.time()
- train_speed, _ = estimator.train_fn(hparams)
- except tf.errors.OutOfRangeError:
- utils.print_out("training hits OutOfRangeError", f=sys.stderr)
- train_end = time.time()
- train_delta = train_end - train_start
- utils.print_out("training time for epoch %d: %.2f mins (%.2f sent/sec, %.2f tokens/sec)" %
- (epochs + 1, train_delta / 60., train_speed, train_speed * (train_src_tokens + train_tgt_tokens) / train_sentences), f=sys.stderr)
- logging_data = {
- 'train_speed_sent': train_speed,
- 'train_speed_toks': train_speed * (train_src_tokens + train_tgt_tokens) / train_sentences,
- }
- # This is probably sub-optimal, doing eval per-epoch
- eval_start = time.time()
- bleu_score, (eval_speed, eval_latencies), eval_output_tokens = estimator.eval_fn(infer_hparams)
- eval_end = time.time()
- eval_delta = eval_end - eval_start
- utils.print_out("eval time for epoch %d: %.2f mins (%.2f sent/sec, %.2f tokens/sec)" %
- (epochs + 1, eval_delta / 60., eval_speed, eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences), f=sys.stderr)
- logging_data.update({
- 'bleu': bleu_score,
- 'infer_speed_sent': eval_speed,
- 'infer_speed_toks': eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences,
- })
- for lat in sorted(eval_latencies):
- utils.print_out("eval latency_%s for epoch %d: %.2f ms" % (lat, epochs + 1, eval_latencies[lat] * 1000))
- logging_data['eval_latency_{}'.format(lat)] = eval_latencies[lat] * 1000
- dllogger.log((epochs,), logging_data)
- dllogger.flush()
- if FLAGS.debug or (FLAGS.target_bleu is not None and bleu_score > FLAGS.target_bleu):
- should_stop = True
- utils.print_out(
- "Stop job since target bleu is reached at epoch %d ." % epochs,
- f=sys.stderr)
- epochs += 1
- if epochs >= FLAGS.max_train_epochs:
- should_stop = True
- utils.print_out("Stop job since max_train_epochs is reached.",
- f=sys.stderr)
- dllogger.log((), logging_data)
- dllogger.flush()
- experiment_end = time.time()
- utils.print_out('Experiment took {} min'.format((experiment_end - experiment_start) / 60))
- if __name__ == "__main__":
- nmt_parser = argparse.ArgumentParser()
- add_arguments(nmt_parser)
- FLAGS, unparsed = nmt_parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|