layers.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import nv_norms
  15. import tensorflow as tf
  16. import tensorflow_addons as tfa
  17. convolutions = {
  18. "Conv2d": tf.keras.layers.Conv2D,
  19. "Conv3d": tf.keras.layers.Conv3D,
  20. "ConvTranspose2d": tf.keras.layers.Conv2DTranspose,
  21. "ConvTranspose3d": tf.keras.layers.Conv3DTranspose,
  22. }
  23. class KaimingNormal(tf.keras.initializers.VarianceScaling):
  24. def __init__(self, negative_slope, seed=None):
  25. super().__init__(
  26. scale=2.0 / (1 + negative_slope**2), mode="fan_in", distribution="untruncated_normal", seed=seed
  27. )
  28. def get_config(self):
  29. return {"seed": self.seed}
  30. def get_norm(name):
  31. if "group" in name:
  32. return tfa.layers.GroupNormalization(32, axis=-1, center=True, scale=True)
  33. elif "batch" in name:
  34. return tf.keras.layers.BatchNormalization(axis=-1, center=True, scale=True)
  35. elif "atex_instance" in name:
  36. return nv_norms.InstanceNormalization(axis=-1)
  37. elif "instance" in name:
  38. return tfa.layers.InstanceNormalization(axis=-1, center=True, scale=True)
  39. elif "none" in name:
  40. return tf.identity
  41. else:
  42. raise ValueError("Invalid normalization layer")
  43. def extract_args(kwargs):
  44. args = {}
  45. if "input_shape" in kwargs:
  46. args["input_shape"] = kwargs["input_shape"]
  47. return args
  48. def get_conv(filters, kernel_size, stride, dim, use_bias=False, **kwargs):
  49. conv = convolutions[f"Conv{dim}d"]
  50. return conv(
  51. filters=filters,
  52. kernel_size=kernel_size,
  53. strides=stride,
  54. padding="same",
  55. use_bias=use_bias,
  56. kernel_initializer=KaimingNormal(kwargs["negative_slope"]),
  57. data_format="channels_last",
  58. **extract_args(kwargs),
  59. )
  60. def get_transp_conv(filters, kernel_size, stride, dim, **kwargs):
  61. conv = convolutions[f"ConvTranspose{dim}d"]
  62. return conv(
  63. filters=filters,
  64. kernel_size=kernel_size,
  65. strides=stride,
  66. padding="same",
  67. use_bias=True,
  68. data_format="channels_last",
  69. **extract_args(kwargs),
  70. )
  71. class ConvLayer(tf.keras.layers.Layer):
  72. def __init__(self, filters, kernel_size, stride, **kwargs):
  73. super().__init__()
  74. self.conv = get_conv(filters, kernel_size, stride, **kwargs)
  75. self.norm = get_norm(kwargs["norm"])
  76. self.lrelu = tf.keras.layers.LeakyReLU(alpha=kwargs["negative_slope"])
  77. def call(self, data):
  78. out = self.conv(data)
  79. out = self.norm(out)
  80. out = self.lrelu(out)
  81. return out
  82. class ConvBlock(tf.keras.layers.Layer):
  83. def __init__(self, filters, kernel_size, stride, **kwargs):
  84. super().__init__()
  85. self.conv1 = ConvLayer(filters, kernel_size, stride, **kwargs)
  86. kwargs.pop("input_shape", None)
  87. self.conv2 = ConvLayer(filters, kernel_size, 1, **kwargs)
  88. def call(self, input_data):
  89. out = self.conv1(input_data)
  90. out = self.conv2(out)
  91. return out
  92. class UpsampleBlock(tf.keras.layers.Layer):
  93. def __init__(self, filters, kernel_size, stride, **kwargs):
  94. super().__init__()
  95. self.transp_conv = get_transp_conv(filters, stride, stride, **kwargs)
  96. self.conv_block = ConvBlock(filters, kernel_size, 1, **kwargs)
  97. def call(self, input_data, skip_data):
  98. out = self.transp_conv(input_data)
  99. out = tf.concat((out, skip_data), axis=-1)
  100. out = self.conv_block(out)
  101. return out
  102. class OutputBlock(tf.keras.layers.Layer):
  103. def __init__(self, filters, dim, negative_slope):
  104. super().__init__()
  105. self.conv = get_conv(
  106. filters,
  107. kernel_size=1,
  108. stride=1,
  109. dim=dim,
  110. use_bias=True,
  111. negative_slope=negative_slope,
  112. )
  113. def call(self, data):
  114. return self.conv(data)