| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import numpy as np
- import torch
- import torch.nn as nn
- normalizations = {
- "instancenorm3d": nn.InstanceNorm3d,
- "instancenorm2d": nn.InstanceNorm2d,
- "batchnorm3d": nn.BatchNorm3d,
- "batchnorm2d": nn.BatchNorm2d,
- }
- convolutions = {
- "Conv2d": nn.Conv2d,
- "Conv3d": nn.Conv3d,
- "ConvTranspose2d": nn.ConvTranspose2d,
- "ConvTranspose3d": nn.ConvTranspose3d,
- }
- def get_norm(name, out_channels):
- if "groupnorm" in name:
- return nn.GroupNorm(32, out_channels, affine=True)
- return normalizations[name](out_channels, affine=True)
- def get_conv(in_channels, out_channels, kernel_size, stride, dim, bias=False):
- conv = convolutions[f"Conv{dim}d"]
- padding = get_padding(kernel_size, stride)
- return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
- def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
- conv = convolutions[f"ConvTranspose{dim}d"]
- padding = get_padding(kernel_size, stride)
- output_padding = get_output_padding(kernel_size, stride, padding)
- return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
- def get_padding(kernel_size, stride):
- kernel_size_np = np.atleast_1d(kernel_size)
- stride_np = np.atleast_1d(stride)
- padding_np = (kernel_size_np - stride_np + 1) / 2
- padding = tuple(int(p) for p in padding_np)
- return padding if len(padding) > 1 else padding[0]
- def get_output_padding(kernel_size, stride, padding):
- kernel_size_np = np.atleast_1d(kernel_size)
- stride_np = np.atleast_1d(stride)
- padding_np = np.atleast_1d(padding)
- out_padding_np = 2 * padding_np + stride_np - kernel_size_np
- out_padding = tuple(int(p) for p in out_padding_np)
- return out_padding if len(out_padding) > 1 else out_padding[0]
- class ConvLayer(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
- super(ConvLayer, self).__init__()
- self.conv = get_conv(in_channels, out_channels, kernel_size, stride, dim)
- self.norm = get_norm(norm, out_channels)
- self.lrelu = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
- def forward(self, input_data):
- return self.lrelu(self.norm(self.conv(input_data)))
- class ConvBlock(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
- super(ConvBlock, self).__init__()
- self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim)
- self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
- def forward(self, input_data):
- out = self.conv1(input_data)
- out = self.conv2(out)
- return out
- class UpsampleBlock(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
- super(UpsampleBlock, self).__init__()
- self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, dim)
- self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
- def forward(self, input_data, skip_data):
- out = self.transp_conv(input_data)
- out = torch.cat((out, skip_data), dim=1)
- out = self.conv_block(out)
- return out
- class OutputBlock(nn.Module):
- def __init__(self, in_channels, out_channels, dim):
- super(OutputBlock, self).__init__()
- self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
- def forward(self, input_data):
- return self.conv(input_data)
|