layers.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ High level definition of layers for model construction """
  15. import tensorflow as tf
  16. def _normalization(inputs, name, mode):
  17. """ Choose a normalization layer
  18. :param inputs: Input node from the graph
  19. :param name: Name of layer
  20. :param mode: Estimator's execution mode
  21. :return: Normalized output
  22. """
  23. training = mode == tf.estimator.ModeKeys.TRAIN
  24. if name == 'instancenorm':
  25. gamma_initializer = tf.constant_initializer(1.0)
  26. return tf.contrib.layers.instance_norm(
  27. inputs,
  28. center=True,
  29. scale=True,
  30. epsilon=1e-6,
  31. param_initializers={'gamma': gamma_initializer},
  32. reuse=None,
  33. variables_collections=None,
  34. outputs_collections=None,
  35. trainable=True,
  36. data_format='NHWC',
  37. scope=None)
  38. if name == 'groupnorm':
  39. return tf.contrib.layers.group_norm(inputs=inputs,
  40. groups=16,
  41. channels_axis=-1,
  42. reduction_axes=(-4, -3, -2),
  43. activation_fn=None,
  44. trainable=True)
  45. if name == 'batchnorm':
  46. return tf.keras.layers.BatchNormalization(axis=-1,
  47. trainable=True,
  48. virtual_batch_size=None)(inputs, training=training)
  49. if name == 'none':
  50. return inputs
  51. raise ValueError('Invalid normalization layer')
  52. def _activation(out, activation):
  53. """ Choose an activation layer
  54. :param out: Input node from the graph
  55. :param activation: Name of layer
  56. :return: Activation output
  57. """
  58. if activation == 'relu':
  59. return tf.nn.relu(out)
  60. if activation == 'leaky_relu':
  61. return tf.nn.leaky_relu(out, alpha=0.01)
  62. if activation == 'sigmoid':
  63. return tf.nn.sigmoid(out)
  64. if activation == 'softmax':
  65. return tf.nn.softmax(out, axis=-1)
  66. if activation == 'none':
  67. return out
  68. raise ValueError("Unknown activation {}".format(activation))
  69. def convolution(inputs, # pylint: disable=R0913
  70. out_channels,
  71. kernel_size=3,
  72. stride=1,
  73. mode=tf.estimator.ModeKeys.TRAIN,
  74. normalization='batchnorm',
  75. activation='leaky_relu',
  76. transpose=False):
  77. """ Create a convolution layer
  78. :param inputs: Input node from graph
  79. :param out_channels: Output number of channels
  80. :param kernel_size: Size of the kernel
  81. :param stride: Stride of the kernel
  82. :param mode: Estimator's execution mode
  83. :param normalization: Name of the normalization layer
  84. :param activation: Name of the activation layer
  85. :param transpose: Select between regular and transposed convolution
  86. :return: Convolution output
  87. """
  88. if transpose:
  89. conv = tf.keras.layers.Conv3DTranspose
  90. else:
  91. conv = tf.keras.layers.Conv3D
  92. regularizer = None # tf.keras.regularizers.l2(1e-5)
  93. use_bias = normalization == "none"
  94. inputs = conv(filters=out_channels,
  95. kernel_size=kernel_size,
  96. strides=stride,
  97. activation=None,
  98. padding='same',
  99. data_format='channels_last',
  100. kernel_initializer=tf.compat.v1.glorot_uniform_initializer(),
  101. kernel_regularizer=regularizer,
  102. bias_initializer=tf.zeros_initializer(),
  103. bias_regularizer=regularizer,
  104. use_bias=use_bias)(inputs)
  105. inputs = _normalization(inputs, normalization, mode)
  106. return _activation(inputs, activation)
  107. def upsample_block(inputs, skip_connection, out_channels, normalization, mode):
  108. """ Create a block for upsampling
  109. :param inputs: Input node from the graph
  110. :param skip_connection: Choose whether or not to use skip connection
  111. :param out_channels: Number of output channels
  112. :param normalization: Name of the normalizaiton layer
  113. :param mode: Estimator's execution mode
  114. :return: Output from the upsample block
  115. """
  116. inputs = convolution(inputs, kernel_size=2, out_channels=out_channels, stride=2,
  117. normalization='none', activation='none', transpose=True)
  118. inputs = tf.keras.layers.Concatenate(axis=-1)([inputs, skip_connection])
  119. inputs = convolution(inputs, out_channels=out_channels, normalization=normalization, mode=mode)
  120. inputs = convolution(inputs, out_channels=out_channels, normalization=normalization, mode=mode)
  121. return inputs
  122. def input_block(inputs, out_channels, normalization, mode):
  123. """ Create the input block
  124. :param inputs: Input node from the graph
  125. :param out_channels: Number of output channels
  126. :param normalization: Name of the normalization layer
  127. :param mode: Estimator's execution mode
  128. :return: Output from the input block
  129. """
  130. inputs = convolution(inputs, out_channels=out_channels, normalization=normalization, mode=mode)
  131. inputs = convolution(inputs, out_channels=out_channels, normalization=normalization, mode=mode)
  132. return inputs
  133. def downsample_block(inputs, out_channels, normalization, mode):
  134. """ Create a downsample block
  135. :param inputs: Input node from the graph
  136. :param out_channels: Number of output channels
  137. :param normalization: Name of the normalization layer
  138. :param mode: Estimator's execution mode
  139. :return: Output from the downsample block
  140. """
  141. inputs = convolution(inputs, out_channels=out_channels, normalization=normalization, mode=mode, stride=2)
  142. return convolution(inputs, out_channels=out_channels, normalization=normalization, mode=mode)
  143. def output_layer(inputs, out_channels, activation):
  144. """ Create the output layer
  145. :param inputs: Input node from the graph
  146. :param out_channels: Number of output channels
  147. :param activation: Name of the activation layer
  148. :return: Output from the output block
  149. """
  150. return convolution(inputs, out_channels=out_channels, kernel_size=3, normalization='none', activation=activation)