Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class ConcatUpConv(nn.Module): | |
| def __init__(self, inplanes, outplanes, upsample=True): | |
| super(ConcatUpConv, self).__init__() | |
| out_channels = outplanes | |
| self.upsample = upsample | |
| self.con_1x1 = nn.Conv2d(inplanes, outplanes, 1, bias=False) | |
| nn.init.kaiming_uniform_(self.con_1x1.weight, a=1) | |
| self.nor_1 = nn.BatchNorm2d(out_channels) | |
| self.leakyrelu_1 = nn.ReLU() | |
| if self.upsample: | |
| self.con_3x3 = nn.Conv2d(outplanes, out_channels // 2, | |
| kernel_size=3, stride=1, padding=1, bias=False) | |
| nn.init.kaiming_uniform_(self.con_3x3.weight, a=1) | |
| self.nor_3 = nn.BatchNorm2d(out_channels // 2) | |
| self.leakyrelu_3 = nn.ReLU() | |
| def forward(self, x1, x2): | |
| fusion = torch.cat([x1, x2], dim=1) | |
| out_1 = self.leakyrelu_1(self.nor_1(self.con_1x1(fusion))) | |
| out = None | |
| if self.upsample: | |
| out = self.leakyrelu_3(self.nor_3(self.con_3x3(out_1))) | |
| out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=False) | |
| return out, out_1 | |
| class MSR(nn.Module): | |
| def __init__(self, body, channels, fpn=None, pan=None): | |
| super(MSR, self).__init__() | |
| self.body = body | |
| cucs = nn.ModuleList() | |
| channel = channels[0] | |
| cucs.append(ConcatUpConv(channel * 2, channel, upsample=False)) | |
| for i, channel in enumerate(channels[1:]): | |
| cucs.append(ConcatUpConv(channel * 2, channel)) | |
| self.cucs = cucs | |
| if fpn is not None: | |
| self.fpn = fpn | |
| if pan is not None: | |
| self.pan = pan | |
| def forward(self, x): | |
| outputs = self.body(x) | |
| re_x = F.interpolate(x, scale_factor=0.5, | |
| mode='bilinear', align_corners=False) | |
| output_re = self.body(re_x)[-1] | |
| low = F.interpolate(output_re, | |
| size=outputs[-1].shape[2:], | |
| mode='bilinear', align_corners=False) | |
| new_outputs = [] | |
| for cuc, high in zip(self.cucs[::-1], outputs[::-1]): | |
| low, out = cuc(high, low) | |
| new_outputs.append(out) | |
| outs = new_outputs[::-1] | |
| if hasattr(self, 'pan'): | |
| outs = self.pan(outs) | |
| if hasattr(self, 'fpn'): | |
| outs = self.fpn(outs) | |
| return outs | |