Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
class ResidualConvUnit(nn.Module): | |
def __init__(self, features): | |
super().__init__() | |
self.conv1 = nn.Conv2d( | |
features, features, kernel_size=3, stride=1, padding=1, bias=True) | |
self.conv2 = nn.Conv2d( | |
features, features, kernel_size=3, stride=1, padding=1, bias=True) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
"""Forward pass. | |
Args: | |
x (tensor): input | |
Returns: | |
tensor: output | |
""" | |
out = self.relu(x) | |
out = self.conv1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
return out + x | |
class Fusion(nn.Module): | |
def __init__(self, resample_dim): | |
super(Fusion, self).__init__() | |
self.res_conv1 = ResidualConvUnit(resample_dim) | |
self.res_conv2 = ResidualConvUnit(resample_dim) | |
def forward(self, x, previous_stage=None): | |
if previous_stage == None: | |
previous_stage = torch.zeros_like(x) | |
output_stage1 = self.res_conv1(x) | |
output_stage1 += previous_stage | |
output_stage2 = self.res_conv2(output_stage1) | |
output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True) | |
return output_stage2 |