Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from torch.nn import functional as F | |
import torch | |
import numpy as np | |
def tensor2array(tensors): | |
arrays = tensors.detach().to("cpu").numpy() | |
return np.transpose(arrays, (0, 2, 3, 1)) | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels): | |
super(ResidualBlock, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(channels, channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
) | |
def forward(self, x): | |
residual = self.conv(x) | |
return x + residual | |
class DownsampleBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, withConvRelu=True): | |
super(DownsampleBlock, self).__init__() | |
if withConvRelu: | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
else: | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2) | |
def forward(self, x): | |
return self.conv(x) | |
class ConvBlock(nn.Module): | |
def __init__(self, inChannels, outChannels, convNum): | |
super(ConvBlock, self).__init__() | |
self.inConv = nn.Sequential( | |
nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
layers = [] | |
for _ in range(convNum - 1): | |
layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)) | |
layers.append(nn.ReLU(inplace=True)) | |
self.conv = nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.inConv(x) | |
x = self.conv(x) | |
return x | |
class UpsampleBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(UpsampleBlock, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = F.interpolate(x, scale_factor=2, mode='nearest') | |
return self.conv(x) | |
class SkipConnection(nn.Module): | |
def __init__(self, channels): | |
super(SkipConnection, self).__init__() | |
self.conv = nn.Conv2d(2 * channels, channels, 1, bias=False) | |
def forward(self, x, y): | |
x = torch.cat((x, y), 1) | |
return self.conv(x) |