import numpy as np import torch import torch.nn as nn class ResidualConvUnit(nn.Module): def __init__(self, features): super().__init__() self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU(inplace=True) def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.relu(x) out = self.conv1(out) out = self.relu(out) out = self.conv2(out) return out + x class Fusion(nn.Module): def __init__(self, resample_dim): super(Fusion, self).__init__() self.res_conv1 = ResidualConvUnit(resample_dim) self.res_conv2 = ResidualConvUnit(resample_dim) def forward(self, x, previous_stage=None): if previous_stage == None: previous_stage = torch.zeros_like(x) output_stage1 = self.res_conv1(x) output_stage1 += previous_stage output_stage2 = self.res_conv2(output_stage1) output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True) return output_stage2