model.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. #
  16. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  17. #
  18. # Licensed under the Apache License, Version 2.0 (the "License");
  19. # you may not use this file except in compliance with the License.
  20. # You may obtain a copy of the License at
  21. #
  22. # http://www.apache.org/licenses/LICENSE-2.0
  23. #
  24. # Unless required by applicable law or agreed to in writing, software
  25. # distributed under the License is distributed on an "AS IS" BASIS,
  26. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  27. # See the License for the specific language governing permissions and
  28. # limitations under the License.
  29. """Basic sequence-to-sequence model with dynamic RNN support."""
  30. from __future__ import absolute_import
  31. from __future__ import division
  32. from __future__ import print_function
  33. import abc
  34. import collections
  35. import os
  36. import tensorflow as tf
  37. import numpy as np
  38. from tensorflow.python.framework import function
  39. from tensorflow.python.ops import math_ops
  40. import attention_wrapper
  41. import model_helper
  42. import beam_search_decoder
  43. from utils import iterator_utils
  44. from utils import math_utils
  45. from utils import misc_utils as utils
  46. from utils import vocab_utils
  47. utils.check_tensorflow_version()
  48. __all__ = ["BaseModel"]
  49. def create_attention_mechanism(
  50. num_units, memory, source_sequence_length, dtype=None):
  51. """Create attention mechanism based on the attention_option."""
  52. # Mechanism
  53. attention_mechanism = attention_wrapper.BahdanauAttention(
  54. num_units,
  55. memory,
  56. memory_sequence_length=tf.to_int64(source_sequence_length),
  57. normalize=True, dtype=dtype)
  58. return attention_mechanism
  59. class BaseModel(object):
  60. """Sequence-to-sequence base class.
  61. """
  62. def __init__(self, hparams, mode, features, scope=None, extra_args=None):
  63. """Create the model.
  64. Args:
  65. hparams: Hyperparameter configurations.
  66. mode: TRAIN | EVAL | INFER
  67. features: a dict of input features.
  68. scope: scope of the model.
  69. extra_args: model_helper.ExtraArgs, for passing customizable functions.
  70. """
  71. self.hparams = hparams
  72. # Set params
  73. self._set_params_initializer(hparams, mode, features, scope, extra_args)
  74. # Train graph
  75. res = self.build_graph(hparams, scope=scope)
  76. self._set_train_or_infer(res, hparams)
  77. def _set_params_initializer(self,
  78. hparams,
  79. mode,
  80. features,
  81. scope,
  82. extra_args=None):
  83. """Set various params for self and initialize."""
  84. self.mode = mode
  85. self.src_vocab_size = hparams.src_vocab_size
  86. self.tgt_vocab_size = hparams.tgt_vocab_size
  87. self.features = features
  88. self.time_major = hparams.time_major
  89. if hparams.use_char_encode:
  90. assert (not self.time_major), ("Can't use time major for"
  91. " char-level inputs.")
  92. self.dtype = tf.float16 if hparams.use_fp16 else tf.float32
  93. # extra_args: to make it flexible for adding external customizable code
  94. self.single_cell_fn = None
  95. if extra_args:
  96. self.single_cell_fn = extra_args.single_cell_fn
  97. # Set num units
  98. self.num_units = hparams.num_units
  99. # Set num layers
  100. self.num_encoder_layers = hparams.num_encoder_layers
  101. self.num_decoder_layers = hparams.num_decoder_layers
  102. assert self.num_encoder_layers
  103. assert self.num_decoder_layers
  104. # Set num residual layers
  105. if hasattr(hparams, "num_residual_layers"): # compatible common_test_utils
  106. self.num_encoder_residual_layers = hparams.num_residual_layers
  107. self.num_decoder_residual_layers = hparams.num_residual_layers
  108. else:
  109. self.num_encoder_residual_layers = hparams.num_encoder_residual_layers
  110. self.num_decoder_residual_layers = hparams.num_decoder_residual_layers
  111. # Batch size
  112. self.batch_size = tf.size(self.features["source_sequence_length"])
  113. # Global step
  114. global_step = tf.train.get_global_step()
  115. if global_step is not None:
  116. utils.print_out("global_step already created!")
  117. self.global_step = tf.train.get_or_create_global_step()
  118. utils.print_out("model.global_step.name: %s" % self.global_step.name)
  119. # Initializer
  120. self.random_seed = hparams.random_seed
  121. initializer = model_helper.get_initializer(
  122. hparams.init_op, self.random_seed, hparams.init_weight)
  123. tf.get_variable_scope().set_initializer(initializer)
  124. # Embeddings
  125. self.encoder_emb_lookup_fn = tf.nn.embedding_lookup
  126. self.init_embeddings(hparams, scope)
  127. def _set_train_or_infer(self, res, hparams):
  128. """Set up training."""
  129. loss = res[1]
  130. if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
  131. self.train_loss = loss
  132. self.word_count = tf.reduce_sum(
  133. self.features["source_sequence_length"]) + tf.reduce_sum(
  134. self.features["target_sequence_length"])
  135. elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
  136. self.eval_loss = loss
  137. elif self.mode == tf.contrib.learn.ModeKeys.INFER:
  138. self.infer_logits = res[0]
  139. self.infer_loss = loss
  140. self.sample_id = res[2]
  141. if self.mode != tf.contrib.learn.ModeKeys.INFER:
  142. ## Count the number of predicted words for compute ppl.
  143. self.predict_count = tf.reduce_sum(
  144. self.features["target_sequence_length"])
  145. # Gradients and SGD update operation for training the model.
  146. # Arrange for the embedding vars to appear at the beginning.
  147. # Only build bprop if running on GPU and using dist_strategy, in which
  148. # case learning rate, grads and train_op are created in estimator model
  149. # function.
  150. with tf.name_scope("learning_rate"):
  151. self.learning_rate = tf.constant(hparams.learning_rate)
  152. # warm-up
  153. self.learning_rate = self._get_learning_rate_warmup(hparams)
  154. # decay
  155. self.learning_rate = self._get_learning_rate_decay(hparams)
  156. if (hparams.use_dist_strategy and
  157. self.mode == tf.contrib.learn.ModeKeys.TRAIN):
  158. # Gradients
  159. params = tf.trainable_variables()
  160. # Print trainable variables
  161. utils.print_out("# Trainable variables")
  162. utils.print_out(
  163. "Format: <name>, <shape>, <dtype>, <(soft) device placement>")
  164. for param in params:
  165. utils.print_out(
  166. " %s, %s, %s, %s" % (param.name, str(param.get_shape()),
  167. param.dtype.name, param.op.device))
  168. utils.print_out("Total params size: %.2f GB" % (4. * np.sum([
  169. p.get_shape().num_elements()
  170. for p in params
  171. if p.shape.is_fully_defined()
  172. ]) / 2**30))
  173. # Optimizer
  174. if hparams.optimizer == "sgd":
  175. opt = tf.train.GradientDescentOptimizer(self.learning_rate)
  176. elif hparams.optimizer == "adam":
  177. opt = tf.train.AdamOptimizer(self.learning_rate)
  178. else:
  179. raise ValueError("Unknown optimizer type %s" % hparams.optimizer)
  180. assert opt is not None
  181. grads_and_vars = opt.compute_gradients(
  182. self.train_loss,
  183. params,
  184. colocate_gradients_with_ops=hparams.colocate_gradients_with_ops)
  185. gradients = [x for (x, _) in grads_and_vars]
  186. clipped_grads, grad_norm = model_helper.gradient_clip(
  187. gradients, max_gradient_norm=hparams.max_gradient_norm)
  188. self.grad_norm = grad_norm
  189. self.params = params
  190. self.grads = clipped_grads
  191. self.update = opt.apply_gradients(
  192. list(zip(clipped_grads, params)), global_step=self.global_step)
  193. else:
  194. self.grad_norm = None
  195. self.update = None
  196. self.params = None
  197. self.grads = None
  198. def _get_learning_rate_warmup(self, hparams):
  199. """Get learning rate warmup."""
  200. warmup_steps = hparams.warmup_steps
  201. warmup_scheme = hparams.warmup_scheme
  202. utils.print_out(" learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" %
  203. (hparams.learning_rate, warmup_steps, warmup_scheme))
  204. if not warmup_scheme:
  205. return self.learning_rate
  206. # Apply inverse decay if global steps less than warmup steps.
  207. # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3)
  208. # When step < warmup_steps,
  209. # learing_rate *= warmup_factor ** (warmup_steps - step)
  210. if warmup_scheme == "t2t":
  211. # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller
  212. warmup_factor = tf.exp(tf.log(0.01) / warmup_steps)
  213. inv_decay = warmup_factor**(tf.to_float(warmup_steps - self.global_step))
  214. else:
  215. raise ValueError("Unknown warmup scheme %s" % warmup_scheme)
  216. return tf.cond(
  217. self.global_step < hparams.warmup_steps,
  218. lambda: inv_decay * self.learning_rate,
  219. lambda: self.learning_rate,
  220. name="learning_rate_warump_cond")
  221. def _get_decay_info(self, hparams):
  222. """Return decay info based on decay_scheme."""
  223. if hparams.decay_scheme in [
  224. "luong5", "luong10", "luong234", "jamesqin1616"
  225. ]:
  226. epoch_size, _, _ = iterator_utils.get_effective_epoch_size(hparams)
  227. num_train_steps = int(hparams.max_train_epochs * epoch_size / hparams.batch_size)
  228. decay_factor = 0.5
  229. if hparams.decay_scheme == "luong5":
  230. start_decay_step = int(num_train_steps / 2)
  231. decay_times = 5
  232. remain_steps = num_train_steps - start_decay_step
  233. elif hparams.decay_scheme == "luong10":
  234. start_decay_step = int(num_train_steps / 2)
  235. decay_times = 10
  236. remain_steps = num_train_steps - start_decay_step
  237. elif hparams.decay_scheme == "luong234":
  238. start_decay_step = int(num_train_steps * 2 / 3)
  239. decay_times = 4
  240. remain_steps = num_train_steps - start_decay_step
  241. elif hparams.decay_scheme == "jamesqin1616":
  242. # dehao@ reported TPU setting max_epoch = 2 and use luong234.
  243. # They start decay after 2 * 2/3 epochs for 4 times.
  244. # If keep max_epochs = 8 then decay should start at 8 * 2/(3 * 4) epochs
  245. # and for (4 *4 = 16) times.
  246. decay_times = 16
  247. start_decay_step = int(num_train_steps / 16.)
  248. remain_steps = num_train_steps - start_decay_step
  249. decay_steps = int(remain_steps / decay_times)
  250. elif not hparams.decay_scheme: # no decay
  251. start_decay_step = num_train_steps
  252. decay_steps = 0
  253. decay_factor = 1.0
  254. elif hparams.decay_scheme:
  255. raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme)
  256. return start_decay_step, decay_steps, decay_factor
  257. def _get_learning_rate_decay(self, hparams):
  258. """Get learning rate decay."""
  259. start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams)
  260. utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, "
  261. "decay_factor %g" % (hparams.decay_scheme, start_decay_step,
  262. decay_steps, decay_factor))
  263. return tf.cond(
  264. self.global_step < start_decay_step,
  265. lambda: self.learning_rate,
  266. lambda: tf.train.exponential_decay( # pylint: disable=g-long-lambda
  267. self.learning_rate,
  268. (self.global_step - start_decay_step),
  269. decay_steps, decay_factor, staircase=True),
  270. name="learning_rate_decay_cond")
  271. def init_embeddings(self, hparams, scope):
  272. """Init embeddings."""
  273. self.embedding_encoder, self.embedding_decoder = (
  274. model_helper.create_emb_for_encoder_and_decoder(
  275. share_vocab=hparams.share_vocab,
  276. src_vocab_size=self.src_vocab_size,
  277. tgt_vocab_size=self.tgt_vocab_size,
  278. src_embed_size=self.num_units,
  279. tgt_embed_size=self.num_units,
  280. dtype=self.dtype,
  281. num_enc_partitions=hparams.num_enc_emb_partitions,
  282. num_dec_partitions=hparams.num_dec_emb_partitions,
  283. src_vocab_file=hparams.src_vocab_file,
  284. tgt_vocab_file=hparams.tgt_vocab_file,
  285. src_embed_file=hparams.src_embed_file,
  286. tgt_embed_file=hparams.tgt_embed_file,
  287. use_char_encode=hparams.use_char_encode,
  288. scope=scope,
  289. ))
  290. def build_graph(self, hparams, scope=None):
  291. """Subclass must implement this method.
  292. Creates a sequence-to-sequence model with dynamic RNN decoder API.
  293. Args:
  294. hparams: Hyperparameter configurations.
  295. scope: VariableScope for the created subgraph; default "dynamic_seq2seq".
  296. Returns:
  297. A tuple of the form (logits, loss_tuple, final_context_state, sample_id),
  298. where:
  299. logits: float32 Tensor [batch_size x num_decoder_symbols].
  300. loss: loss = the total loss / batch_size.
  301. final_context_state: the final state of decoder RNN.
  302. sample_id: sampling indices.
  303. Raises:
  304. ValueError: if encoder_type differs from mono and bi, or
  305. attention_option is not (luong | scaled_luong |
  306. bahdanau | normed_bahdanau).
  307. """
  308. utils.print_out("# Creating %s graph ..." % self.mode)
  309. # Projection
  310. with tf.variable_scope(scope or "build_network"):
  311. with tf.variable_scope("decoder/output_projection"):
  312. self.output_layer = tf.layers.Dense(
  313. self.tgt_vocab_size, use_bias=False, name="output_projection",
  314. dtype=self.dtype)
  315. with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
  316. # Encoder
  317. if hparams.language_model: # no encoder for language modeling
  318. utils.print_out(" language modeling: no encoder")
  319. self.encoder_outputs = None
  320. encoder_state = None
  321. else:
  322. self.encoder_outputs, encoder_state = self._build_encoder(hparams)
  323. ## Decoder
  324. logits, sample_id = (
  325. self._build_decoder(self.encoder_outputs, encoder_state, hparams))
  326. ## Loss
  327. if self.mode != tf.contrib.learn.ModeKeys.INFER:
  328. loss = self._compute_loss(logits, hparams.label_smoothing)
  329. else:
  330. loss = tf.constant(0.0)
  331. return logits, loss, sample_id
  332. @abc.abstractmethod
  333. def _build_encoder(self, hparams):
  334. """Subclass must implement this.
  335. Build and run an RNN encoder.
  336. Args:
  337. hparams: Hyperparameters configurations.
  338. Returns:
  339. A tuple of encoder_outputs and encoder_state.
  340. """
  341. pass
  342. def _get_infer_maximum_iterations(self, hparams, source_sequence_length):
  343. """Maximum decoding steps at inference time."""
  344. if hparams.tgt_max_len_infer:
  345. maximum_iterations = hparams.tgt_max_len_infer
  346. utils.print_out(" decoding maximum_iterations %d" % maximum_iterations)
  347. else:
  348. # TODO(thangluong): add decoding_length_factor flag
  349. decoding_length_factor = 2.0
  350. max_encoder_length = tf.reduce_max(source_sequence_length)
  351. maximum_iterations = tf.to_int32(
  352. tf.round(tf.to_float(max_encoder_length) * decoding_length_factor))
  353. return maximum_iterations
  354. def _build_decoder(self, encoder_outputs, encoder_state, hparams):
  355. """Build and run a RNN decoder with a final projection layer.
  356. Args:
  357. encoder_outputs: The outputs of encoder for every time step.
  358. encoder_state: The final state of the encoder.
  359. hparams: The Hyperparameters configurations.
  360. Returns:
  361. A tuple of final logits and final decoder state:
  362. logits: size [time, batch_size, vocab_size] when time_major=True.
  363. """
  364. ## Decoder.
  365. with tf.variable_scope("decoder") as decoder_scope:
  366. ## Train or eval
  367. if self.mode != tf.contrib.learn.ModeKeys.INFER:
  368. # [batch, time]
  369. target_input = self.features["target_input"]
  370. if self.time_major:
  371. # If using time_major mode, then target_input should be [time, batch]
  372. # then the decoder_emb_inp would be [time, batch, dim]
  373. target_input = tf.transpose(target_input)
  374. decoder_emb_inp = tf.cast(
  375. tf.nn.embedding_lookup(self.embedding_decoder, target_input),
  376. self.dtype)
  377. if not hparams.use_fused_lstm_dec:
  378. cell, decoder_initial_state = self._build_decoder_cell(
  379. hparams, encoder_outputs, encoder_state,
  380. self.features["source_sequence_length"])
  381. if hparams.use_dynamic_rnn:
  382. final_rnn_outputs, _ = tf.nn.dynamic_rnn(
  383. cell,
  384. decoder_emb_inp,
  385. sequence_length=self.features["target_sequence_length"],
  386. initial_state=decoder_initial_state,
  387. dtype=self.dtype,
  388. scope=decoder_scope,
  389. parallel_iterations=hparams.parallel_iterations,
  390. time_major=self.time_major)
  391. else:
  392. final_rnn_outputs, _ = tf.contrib.recurrent.functional_rnn(
  393. cell,
  394. decoder_emb_inp,
  395. sequence_length=tf.to_int32(
  396. self.features["target_sequence_length"]),
  397. initial_state=decoder_initial_state,
  398. dtype=self.dtype,
  399. scope=decoder_scope,
  400. time_major=self.time_major,
  401. use_tpu=False)
  402. else:
  403. if hparams.pass_hidden_state:
  404. decoder_initial_state = encoder_state
  405. else:
  406. decoder_initial_state = tuple((tf.nn.rnn_cell.LSTMStateTuple(
  407. tf.zeros_like(s[0]), tf.zeros_like(s[1])) for s in encoder_state))
  408. final_rnn_outputs = self._build_decoder_fused_for_training(
  409. encoder_outputs, decoder_initial_state, decoder_emb_inp, self.hparams)
  410. # We chose to apply the output_layer to all timesteps for speed:
  411. # 10% improvements for small models & 20% for larger ones.
  412. # If memory is a concern, we should apply output_layer per timestep.
  413. logits = self.output_layer(final_rnn_outputs)
  414. sample_id = None
  415. ## Inference
  416. else:
  417. cell, decoder_initial_state = self._build_decoder_cell(
  418. hparams, encoder_outputs, encoder_state,
  419. self.features["source_sequence_length"])
  420. assert hparams.infer_mode == "beam_search"
  421. _, tgt_vocab_table = vocab_utils.create_vocab_tables(
  422. hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab)
  423. tgt_sos_id = tf.cast(
  424. tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32)
  425. tgt_eos_id = tf.cast(
  426. tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32)
  427. start_tokens = tf.fill([self.batch_size], tgt_sos_id)
  428. end_token = tgt_eos_id
  429. beam_width = hparams.beam_width
  430. length_penalty_weight = hparams.length_penalty_weight
  431. coverage_penalty_weight = hparams.coverage_penalty_weight
  432. my_decoder = beam_search_decoder.BeamSearchDecoder(
  433. cell=cell,
  434. embedding=self.embedding_decoder,
  435. start_tokens=start_tokens,
  436. end_token=end_token,
  437. initial_state=decoder_initial_state,
  438. beam_width=beam_width,
  439. output_layer=self.output_layer,
  440. length_penalty_weight=length_penalty_weight,
  441. coverage_penalty_weight=coverage_penalty_weight)
  442. # maximum_iteration: The maximum decoding steps.
  443. maximum_iterations = self._get_infer_maximum_iterations(
  444. hparams, self.features["source_sequence_length"])
  445. # Dynamic decoding
  446. outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
  447. my_decoder,
  448. maximum_iterations=maximum_iterations,
  449. output_time_major=self.time_major,
  450. swap_memory=True,
  451. scope=decoder_scope)
  452. logits = tf.no_op()
  453. sample_id = outputs.predicted_ids
  454. return logits, sample_id
  455. def get_max_time(self, tensor):
  456. time_axis = 0 if self.time_major else 1
  457. return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis]
  458. @abc.abstractmethod
  459. def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
  460. source_sequence_length):
  461. """Subclass must implement this.
  462. Args:
  463. hparams: Hyperparameters configurations.
  464. encoder_outputs: The outputs of encoder for every time step.
  465. encoder_state: The final state of the encoder.
  466. source_sequence_length: sequence length of encoder_outputs.
  467. Returns:
  468. A tuple of a multi-layer RNN cell used by decoder and the initial state of
  469. the decoder RNN.
  470. """
  471. pass
  472. def _softmax_cross_entropy_loss(self, logits, labels, label_smoothing):
  473. """Compute softmax loss or sampled softmax loss."""
  474. use_defun = os.environ["use_defun"] == "true"
  475. use_xla = os.environ["use_xla"] == "true"
  476. # @function.Defun(noinline=True, compiled=use_xla)
  477. def ComputePositiveCrossent(labels, logits):
  478. crossent = math_utils.sparse_softmax_crossent_with_logits(
  479. labels=labels, logits=logits)
  480. return crossent
  481. crossent = ComputePositiveCrossent(labels, logits)
  482. assert crossent.dtype == tf.float32
  483. def _safe_shape_div(x, y):
  484. """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
  485. return x // tf.maximum(y, 1)
  486. @function.Defun(tf.float32, tf.float32, compiled=use_xla)
  487. def ReduceSumGrad(x, grad):
  488. """docstring."""
  489. input_shape = tf.shape(x)
  490. # TODO(apassos) remove this once device placement for eager ops makes more
  491. # sense.
  492. with tf.colocate_with(input_shape):
  493. output_shape_kept_dims = math_ops.reduced_shape(input_shape, -1)
  494. tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
  495. grad = tf.reshape(grad, output_shape_kept_dims)
  496. return tf.tile(grad, tile_scaling)
  497. def ReduceSum(x):
  498. """docstring."""
  499. return tf.reduce_sum(x, axis=-1)
  500. if use_defun:
  501. ReduceSum = function.Defun(
  502. tf.float32,
  503. compiled=use_xla,
  504. noinline=True,
  505. grad_func=ReduceSumGrad)(ReduceSum)
  506. if abs(label_smoothing) > 1e-3:
  507. # pylint:disable=invalid-name
  508. def ComputeNegativeCrossentFwd(logits):
  509. """docstring."""
  510. # [time, batch, dim]
  511. # [time, batch]
  512. max_logits = tf.reduce_max(logits, axis=-1)
  513. # [time, batch, dim]
  514. shifted_logits = logits - tf.expand_dims(max_logits, axis=-1)
  515. # Always compute loss in fp32
  516. shifted_logits = tf.to_float(shifted_logits)
  517. # [time, batch]
  518. log_sum_exp = tf.log(ReduceSum(tf.exp(shifted_logits)))
  519. # [time, batch, dim] - [time, batch, 1] --> reduce_sum(-1) -->
  520. # [time, batch]
  521. neg_crossent = ReduceSum(
  522. shifted_logits - tf.expand_dims(log_sum_exp, axis=-1))
  523. return neg_crossent
  524. def ComputeNegativeCrossent(logits):
  525. return ComputeNegativeCrossentFwd(logits)
  526. if use_defun:
  527. ComputeNegativeCrossent = function.Defun(
  528. compiled=use_xla)(ComputeNegativeCrossent)
  529. neg_crossent = ComputeNegativeCrossent(logits)
  530. neg_crossent = tf.to_float(neg_crossent)
  531. num_labels = logits.shape[-1].value
  532. crossent = (1.0 - label_smoothing) * crossent - (
  533. label_smoothing / tf.to_float(num_labels) * neg_crossent)
  534. # pylint:enable=invalid-name
  535. return crossent
  536. def _compute_loss(self, logits, label_smoothing):
  537. """Compute optimization loss."""
  538. target_output = self.features["target_output"]
  539. if self.time_major:
  540. target_output = tf.transpose(target_output)
  541. max_time = self.get_max_time(target_output)
  542. self.batch_seq_len = max_time
  543. crossent = self._softmax_cross_entropy_loss(
  544. logits, target_output, label_smoothing)
  545. assert crossent.dtype == tf.float32
  546. target_weights = tf.sequence_mask(
  547. self.features["target_sequence_length"], max_time, dtype=crossent.dtype)
  548. if self.time_major:
  549. # [time, batch] if time_major, since the crossent is [time, batch] in this
  550. # case.
  551. target_weights = tf.transpose(target_weights)
  552. loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(
  553. self.batch_size)
  554. return loss
  555. def build_encoder_states(self, include_embeddings=False):
  556. """Stack encoder states and return tensor [batch, length, layer, size]."""
  557. assert self.mode == tf.contrib.learn.ModeKeys.INFER
  558. if include_embeddings:
  559. stack_state_list = tf.stack(
  560. [self.encoder_emb_inp] + self.encoder_state_list, 2)
  561. else:
  562. stack_state_list = tf.stack(self.encoder_state_list, 2)
  563. # transform from [length, batch, ...] -> [batch, length, ...]
  564. if self.time_major:
  565. stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3])
  566. return stack_state_list