fused_layer_norm.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import collections
  17. import copy
  18. import json
  19. import math
  20. import re
  21. import six
  22. import tensorflow as tf
  23. from tensorflow.python.framework import ops
  24. from tensorflow.contrib.layers.python.layers import utils
  25. from tensorflow.contrib.framework.python.ops import variables
  26. from tensorflow.python.ops import init_ops
  27. import numpy
  28. from tensorflow.python.ops import array_ops
  29. from tensorflow.python.framework import dtypes
  30. from tensorflow.python.ops import nn
  31. def fused_layer_norm(inputs,
  32. center=True,
  33. scale=True,
  34. activation_fn=None,
  35. reuse=None,
  36. variables_collections=None,
  37. outputs_collections=None,
  38. trainable=True,
  39. begin_norm_axis=1,
  40. begin_params_axis=-1,
  41. scope=None,
  42. use_fused_batch_norm=False):
  43. with tf.variable_scope(
  44. scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
  45. inputs = ops.convert_to_tensor(inputs)
  46. inputs_shape = inputs.shape
  47. inputs_rank = inputs_shape.ndims
  48. if inputs_rank is None:
  49. raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  50. dtype = inputs.dtype.base_dtype
  51. if begin_norm_axis < 0:
  52. begin_norm_axis = inputs_rank + begin_norm_axis
  53. if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
  54. raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
  55. 'must be < rank(inputs) (%d)' %
  56. (begin_params_axis, begin_norm_axis, inputs_rank))
  57. params_shape = inputs_shape[begin_params_axis:]
  58. if not params_shape.is_fully_defined():
  59. raise ValueError(
  60. 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
  61. (inputs.name, begin_params_axis, inputs_shape))
  62. # Allocate parameters for the beta and gamma of the normalization.
  63. beta, gamma = None, None
  64. if center:
  65. beta_collections = utils.get_variable_collections(variables_collections,
  66. 'beta')
  67. beta = variables.model_variable(
  68. 'beta',
  69. shape=params_shape,
  70. dtype=dtype,
  71. initializer=init_ops.zeros_initializer(),
  72. collections=beta_collections,
  73. trainable=trainable)
  74. if scale:
  75. gamma_collections = utils.get_variable_collections(
  76. variables_collections, 'gamma')
  77. gamma = variables.model_variable(
  78. 'gamma',
  79. shape=params_shape,
  80. dtype=dtype,
  81. initializer=init_ops.ones_initializer(),
  82. collections=gamma_collections,
  83. trainable=trainable)
  84. if use_fused_batch_norm:
  85. # get static TensorShape if fully defined,
  86. # otherwise retrieve shape tensor
  87. norm_shape = inputs.shape[begin_norm_axis:]
  88. if norm_shape.is_fully_defined():
  89. bn_shape = [1, -1, 1, numpy.prod(norm_shape.as_list())]
  90. else:
  91. norm_shape = tf.shape(inputs)[begin_norm_axis:]
  92. bn_shape = [1, -1, 1, tf.reduce_prod(norm_shape)]
  93. if inputs.get_shape().is_fully_defined():
  94. outputs_shape = inputs.get_shape()
  95. else:
  96. outputs_shape = tf.shape(inputs)
  97. inputs = array_ops.reshape(inputs, bn_shape)
  98. if inputs.get_shape().is_fully_defined():
  99. # static inputs TensorShape fully defined after reshape.
  100. ones = array_ops.ones(inputs.get_shape()[1], dtype=dtypes.float32)
  101. zeros = array_ops.zeros(inputs.get_shape()[1], dtype=dtypes.float32)
  102. else:
  103. # static inputs TensorShape NOT fully defined after reshape.
  104. # must use dynamic shape, which means these input tensors
  105. # have to be created at runtime, which causes a slowdown.
  106. scale_shape = tf.shape(inputs)[1]
  107. ones = array_ops.ones(scale_shape, dtype=dtypes.float32)
  108. zeros = array_ops.zeros(scale_shape, dtype=dtypes.float32)
  109. outputs, mean, variance = nn.fused_batch_norm(
  110. inputs,
  111. ones, zeros,
  112. epsilon=1e-4,
  113. data_format="NCHW")
  114. outputs = array_ops.reshape(outputs, outputs_shape)
  115. if center and scale:
  116. outputs = outputs * gamma + beta
  117. elif center:
  118. outputs = outputs + beta
  119. elif scale:
  120. outputs = outputs * gamma
  121. else:
  122. # Calculate the moments on the last axis (layer activations).
  123. norm_axes = list(range(begin_norm_axis, inputs_rank))
  124. mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
  125. # Compute layer normalization using the batch_normalization function.
  126. variance_epsilon = 1e-4
  127. outputs = nn.batch_normalization(
  128. inputs,
  129. mean,
  130. variance,
  131. offset=beta,
  132. scale=gamma,
  133. variance_epsilon=variance_epsilon)
  134. outputs.set_shape(inputs_shape)
  135. if activation_fn is not None:
  136. outputs = activation_fn(outputs)
  137. return utils.collect_named_outputs(outputs_collections, sc.name, outputs)