nmt.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. #
  16. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  17. #
  18. # Licensed under the Apache License, Version 2.0 (the "License");
  19. # you may not use this file except in compliance with the License.
  20. # You may obtain a copy of the License at
  21. #
  22. # http://www.apache.org/licenses/LICENSE-2.0
  23. #
  24. # Unless required by applicable law or agreed to in writing, software
  25. # distributed under the License is distributed on an "AS IS" BASIS,
  26. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  27. # See the License for the specific language governing permissions and
  28. # limitations under the License.
  29. """TensorFlow NMT model implementation."""
  30. from __future__ import print_function
  31. import argparse
  32. import os
  33. import random
  34. import sys
  35. import subprocess
  36. # import matplotlib.image as mpimg
  37. import numpy as np
  38. import time
  39. import tensorflow as tf
  40. import dllogger
  41. import estimator
  42. from utils import evaluation_utils
  43. from utils import iterator_utils
  44. from utils import misc_utils as utils
  45. from utils import vocab_utils
  46. from variable_mgr import constants
  47. utils.check_tensorflow_version()
  48. FLAGS = None
  49. # LINT.IfChange
  50. def add_arguments(parser):
  51. """Build ArgumentParser."""
  52. parser.register("type", "bool", lambda v: v.lower() == "true")
  53. # network
  54. parser.add_argument(
  55. "--num_units", type=int, default=1024, help="Network size.")
  56. parser.add_argument(
  57. "--num_layers", type=int, default=4, help="Network depth.")
  58. parser.add_argument("--num_encoder_layers", type=int, default=None,
  59. help="Encoder depth, equal to num_layers if None.")
  60. parser.add_argument("--num_decoder_layers", type=int, default=None,
  61. help="Decoder depth, equal to num_layers if None.")
  62. parser.add_argument(
  63. "--encoder_type",
  64. type=str,
  65. default="gnmt",
  66. help="""\
  67. uni | bi | gnmt.
  68. For bi, we build num_encoder_layers/2 bi-directional layers.
  69. For gnmt, we build 1 bi-directional layer, and (num_encoder_layers - 1)
  70. uni-directional layers.\
  71. """)
  72. parser.add_argument(
  73. "--residual",
  74. type="bool",
  75. nargs="?",
  76. const=True,
  77. default=True,
  78. help="Whether to add residual connections.")
  79. parser.add_argument("--time_major", type="bool", nargs="?", const=True,
  80. default=True,
  81. help="Whether to use time-major mode for dynamic RNN.")
  82. parser.add_argument("--num_embeddings_partitions", type=int, default=0,
  83. help="Number of partitions for embedding vars.")
  84. # attention mechanisms
  85. parser.add_argument(
  86. "--attention",
  87. type=str,
  88. default="normed_bahdanau",
  89. help="""\
  90. luong | scaled_luong | bahdanau | normed_bahdanau or set to "" for no
  91. attention\
  92. """)
  93. parser.add_argument(
  94. "--attention_architecture",
  95. type=str,
  96. default="gnmt_v2",
  97. help="""\
  98. standard | gnmt | gnmt_v2.
  99. standard: use top layer to compute attention.
  100. gnmt: GNMT style of computing attention, use previous bottom layer to
  101. compute attention.
  102. gnmt_v2: similar to gnmt, but use current bottom layer to compute
  103. attention.\
  104. """)
  105. parser.add_argument(
  106. "--output_attention", type="bool", nargs="?", const=True,
  107. default=True,
  108. help="""\
  109. Only used in standard attention_architecture. Whether use attention as
  110. the cell output at each timestep.
  111. .\
  112. """)
  113. parser.add_argument(
  114. "--pass_hidden_state", type="bool", nargs="?", const=True,
  115. default=True,
  116. help="""\
  117. Whether to pass encoder's hidden state to decoder when using an attention
  118. based model.\
  119. """)
  120. # optimizer
  121. parser.add_argument(
  122. "--optimizer", type=str, default="adam", help="sgd | adam")
  123. parser.add_argument(
  124. "--learning_rate",
  125. type=float,
  126. default=5e-4,
  127. help="Learning rate. Adam: 0.001 | 0.0001")
  128. parser.add_argument("--warmup_steps", type=int, default=200,
  129. help="How many steps we inverse-decay learning.")
  130. parser.add_argument("--warmup_scheme", type=str, default="t2t", help="""\
  131. How to warmup learning rates. Options include:
  132. t2t: Tensor2Tensor's way, start with lr 100 times smaller, then
  133. exponentiate until the specified lr.\
  134. """)
  135. parser.add_argument(
  136. "--decay_scheme", type=str, default="luong234", help="""\
  137. How we decay learning rate. Options include:
  138. luong234: after 2/3 num train steps, we start halving the learning rate
  139. for 4 times before finishing.
  140. luong5: after 1/2 num train steps, we start halving the learning rate
  141. for 5 times before finishing.\
  142. luong10: after 1/2 num train steps, we start halving the learning rate
  143. for 10 times before finishing.\
  144. """)
  145. parser.add_argument(
  146. "--max_train_epochs", type=int, default=6, help="Max number of epochs.")
  147. parser.add_argument(
  148. "--target_bleu", type=float, default=None, help="Target bleu.")
  149. parser.add_argument("--colocate_gradients_with_ops", type="bool", nargs="?",
  150. const=True,
  151. default=True,
  152. help=("Whether try colocating gradients with "
  153. "corresponding op"))
  154. parser.add_argument("--label_smoothing", type=float, default=0.1,
  155. help=("If nonzero, smooth the labels towards "
  156. "1/num_classes."))
  157. # initializer
  158. parser.add_argument("--init_op", type=str, default="uniform",
  159. help="uniform | glorot_normal | glorot_uniform")
  160. parser.add_argument("--init_weight", type=float, default=0.1,
  161. help=("for uniform init_op, initialize weights "
  162. "between [-this, this]."))
  163. # data
  164. parser.add_argument(
  165. "--src", type=str, default="en", help="Source suffix, e.g., en.")
  166. parser.add_argument(
  167. "--tgt", type=str, default="de", help="Target suffix, e.g., de.")
  168. parser.add_argument(
  169. "--data_dir", type=str, default="data/wmt16_de_en",
  170. help="Training/eval data directory.")
  171. parser.add_argument(
  172. "--train_prefix",
  173. type=str,
  174. default="train.tok.clean.bpe.32000",
  175. help="Train prefix, expect files with src/tgt suffixes.")
  176. parser.add_argument(
  177. "--test_prefix",
  178. type=str,
  179. default="newstest2014.tok.bpe.32000",
  180. help="Test prefix, expect files with src/tgt suffixes.")
  181. parser.add_argument(
  182. "--translate_file",
  183. type=str,
  184. help="File to translate, works only with translate mode")
  185. parser.add_argument(
  186. "--output_dir", type=str, default="results",
  187. help="Store log/model files.")
  188. # Vocab
  189. parser.add_argument(
  190. "--vocab_prefix",
  191. type=str,
  192. default="vocab.bpe.32000",
  193. help="""\
  194. Vocab prefix, expect files with src/tgt suffixes.\
  195. """)
  196. parser.add_argument(
  197. "--embed_prefix",
  198. type=str,
  199. default=None,
  200. help="""\
  201. Pretrained embedding prefix, expect files with src/tgt suffixes.
  202. The embedding files should be Glove formatted txt files.\
  203. """)
  204. parser.add_argument("--sos", type=str, default="<s>",
  205. help="Start-of-sentence symbol.")
  206. parser.add_argument("--eos", type=str, default="</s>",
  207. help="End-of-sentence symbol.")
  208. parser.add_argument(
  209. "--share_vocab",
  210. type="bool",
  211. nargs="?",
  212. const=True,
  213. default=True,
  214. help="""\
  215. Whether to use the source vocab and embeddings for both source and
  216. target.\
  217. """)
  218. parser.add_argument("--check_special_token", type="bool", default=True,
  219. help="""\
  220. Whether check special sos, eos, unk tokens exist in the
  221. vocab files.\
  222. """)
  223. # Sequence lengths
  224. parser.add_argument(
  225. "--src_max_len",
  226. type=int,
  227. default=50,
  228. help="Max length of src sequences during training (including EOS).")
  229. parser.add_argument(
  230. "--tgt_max_len",
  231. type=int,
  232. default=50,
  233. help="Max length of tgt sequences during training (including BOS).")
  234. parser.add_argument("--src_max_len_infer", type=int, default=None,
  235. help="Max length of src sequences during inference (including EOS).")
  236. parser.add_argument("--tgt_max_len_infer", type=int, default=80,
  237. help="""\
  238. Max length of tgt sequences during inference (including BOS). Also use to restrict the
  239. maximum decoding length.\
  240. """)
  241. # Default settings works well (rarely need to change)
  242. parser.add_argument("--unit_type", type=str, default="lstm",
  243. help="lstm | gru | layer_norm_lstm | nas")
  244. parser.add_argument("--forget_bias", type=float, default=0.0,
  245. help="Forget bias for BasicLSTMCell.")
  246. parser.add_argument("--dropout", type=float, default=0.2,
  247. help="Dropout rate (not keep_prob)")
  248. parser.add_argument("--max_gradient_norm", type=float, default=5.0,
  249. help="Clip gradients to this norm.")
  250. parser.add_argument("--batch_size", type=int, default=128, help="Total batch size.")
  251. parser.add_argument(
  252. "--num_buckets",
  253. type=int,
  254. default=5,
  255. help="Put data into similar-length buckets (only for training).")
  256. # SPM
  257. parser.add_argument("--subword_option", type=str, default="bpe",
  258. choices=["", "bpe", "spm"],
  259. help="""\
  260. Set to bpe or spm to activate subword desegmentation.\
  261. """)
  262. # Experimental encoding feature.
  263. parser.add_argument("--use_char_encode", type="bool", default=False,
  264. help="""\
  265. Whether to split each word or bpe into character, and then
  266. generate the word-level representation from the character
  267. reprentation.
  268. """)
  269. # Misc
  270. parser.add_argument(
  271. "--save_checkpoints_steps", type=int, default=2000,
  272. help="save_checkpoints_steps")
  273. parser.add_argument(
  274. "--log_step_count_steps", type=int, default=10,
  275. help=("The frequency, in number of global steps, that the global step "
  276. "and the loss will be logged during training"))
  277. parser.add_argument(
  278. "--num_gpus", type=int, default=1, help="Number of gpus in each worker.")
  279. parser.add_argument("--hparams_path", type=str, default=None,
  280. help=("Path to standard hparams json file that overrides"
  281. "hparams values from FLAGS."))
  282. parser.add_argument(
  283. "--random_seed",
  284. type=int,
  285. default=1,
  286. help="Random seed (>0, set a specific seed).")
  287. parser.add_argument("--language_model", type="bool", nargs="?",
  288. const=True, default=False,
  289. help="True to train a language model, ignoring encoder")
  290. # Inference
  291. parser.add_argument("--ckpt", type=str, default=None,
  292. help="Checkpoint file to load a model for inference. (defaults to newest checkpoint)")
  293. parser.add_argument(
  294. "--infer_batch_size",
  295. type=int,
  296. default=128,
  297. help="Batch size for inference mode.")
  298. parser.add_argument("--detokenizer_file", type=str,
  299. default=None,
  300. help=("""Detokenizer script file. Default: DATA_DIR/mosesdecoder/scripts/tokenizer/detokenizer.perl"""))
  301. parser.add_argument("--tokenizer_file", type=str,
  302. default=None,
  303. help=("""Tokenizer script file. Default: DATA_DIR/mosesdecoder/scripts/tokenizer/tokenizer.perl"""))
  304. # Advanced inference arguments
  305. parser.add_argument("--infer_mode", type=str, default="beam_search",
  306. choices=["greedy", "beam_search"],
  307. help="Which type of decoder to use during inference.")
  308. parser.add_argument("--beam_width", type=int, default=5,
  309. help=("""\
  310. beam width when using beam search decoder. If 0, use standard
  311. decoder with greedy helper.\
  312. """))
  313. parser.add_argument(
  314. "--length_penalty_weight",
  315. type=float,
  316. default=0.6,
  317. help="Length penalty for beam search.")
  318. parser.add_argument(
  319. "--coverage_penalty_weight",
  320. type=float,
  321. default=0.1,
  322. help="Coverage penalty for beam search.")
  323. # Job info
  324. parser.add_argument("--num_workers", type=int, default=1,
  325. help="Number of workers (inference only).")
  326. parser.add_argument("--amp", action='store_true',
  327. help="use amp for training and inference")
  328. parser.add_argument("--use_fastmath", type="bool", default=False,
  329. help="use_fastmath for training and inference")
  330. parser.add_argument("--use_fp16", type="bool", default=False,
  331. help="use_fp16 for training and inference")
  332. parser.add_argument(
  333. "--fp16_loss_scale",
  334. type=float,
  335. default=128,
  336. help="If fp16 is enabled, the loss is multiplied by this amount "
  337. "right before gradients are computed, then each gradient "
  338. "is divided by this amount. Mathematically, this has no "
  339. "effect, but it helps avoid fp16 underflow. Set to 1 to "
  340. "effectively disable.")
  341. parser.add_argument(
  342. "--enable_auto_loss_scale",
  343. type="bool",
  344. default=True,
  345. help="If True and use_fp16 is True, automatically adjust the "
  346. "loss scale during training.")
  347. parser.add_argument(
  348. "--fp16_inc_loss_scale_every_n",
  349. type=int,
  350. default=128,
  351. help="If fp16 is enabled and enable_auto_loss_scale is "
  352. "True, increase the loss scale every n steps.")
  353. parser.add_argument(
  354. "--check_tower_loss_numerics",
  355. type="bool",
  356. default=False, # Set to false for xla.compile()
  357. help="whether to check tower loss numerics")
  358. parser.add_argument(
  359. "--use_fp32_batch_matmul",
  360. type="bool",
  361. default=False,
  362. help="Whether to use fp32 batch matmul")
  363. # Performance
  364. # XLA
  365. parser.add_argument(
  366. "--force_inputs_padding",
  367. type="bool",
  368. default=False,
  369. help="Force padding input batch to src_max_len and tgt_max_len")
  370. parser.add_argument(
  371. "--use_xla",
  372. type="bool",
  373. default=False,
  374. help="Use xla to compile a few selected locations, mostly Defuns.")
  375. parser.add_argument(
  376. "--xla_compile",
  377. type="bool",
  378. default=False,
  379. help="Use xla.compile() for each tower's fwd and bak pass.")
  380. parser.add_argument(
  381. "--use_autojit_xla",
  382. type="bool",
  383. default=False,
  384. help="Use auto jit xla.")
  385. # GPU knobs
  386. parser.add_argument(
  387. "--use_pintohost_optimizer",
  388. type="bool",
  389. default=False,
  390. help="whether to use PinToHost optimizer")
  391. parser.add_argument(
  392. "--use_cudnn_lstm",
  393. type="bool",
  394. default=False,
  395. help="whether to use cudnn_lstm for encoder, non residual layers")
  396. parser.add_argument(
  397. "--use_loose_bidi_cudnn_lstm",
  398. type="bool",
  399. default=False,
  400. help="whether to use loose bidi cudnn_lstm")
  401. parser.add_argument(
  402. "--use_fused_lstm",
  403. type="bool",
  404. default=True,
  405. help="whether to use fused lstm and variant. If enabled, training will "
  406. "use LSTMBlockFusedCell, infer will use LSTMBlockCell when appropriate.")
  407. parser.add_argument(
  408. "--use_fused_lstm_dec",
  409. type="bool",
  410. default=False,
  411. help="whether to use fused lstm for decoder (training only).")
  412. parser.add_argument(
  413. "--gpu_indices",
  414. type=str,
  415. default="",
  416. help="Indices of worker GPUs in ring order")
  417. # Graph knobs
  418. parser.add_argument("--parallel_iterations", type=int, default=10,
  419. help="number of parallel iterations in dynamic_rnn")
  420. parser.add_argument("--use_dist_strategy", type="bool", default=False,
  421. help="whether to use distribution strategy")
  422. parser.add_argument(
  423. "--hierarchical_copy",
  424. type="bool",
  425. default=False,
  426. help="Use hierarchical copies. Currently only optimized for "
  427. "use on a DGX-1 with 8 GPUs and may perform poorly on "
  428. "other hardware. Requires --num_gpus > 1, and only "
  429. "recommended when --num_gpus=8")
  430. parser.add_argument(
  431. "--network_topology",
  432. type=constants.NetworkTopology,
  433. default=constants.NetworkTopology.DGX1,
  434. choices=list(constants.NetworkTopology))
  435. parser.add_argument(
  436. "--use_block_lstm",
  437. type="bool",
  438. default=False,
  439. help="whether to use block lstm")
  440. parser.add_argument(
  441. "--use_defun",
  442. type="bool",
  443. default=False,
  444. help="whether to use Defun")
  445. # Gradient tricks
  446. parser.add_argument(
  447. "--gradient_repacking",
  448. type=int,
  449. default=0,
  450. help="Use gradient repacking. It"
  451. "currently only works with replicated mode. At the end of"
  452. "of each step, it repacks the gradients for more efficient"
  453. "cross-device transportation. A non-zero value specifies"
  454. "the number of split packs that will be formed.")
  455. parser.add_argument(
  456. "--compact_gradient_transfer",
  457. type="bool",
  458. default=True,
  459. help="Compact gradient as much as possible for cross-device transfer and "
  460. "aggregation.")
  461. parser.add_argument(
  462. "--all_reduce_spec",
  463. type=str,
  464. default="nccl",
  465. help="A specification of the all_reduce algorithm to be used "
  466. "for reducing gradients. For more details, see "
  467. "parse_all_reduce_spec in variable_mgr.py. An "
  468. "all_reduce_spec has BNF form:\n"
  469. "int ::= positive whole number\n"
  470. "g_int ::= int[KkMGT]?\n"
  471. "alg_spec ::= alg | alg#int\n"
  472. "range_spec ::= alg_spec | alg_spec/alg_spec\n"
  473. "spec ::= range_spec | range_spec:g_int:range_spec\n"
  474. "NOTE: not all syntactically correct constructs are "
  475. "supported.\n\n"
  476. "Examples:\n "
  477. "\"xring\" == use one global ring reduction for all "
  478. "tensors\n"
  479. "\"pscpu\" == use CPU at worker 0 to reduce all tensors\n"
  480. "\"nccl\" == use NCCL to locally reduce all tensors. "
  481. "Limited to 1 worker.\n"
  482. "\"nccl/xring\" == locally (to one worker) reduce values "
  483. "using NCCL then ring reduce across workers.\n"
  484. "\"pscpu:32k:xring\" == use pscpu algorithm for tensors of "
  485. "size up to 32kB, then xring for larger tensors.")
  486. parser.add_argument(
  487. "--agg_small_grads_max_bytes",
  488. type=int,
  489. default=0,
  490. help="If > 0, try to aggregate tensors of less than this "
  491. "number of bytes prior to all-reduce.")
  492. parser.add_argument(
  493. "--agg_small_grads_max_group",
  494. type=int,
  495. default=10,
  496. help="When aggregating small tensors for all-reduce do not "
  497. "aggregate more than this many into one new tensor.")
  498. parser.add_argument(
  499. "--allreduce_merge_scope",
  500. type=int,
  501. default=1,
  502. help="Establish a name scope around this many "
  503. "gradients prior to creating the all-reduce operations. "
  504. "It may affect the ability of the backend to merge "
  505. "parallel ops.")
  506. # Other knobs
  507. parser.add_argument(
  508. "--local_parameter_device",
  509. type=str,
  510. default="gpu",
  511. help="Device to use as parameter server: cpu or gpu. For "
  512. "distributed training, it can affect where caching of "
  513. "variables happens.")
  514. parser.add_argument(
  515. "--use_resource_vars",
  516. type="bool",
  517. default=False,
  518. help="Use resource variables instead of normal variables. "
  519. "Resource variables are slower, but this option is useful "
  520. "for debugging their performance.")
  521. parser.add_argument("--debug", type="bool", default=False,
  522. help="Debug train and eval")
  523. parser.add_argument(
  524. "--debug_num_train_steps", type=int, default=None, help="Num steps to train.")
  525. parser.add_argument("--show_metrics", type="bool", default=True,
  526. help="whether to show detailed metrics")
  527. parser.add_argument("--clip_grads", type="bool", default=True,
  528. help="whether to clip gradients")
  529. parser.add_argument("--profile", type="bool", default=False,
  530. help="If generate profile")
  531. parser.add_argument("--profile_save_steps", type=int, default=10,
  532. help="Save timeline every N steps.")
  533. parser.add_argument("--use_dynamic_rnn", type="bool", default=True)
  534. parser.add_argument("--use_synthetic_data", type="bool", default=False)
  535. parser.add_argument(
  536. "--mode", type=str, default="train_and_eval",
  537. choices=("train_and_eval", "infer", "translate"))
  538. def create_hparams(flags):
  539. """Create training hparams."""
  540. return tf.contrib.training.HParams(
  541. # Data
  542. src=flags.src,
  543. tgt=flags.tgt,
  544. train_prefix=os.path.join(flags.data_dir, flags.train_prefix),
  545. test_prefix=os.path.join(flags.data_dir, flags.test_prefix),
  546. translate_file=flags.translate_file,
  547. vocab_prefix=os.path.join(flags.data_dir, flags.vocab_prefix),
  548. embed_prefix=flags.embed_prefix,
  549. output_dir=flags.output_dir,
  550. # Networks
  551. num_units=flags.num_units,
  552. num_encoder_layers=(flags.num_encoder_layers or flags.num_layers),
  553. num_decoder_layers=(flags.num_decoder_layers or flags.num_layers),
  554. dropout=flags.dropout,
  555. unit_type=flags.unit_type,
  556. encoder_type=flags.encoder_type,
  557. residual=flags.residual,
  558. time_major=flags.time_major,
  559. num_embeddings_partitions=flags.num_embeddings_partitions,
  560. # Attention mechanisms
  561. attention=flags.attention,
  562. attention_architecture=flags.attention_architecture,
  563. output_attention=flags.output_attention,
  564. pass_hidden_state=flags.pass_hidden_state,
  565. # Train
  566. optimizer=flags.optimizer,
  567. max_train_epochs=flags.max_train_epochs,
  568. target_bleu=flags.target_bleu,
  569. label_smoothing=flags.label_smoothing,
  570. batch_size=flags.batch_size,
  571. init_op=flags.init_op,
  572. init_weight=flags.init_weight,
  573. max_gradient_norm=flags.max_gradient_norm,
  574. learning_rate=flags.learning_rate,
  575. warmup_steps=flags.warmup_steps,
  576. warmup_scheme=flags.warmup_scheme,
  577. decay_scheme=flags.decay_scheme,
  578. colocate_gradients_with_ops=flags.colocate_gradients_with_ops,
  579. # Data constraints
  580. num_buckets=flags.num_buckets,
  581. src_max_len=flags.src_max_len,
  582. tgt_max_len=flags.tgt_max_len,
  583. # Inference
  584. src_max_len_infer=flags.src_max_len_infer,
  585. tgt_max_len_infer=flags.tgt_max_len_infer,
  586. ckpt=flags.ckpt,
  587. infer_batch_size=flags.infer_batch_size,
  588. detokenizer_file=flags.detokenizer_file if flags.detokenizer_file is not None \
  589. else os.path.join(flags.data_dir, 'mosesdecoder/scripts/tokenizer/detokenizer.perl'),
  590. tokenizer_file=flags.tokenizer_file if flags.tokenizer_file is not None \
  591. else os.path.join(flags.data_dir, 'mosesdecoder/scripts/tokenizer/tokenizer.perl'),
  592. # Advanced inference arguments
  593. infer_mode=flags.infer_mode,
  594. beam_width=flags.beam_width,
  595. length_penalty_weight=flags.length_penalty_weight,
  596. coverage_penalty_weight=flags.coverage_penalty_weight,
  597. # Vocab
  598. sos=flags.sos if flags.sos else vocab_utils.SOS,
  599. eos=flags.eos if flags.eos else vocab_utils.EOS,
  600. subword_option=flags.subword_option,
  601. check_special_token=flags.check_special_token,
  602. use_char_encode=flags.use_char_encode,
  603. # Misc
  604. forget_bias=flags.forget_bias,
  605. num_gpus=flags.num_gpus,
  606. save_checkpoints_steps=flags.save_checkpoints_steps,
  607. log_step_count_steps=flags.log_step_count_steps,
  608. epoch_step=0, # record where we were within an epoch.
  609. share_vocab=flags.share_vocab,
  610. random_seed=flags.random_seed,
  611. language_model=flags.language_model,
  612. amp=flags.amp,
  613. use_fastmath=flags.use_fastmath,
  614. use_fp16=flags.use_fp16,
  615. fp16_loss_scale=flags.fp16_loss_scale,
  616. enable_auto_loss_scale=flags.enable_auto_loss_scale,
  617. fp16_inc_loss_scale_every_n=flags.fp16_inc_loss_scale_every_n,
  618. check_tower_loss_numerics=flags.check_tower_loss_numerics,
  619. use_fp32_batch_matmul=flags.use_fp32_batch_matmul,
  620. # Performance
  621. # GPU knbs
  622. force_inputs_padding=flags.force_inputs_padding,
  623. use_xla=flags.use_xla,
  624. xla_compile=flags.xla_compile,
  625. use_autojit_xla=flags.use_autojit_xla,
  626. use_pintohost_optimizer=flags.use_pintohost_optimizer,
  627. use_cudnn_lstm=flags.use_cudnn_lstm,
  628. use_loose_bidi_cudnn_lstm=flags.use_loose_bidi_cudnn_lstm,
  629. use_fused_lstm=flags.use_fused_lstm,
  630. use_fused_lstm_dec=flags.use_fused_lstm_dec,
  631. gpu_indices=flags.gpu_indices,
  632. # Graph knobs
  633. parallel_iterations=flags.parallel_iterations,
  634. use_dynamic_rnn=flags.use_dynamic_rnn,
  635. use_dist_strategy=flags.use_dist_strategy,
  636. hierarchical_copy=flags.hierarchical_copy,
  637. network_topology=flags.network_topology,
  638. use_block_lstm=flags.use_block_lstm,
  639. # Grad tricks
  640. gradient_repacking=flags.gradient_repacking,
  641. compact_gradient_transfer=flags.compact_gradient_transfer,
  642. all_reduce_spec=flags.all_reduce_spec,
  643. agg_small_grads_max_bytes=flags.agg_small_grads_max_bytes,
  644. agg_small_grads_max_group=flags.agg_small_grads_max_group,
  645. allreduce_merge_scope=flags.allreduce_merge_scope,
  646. # Other knobs
  647. local_parameter_device=("cpu" if flags.num_gpus ==0
  648. else flags.local_parameter_device),
  649. use_resource_vars=flags.use_resource_vars,
  650. debug=flags.debug,
  651. debug_num_train_steps=flags.debug_num_train_steps,
  652. clip_grads=flags.clip_grads,
  653. profile=flags.profile,
  654. profile_save_steps=flags.profile_save_steps,
  655. show_metrics=flags.show_metrics,
  656. use_synthetic_data=flags.use_synthetic_data,
  657. mode=flags.mode,
  658. )
  659. def _add_argument(hparams, key, value, update=True):
  660. """Add an argument to hparams; if exists, change the value if update==True."""
  661. if hasattr(hparams, key):
  662. if update:
  663. setattr(hparams, key, value)
  664. else:
  665. hparams.add_hparam(key, value)
  666. def extend_hparams(hparams):
  667. """Add new arguments to hparams."""
  668. # Sanity checks
  669. if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0:
  670. raise ValueError("For bi, num_encoder_layers %d should be even" %
  671. hparams.num_encoder_layers)
  672. if (hparams.attention_architecture in ["gnmt"] and
  673. hparams.num_encoder_layers < 2):
  674. raise ValueError("For gnmt attention architecture, "
  675. "num_encoder_layers %d should be >= 2" %
  676. hparams.num_encoder_layers)
  677. if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]:
  678. raise ValueError("subword option must be either spm, or bpe")
  679. if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0:
  680. raise ValueError("beam_width must greater than 0 when using beam_search"
  681. "decoder.")
  682. if hparams.mode == "translate" and not hparams.translate_file:
  683. raise ValueError("--translate_file flag must be specified in translate mode")
  684. # Different number of encoder / decoder layers
  685. assert hparams.num_encoder_layers and hparams.num_decoder_layers
  686. if hparams.num_encoder_layers != hparams.num_decoder_layers:
  687. hparams.pass_hidden_state = False
  688. utils.print_out("Num encoder layer %d is different from num decoder layer"
  689. " %d, so set pass_hidden_state to False" % (
  690. hparams.num_encoder_layers,
  691. hparams.num_decoder_layers))
  692. # Set residual layers
  693. num_encoder_residual_layers = 0
  694. num_decoder_residual_layers = 0
  695. if hparams.residual:
  696. if hparams.num_encoder_layers > 1:
  697. num_encoder_residual_layers = hparams.num_encoder_layers - 1
  698. if hparams.num_decoder_layers > 1:
  699. num_decoder_residual_layers = hparams.num_decoder_layers - 1
  700. if hparams.encoder_type == "gnmt":
  701. # The first unidirectional layer (after the bi-directional layer) in
  702. # the GNMT encoder can't have residual connection due to the input is
  703. # the concatenation of fw_cell and bw_cell's outputs.
  704. num_encoder_residual_layers = hparams.num_encoder_layers - 2
  705. # Compatible for GNMT models
  706. if hparams.num_encoder_layers == hparams.num_decoder_layers:
  707. num_decoder_residual_layers = num_encoder_residual_layers
  708. _add_argument(hparams, "num_encoder_residual_layers",
  709. num_encoder_residual_layers)
  710. _add_argument(hparams, "num_decoder_residual_layers",
  711. num_decoder_residual_layers)
  712. # Language modeling
  713. if hparams.language_model:
  714. hparams.attention = ""
  715. hparams.attention_architecture = ""
  716. hparams.pass_hidden_state = False
  717. hparams.share_vocab = True
  718. hparams.src = hparams.tgt
  719. utils.print_out("For language modeling, we turn off attention and "
  720. "pass_hidden_state; turn on share_vocab; set src to tgt.")
  721. ## Vocab
  722. # Get vocab file names first
  723. if hparams.vocab_prefix:
  724. src_vocab_file = hparams.vocab_prefix + "." + hparams.src
  725. tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
  726. else:
  727. raise ValueError("hparams.vocab_prefix must be provided.")
  728. # Source vocab
  729. src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
  730. src_vocab_file,
  731. hparams.output_dir,
  732. check_special_token=hparams.check_special_token,
  733. sos=hparams.sos,
  734. eos=hparams.eos,
  735. unk=vocab_utils.UNK,
  736. pad_vocab=True)
  737. # Target vocab
  738. if hparams.share_vocab:
  739. utils.print_out(" using source vocab for target")
  740. tgt_vocab_file = src_vocab_file
  741. tgt_vocab_size = src_vocab_size
  742. else:
  743. tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
  744. tgt_vocab_file,
  745. hparams.output_dir,
  746. check_special_token=hparams.check_special_token,
  747. sos=hparams.sos,
  748. eos=hparams.eos,
  749. unk=vocab_utils.UNK)
  750. _add_argument(hparams, "src_vocab_size", src_vocab_size)
  751. _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size)
  752. _add_argument(hparams, "src_vocab_file", src_vocab_file)
  753. _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file)
  754. # Num embedding partitions
  755. _add_argument(
  756. hparams, "num_enc_emb_partitions", hparams.num_embeddings_partitions)
  757. _add_argument(
  758. hparams, "num_dec_emb_partitions", hparams.num_embeddings_partitions)
  759. # Pretrained Embeddings
  760. _add_argument(hparams, "src_embed_file", "")
  761. _add_argument(hparams, "tgt_embed_file", "")
  762. if hparams.embed_prefix:
  763. src_embed_file = hparams.embed_prefix + "." + hparams.src
  764. tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt
  765. if tf.gfile.Exists(src_embed_file):
  766. utils.print_out(" src_embed_file %s exist" % src_embed_file)
  767. hparams.src_embed_file = src_embed_file
  768. utils.print_out(
  769. "For pretrained embeddings, set num_enc_emb_partitions to 1")
  770. hparams.num_enc_emb_partitions = 1
  771. else:
  772. utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file)
  773. if tf.gfile.Exists(tgt_embed_file):
  774. utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file)
  775. hparams.tgt_embed_file = tgt_embed_file
  776. utils.print_out(
  777. "For pretrained embeddings, set num_dec_emb_partitions to 1")
  778. hparams.num_dec_emb_partitions = 1
  779. else:
  780. utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file)
  781. # Evaluation
  782. metric = "bleu"
  783. best_metric_dir = os.path.join(hparams.output_dir, "best_" + metric)
  784. tf.gfile.MakeDirs(best_metric_dir)
  785. _add_argument(hparams, "best_" + metric, 0, update=False)
  786. _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir)
  787. return hparams
  788. def create_or_load_hparams(default_hparams, hparams_path):
  789. """Create hparams or load hparams from output_dir."""
  790. hparams = utils.maybe_parse_standard_hparams(default_hparams, hparams_path)
  791. hparams = extend_hparams(hparams)
  792. # Print HParams
  793. utils.print_hparams(hparams)
  794. return hparams
  795. def run_main(flags, default_hparams, estimator_fn):
  796. """Run main."""
  797. # Random
  798. random_seed = flags.random_seed
  799. if random_seed is not None and random_seed > 0:
  800. utils.print_out("# Set random seed to %d" % random_seed)
  801. random.seed(random_seed)
  802. np.random.seed(random_seed)
  803. tf.set_random_seed(random_seed)
  804. # Model output directory
  805. output_dir = flags.output_dir
  806. if output_dir and not tf.gfile.Exists(output_dir):
  807. utils.print_out("# Creating output directory %s ..." % output_dir)
  808. tf.gfile.MakeDirs(output_dir)
  809. # Load hparams.
  810. hparams = create_or_load_hparams(default_hparams, flags.hparams_path)
  811. # Train or Evaluation
  812. estimator_fn(hparams)
  813. return hparams
  814. def tokenize(hparams, file, tokenized_file):
  815. utils.print_out("tokenizing {} -> {}".format(file, tokenized_file))
  816. with open(file, 'rb') as input_file:
  817. with open(tokenized_file, 'wb') as output_file:
  818. subprocess.run([hparams.tokenizer_file, '-l', hparams.src], stdin=input_file, stdout=output_file)
  819. def detokenize(hparams, file, detokenized_file):
  820. utils.print_out("detokenizing {} -> {}".format(file, detokenized_file))
  821. with open(file, 'rb') as input_file:
  822. with open(detokenized_file, 'wb') as output_file:
  823. subprocess.run([hparams.detokenizer_file, '-l', hparams.tgt], stdin=input_file, stdout=output_file)
  824. def main(unused_argv):
  825. experiment_start = time.time()
  826. tf.logging.set_verbosity(tf.logging.INFO)
  827. if FLAGS.use_fp16 and FLAGS.use_dist_strategy:
  828. raise ValueError("use_fp16 and use_dist_strategy aren't compatible")
  829. if FLAGS.use_fp16 + FLAGS.amp + FLAGS.use_fastmath > 1:
  830. raise ValueError("Only one of use_fp16, amp, use_fastmath can be set")
  831. if FLAGS.amp:
  832. utils.print_out('Enabling TF-AMP')
  833. os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
  834. if FLAGS.use_fastmath:
  835. utils.print_out('Enabling FastMath')
  836. os.environ["TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32"] = '1'
  837. os.environ["TF_ENABLE_CUDNN_TENSOR_OP_MATH_FP32"] = '1'
  838. os.environ["TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32"] = '1'
  839. # Set up hacky envvars.
  840. # Hack that affects Defun in attention_wrapper.py
  841. active_xla_option_nums = np.sum([FLAGS.use_xla, FLAGS.use_autojit_xla,
  842. FLAGS.xla_compile])
  843. if active_xla_option_nums > 1:
  844. raise ValueError(
  845. "Only one of use_xla, xla_compile, use_autojit_xla can be set")
  846. os.environ["use_xla"] = str(FLAGS.use_xla).lower()
  847. if FLAGS.use_xla:
  848. os.environ["use_defun"] = str(True).lower()
  849. else:
  850. os.environ["use_defun"] = str(FLAGS.use_defun).lower()
  851. utils.print_out("use_defun is %s for attention" % os.environ["use_defun"])
  852. # TODO(jamesqin): retire this config after Cuda9.1
  853. os.environ["use_fp32_batch_matmul"] = ("true" if FLAGS.use_fp32_batch_matmul
  854. else "false")
  855. os.environ["xla_compile"] = "true" if FLAGS.xla_compile else "false"
  856. os.environ["force_inputs_padding"] = (
  857. "true" if FLAGS.force_inputs_padding else "false")
  858. if FLAGS.mode == "train":
  859. utils.print_out("Running training mode.")
  860. default_hparams = create_hparams(FLAGS)
  861. run_main(FLAGS, default_hparams, estimator.train_fn)
  862. elif FLAGS.mode == "infer" or FLAGS.mode == "translate":
  863. if FLAGS.mode == "infer":
  864. utils.print_out("Running inference mode.")
  865. translate_mode = False
  866. else:
  867. utils.print_out("Running translate mode on file {}.".format(FLAGS.translate_file))
  868. translate_mode = True
  869. # Random
  870. random_seed = FLAGS.random_seed
  871. if random_seed is not None and random_seed > 0:
  872. utils.print_out("# Set random seed to %d" % random_seed)
  873. random.seed(random_seed)
  874. np.random.seed(random_seed)
  875. tf.set_random_seed(random_seed)
  876. # Model output directory
  877. output_dir = FLAGS.output_dir
  878. if output_dir and not tf.gfile.Exists(output_dir):
  879. utils.print_out("# Creating output directory %s ..." % output_dir)
  880. tf.gfile.MakeDirs(output_dir)
  881. dllogger.init(backends=[
  882. dllogger.StdOutBackend(dllogger.Verbosity.DEFAULT),
  883. dllogger.JSONStreamBackend(dllogger.Verbosity.VERBOSE, os.path.join(FLAGS.output_dir, FLAGS.mode + '-report.json')),
  884. ])
  885. dllogger.log('PARAMETER', vars(FLAGS))
  886. # Load hparams.
  887. default_hparams = create_hparams(FLAGS)
  888. default_hparams.num_buckets = 1
  889. # The estimator model_fn is written in a way allowing train hparams to be
  890. # passed in infer mode.
  891. hparams = create_or_load_hparams(default_hparams, FLAGS.hparams_path)
  892. utils.print_out("infer_hparams:")
  893. utils.print_hparams(hparams)
  894. if translate_mode:
  895. tokenize(hparams, hparams.translate_file, hparams.translate_file + ".tok")
  896. eval_sentences, eval_src_tokens, _ = iterator_utils.get_effective_epoch_size(hparams, train=False)
  897. # Run evaluation when there's a new checkpoint
  898. tf.logging.info("Starting to evaluate...")
  899. eval_start = time.time()
  900. _, (eval_speed, eval_latencies), eval_output_tokens = estimator.eval_fn(hparams, hparams.ckpt, only_translate=translate_mode)
  901. eval_end = time.time()
  902. eval_delta = eval_end - eval_start
  903. utils.print_out("eval time for ckpt: %.2f mins (%.2f sent/sec, %.2f tokens/sec)" %
  904. (eval_delta / 60., eval_speed, eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences), f=sys.stderr)
  905. logging_data = {
  906. 'infer_speed_sent': eval_speed,
  907. 'infer_speed_toks': eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences,
  908. }
  909. for lat in sorted(eval_latencies):
  910. utils.print_out("eval latency_%s for ckpt: %.2f ms" % (lat, eval_latencies[lat] * 1000))
  911. logging_data['infer_latency_{}'.format(lat)] = eval_latencies[lat] * 1000
  912. dllogger.log((), logging_data)
  913. dllogger.flush()
  914. if translate_mode:
  915. detokenize(hparams, hparams.translate_file + ".trans.tok", hparams.translate_file + ".trans")
  916. else:
  917. assert FLAGS.mode == "train_and_eval"
  918. utils.print_out("Running train and eval mode.")
  919. # Random
  920. random_seed = FLAGS.random_seed
  921. if random_seed is not None and random_seed > 0:
  922. utils.print_out("# Set random seed to %d" % random_seed)
  923. random.seed(random_seed)
  924. np.random.seed(random_seed)
  925. tf.set_random_seed(random_seed)
  926. # Model output directory
  927. output_dir = FLAGS.output_dir
  928. if output_dir and not tf.gfile.Exists(output_dir):
  929. utils.print_out("# Creating output directory %s ..." % output_dir)
  930. tf.gfile.MakeDirs(output_dir)
  931. dllogger.init(backends=[
  932. dllogger.StdOutBackend(dllogger.Verbosity.DEFAULT),
  933. dllogger.JSONStreamBackend(dllogger.Verbosity.VERBOSE, os.path.join(FLAGS.output_dir, FLAGS.mode + '-report.json')),
  934. ])
  935. dllogger.log('PARAMETER', vars(FLAGS))
  936. dllogger.metadata("bleu", {"unit": None})
  937. dllogger.metadata("train_speed_sent", {"unit": "sequences/s"})
  938. dllogger.metadata("train_speed_toks", {"unit": "tokens/s"})
  939. # Load hparams.
  940. default_hparams = create_hparams(FLAGS)
  941. hparams = create_or_load_hparams(default_hparams, FLAGS.hparams_path)
  942. utils.print_out("training hparams:")
  943. utils.print_hparams(hparams)
  944. with tf.gfile.GFile(os.path.join(output_dir, "train_hparams.txt"), "w") as f:
  945. f.write(utils.serialize_hparams(hparams) + "\n")
  946. # The estimator model_fn is written in a way allowing train hparams to be
  947. # passed in infer mode.
  948. infer_hparams = tf.contrib.training.HParams(**hparams.values())
  949. infer_hparams.num_buckets = 1
  950. utils.print_out("infer_hparams:")
  951. utils.print_hparams(infer_hparams)
  952. with tf.gfile.GFile(os.path.join(output_dir, "infer_hparams.txt"), "w") as f:
  953. f.write(utils.serialize_hparams(infer_hparams) + "\n")
  954. epochs = 0
  955. should_stop = epochs >= FLAGS.max_train_epochs
  956. train_sentences, train_src_tokens, train_tgt_tokens = iterator_utils.get_effective_epoch_size(hparams)
  957. eval_sentences, eval_src_tokens, _ = iterator_utils.get_effective_epoch_size(hparams, train=False)
  958. while not should_stop:
  959. utils.print_out("Starting epoch %d" % epochs)
  960. try:
  961. train_start = time.time()
  962. train_speed, _ = estimator.train_fn(hparams)
  963. except tf.errors.OutOfRangeError:
  964. utils.print_out("training hits OutOfRangeError", f=sys.stderr)
  965. train_end = time.time()
  966. train_delta = train_end - train_start
  967. utils.print_out("training time for epoch %d: %.2f mins (%.2f sent/sec, %.2f tokens/sec)" %
  968. (epochs + 1, train_delta / 60., train_speed, train_speed * (train_src_tokens + train_tgt_tokens) / train_sentences), f=sys.stderr)
  969. logging_data = {
  970. 'train_speed_sent': train_speed,
  971. 'train_speed_toks': train_speed * (train_src_tokens + train_tgt_tokens) / train_sentences,
  972. }
  973. # This is probably sub-optimal, doing eval per-epoch
  974. eval_start = time.time()
  975. bleu_score, (eval_speed, eval_latencies), eval_output_tokens = estimator.eval_fn(infer_hparams)
  976. eval_end = time.time()
  977. eval_delta = eval_end - eval_start
  978. utils.print_out("eval time for epoch %d: %.2f mins (%.2f sent/sec, %.2f tokens/sec)" %
  979. (epochs + 1, eval_delta / 60., eval_speed, eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences), f=sys.stderr)
  980. logging_data.update({
  981. 'bleu': bleu_score,
  982. 'infer_speed_sent': eval_speed,
  983. 'infer_speed_toks': eval_speed * (eval_src_tokens + eval_output_tokens) / eval_sentences,
  984. })
  985. for lat in sorted(eval_latencies):
  986. utils.print_out("eval latency_%s for epoch %d: %.2f ms" % (lat, epochs + 1, eval_latencies[lat] * 1000))
  987. logging_data['eval_latency_{}'.format(lat)] = eval_latencies[lat] * 1000
  988. dllogger.log((epochs,), logging_data)
  989. dllogger.flush()
  990. if FLAGS.debug or (FLAGS.target_bleu is not None and bleu_score > FLAGS.target_bleu):
  991. should_stop = True
  992. utils.print_out(
  993. "Stop job since target bleu is reached at epoch %d ." % epochs,
  994. f=sys.stderr)
  995. epochs += 1
  996. if epochs >= FLAGS.max_train_epochs:
  997. should_stop = True
  998. utils.print_out("Stop job since max_train_epochs is reached.",
  999. f=sys.stderr)
  1000. dllogger.log((), logging_data)
  1001. dllogger.flush()
  1002. experiment_end = time.time()
  1003. utils.print_out('Experiment took {} min'.format((experiment_end - experiment_start) / 60))
  1004. if __name__ == "__main__":
  1005. nmt_parser = argparse.ArgumentParser()
  1006. add_arguments(nmt_parser)
  1007. FLAGS, unparsed = nmt_parser.parse_known_args()
  1008. tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)