| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- import torch
- import torch.nn as nn
- normalizations = {
- "instancenorm3d": nn.InstanceNorm3d,
- "instancenorm2d": nn.InstanceNorm2d,
- "batchnorm3d": nn.BatchNorm3d,
- "batchnorm2d": nn.BatchNorm2d,
- }
- convolutions = {
- "Conv2d": nn.Conv2d,
- "Conv3d": nn.Conv3d,
- "ConvTranspose2d": nn.ConvTranspose2d,
- "ConvTranspose3d": nn.ConvTranspose3d,
- }
- def get_norm(name, out_channels):
- if "groupnorm" in name:
- return nn.GroupNorm(32, out_channels, affine=True)
- return normalizations[name](out_channels, affine=True)
- def get_conv(in_channels, out_channels, kernel_size, stride, dim, bias=False):
- conv = convolutions[f"Conv{dim}d"]
- padding = get_padding(kernel_size, stride)
- return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
- def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
- conv = convolutions[f"ConvTranspose{dim}d"]
- padding = get_padding(kernel_size, stride)
- output_padding = get_output_padding(kernel_size, stride, padding)
- return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
- def get_padding(kernel_size, stride):
- kernel_size_np = np.atleast_1d(kernel_size)
- stride_np = np.atleast_1d(stride)
- padding_np = (kernel_size_np - stride_np + 1) / 2
- padding = tuple(int(p) for p in padding_np)
- return padding if len(padding) > 1 else padding[0]
- def get_output_padding(kernel_size, stride, padding):
- kernel_size_np = np.atleast_1d(kernel_size)
- stride_np = np.atleast_1d(stride)
- padding_np = np.atleast_1d(padding)
- out_padding_np = 2 * padding_np + stride_np - kernel_size_np
- out_padding = tuple(int(p) for p in out_padding_np)
- return out_padding if len(out_padding) > 1 else out_padding[0]
- class ConvLayer(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
- super(ConvLayer, self).__init__()
- self.conv = get_conv(in_channels, out_channels, kernel_size, stride, dim)
- self.norm = get_norm(norm, out_channels)
- self.lrelu = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
- def forward(self, input_data):
- return self.lrelu(self.norm(self.conv(input_data)))
- class ConvBlock(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
- super(ConvBlock, self).__init__()
- self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim)
- self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
- def forward(self, input_data):
- out = self.conv1(input_data)
- out = self.conv2(out)
- return out
- class UpsampleBlock(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
- super(UpsampleBlock, self).__init__()
- self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, dim)
- self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
- def forward(self, input_data, skip_data):
- out = self.transp_conv(input_data)
- out = torch.cat((out, skip_data), dim=1)
- out = self.conv_block(out)
- return out
- class OutputBlock(nn.Module):
- def __init__(self, in_channels, out_channels, dim):
- super(OutputBlock, self).__init__()
- self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
- def forward(self, input_data):
- return self.conv(input_data)
|