layers.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. normalizations = {
  5. "instancenorm3d": nn.InstanceNorm3d,
  6. "instancenorm2d": nn.InstanceNorm2d,
  7. "batchnorm3d": nn.BatchNorm3d,
  8. "batchnorm2d": nn.BatchNorm2d,
  9. }
  10. convolutions = {
  11. "Conv2d": nn.Conv2d,
  12. "Conv3d": nn.Conv3d,
  13. "ConvTranspose2d": nn.ConvTranspose2d,
  14. "ConvTranspose3d": nn.ConvTranspose3d,
  15. }
  16. def get_norm(name, out_channels):
  17. if "groupnorm" in name:
  18. return nn.GroupNorm(32, out_channels, affine=True)
  19. return normalizations[name](out_channels, affine=True)
  20. def get_conv(in_channels, out_channels, kernel_size, stride, dim, bias=False):
  21. conv = convolutions[f"Conv{dim}d"]
  22. padding = get_padding(kernel_size, stride)
  23. return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
  24. def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
  25. conv = convolutions[f"ConvTranspose{dim}d"]
  26. padding = get_padding(kernel_size, stride)
  27. output_padding = get_output_padding(kernel_size, stride, padding)
  28. return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
  29. def get_padding(kernel_size, stride):
  30. kernel_size_np = np.atleast_1d(kernel_size)
  31. stride_np = np.atleast_1d(stride)
  32. padding_np = (kernel_size_np - stride_np + 1) / 2
  33. padding = tuple(int(p) for p in padding_np)
  34. return padding if len(padding) > 1 else padding[0]
  35. def get_output_padding(kernel_size, stride, padding):
  36. kernel_size_np = np.atleast_1d(kernel_size)
  37. stride_np = np.atleast_1d(stride)
  38. padding_np = np.atleast_1d(padding)
  39. out_padding_np = 2 * padding_np + stride_np - kernel_size_np
  40. out_padding = tuple(int(p) for p in out_padding_np)
  41. return out_padding if len(out_padding) > 1 else out_padding[0]
  42. class ConvLayer(nn.Module):
  43. def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
  44. super(ConvLayer, self).__init__()
  45. self.conv = get_conv(in_channels, out_channels, kernel_size, stride, dim)
  46. self.norm = get_norm(norm, out_channels)
  47. self.lrelu = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
  48. def forward(self, input_data):
  49. return self.lrelu(self.norm(self.conv(input_data)))
  50. class ConvBlock(nn.Module):
  51. def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
  52. super(ConvBlock, self).__init__()
  53. self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim)
  54. self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
  55. def forward(self, input_data):
  56. out = self.conv1(input_data)
  57. out = self.conv2(out)
  58. return out
  59. class UpsampleBlock(nn.Module):
  60. def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
  61. super(UpsampleBlock, self).__init__()
  62. self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, dim)
  63. self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
  64. def forward(self, input_data, skip_data):
  65. out = self.transp_conv(input_data)
  66. out = torch.cat((out, skip_data), dim=1)
  67. out = self.conv_block(out)
  68. return out
  69. class OutputBlock(nn.Module):
  70. def __init__(self, in_channels, out_channels, dim):
  71. super(OutputBlock, self).__init__()
  72. self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
  73. def forward(self, input_data):
  74. return self.conv(input_data)