layers.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import torch
  16. import torch.nn as nn
  17. normalizations = {
  18. "instancenorm3d": nn.InstanceNorm3d,
  19. "instancenorm2d": nn.InstanceNorm2d,
  20. "batchnorm3d": nn.BatchNorm3d,
  21. "batchnorm2d": nn.BatchNorm2d,
  22. }
  23. convolutions = {
  24. "Conv2d": nn.Conv2d,
  25. "Conv3d": nn.Conv3d,
  26. "ConvTranspose2d": nn.ConvTranspose2d,
  27. "ConvTranspose3d": nn.ConvTranspose3d,
  28. }
  29. def get_norm(name, out_channels):
  30. if "groupnorm" in name:
  31. return nn.GroupNorm(32, out_channels, affine=True)
  32. return normalizations[name](out_channels, affine=True)
  33. def get_conv(in_channels, out_channels, kernel_size, stride, dim, bias=False):
  34. conv = convolutions[f"Conv{dim}d"]
  35. padding = get_padding(kernel_size, stride)
  36. return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
  37. def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
  38. conv = convolutions[f"ConvTranspose{dim}d"]
  39. padding = get_padding(kernel_size, stride)
  40. output_padding = get_output_padding(kernel_size, stride, padding)
  41. return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
  42. def get_padding(kernel_size, stride):
  43. kernel_size_np = np.atleast_1d(kernel_size)
  44. stride_np = np.atleast_1d(stride)
  45. padding_np = (kernel_size_np - stride_np + 1) / 2
  46. padding = tuple(int(p) for p in padding_np)
  47. return padding if len(padding) > 1 else padding[0]
  48. def get_output_padding(kernel_size, stride, padding):
  49. kernel_size_np = np.atleast_1d(kernel_size)
  50. stride_np = np.atleast_1d(stride)
  51. padding_np = np.atleast_1d(padding)
  52. out_padding_np = 2 * padding_np + stride_np - kernel_size_np
  53. out_padding = tuple(int(p) for p in out_padding_np)
  54. return out_padding if len(out_padding) > 1 else out_padding[0]
  55. class ConvLayer(nn.Module):
  56. def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
  57. super(ConvLayer, self).__init__()
  58. self.conv = get_conv(in_channels, out_channels, kernel_size, stride, dim)
  59. self.norm = get_norm(norm, out_channels)
  60. self.lrelu = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
  61. def forward(self, input_data):
  62. return self.lrelu(self.norm(self.conv(input_data)))
  63. class ConvBlock(nn.Module):
  64. def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
  65. super(ConvBlock, self).__init__()
  66. self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim)
  67. self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
  68. def forward(self, input_data):
  69. out = self.conv1(input_data)
  70. out = self.conv2(out)
  71. return out
  72. class UpsampleBlock(nn.Module):
  73. def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
  74. super(UpsampleBlock, self).__init__()
  75. self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, dim)
  76. self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
  77. def forward(self, input_data, skip_data):
  78. out = self.transp_conv(input_data)
  79. out = torch.cat((out, skip_data), dim=1)
  80. out = self.conv_block(out)
  81. return out
  82. class OutputBlock(nn.Module):
  83. def __init__(self, in_channels, out_channels, dim):
  84. super(OutputBlock, self).__init__()
  85. self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
  86. def forward(self, input_data):
  87. return self.conv(input_data)