Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from .base_module import ConvBlock, DownsampleBlock, ResidualBlock, SkipConnection, UpsampleBlock | |
class HourGlass(nn.Module): | |
def __init__(self, convNum=4, resNum=4, inChannel=6, outChannel=3): | |
super(HourGlass, self).__init__() | |
self.inConv = ConvBlock(inChannel, 64, convNum=2) | |
self.down1 = nn.Sequential(*[DownsampleBlock(64, 128, withConvRelu=False), ConvBlock(128, 128, convNum=2)]) | |
self.down2 = nn.Sequential( | |
*[DownsampleBlock(128, 256, withConvRelu=False), ConvBlock(256, 256, convNum=convNum)]) | |
self.down3 = nn.Sequential( | |
*[DownsampleBlock(256, 512, withConvRelu=False), ConvBlock(512, 512, convNum=convNum)]) | |
self.residual = nn.Sequential(*[ResidualBlock(512) for _ in range(resNum)]) | |
self.up3 = nn.Sequential(*[UpsampleBlock(512, 256), ConvBlock(256, 256, convNum=convNum)]) | |
self.skip3 = SkipConnection(256) | |
self.up2 = nn.Sequential(*[UpsampleBlock(256, 128), ConvBlock(128, 128, convNum=2)]) | |
self.skip2 = SkipConnection(128) | |
self.up1 = nn.Sequential(*[UpsampleBlock(128, 64), ConvBlock(64, 64, convNum=2)]) | |
self.skip1 = SkipConnection(64) | |
self.outConv = nn.Sequential( | |
nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, outChannel, kernel_size=1, padding=0) | |
) | |
def forward(self, x): | |
f1 = self.inConv(x) | |
f2 = self.down1(f1) | |
f3 = self.down2(f2) | |
f4 = self.down3(f3) | |
r4 = self.residual(f4) | |
r3 = self.skip3(self.up3(r4), f3) | |
r2 = self.skip2(self.up2(r3), f2) | |
r1 = self.skip1(self.up1(r2), f1) | |
y = self.outConv(r1) | |
return y | |
class ResidualHourGlass(nn.Module): | |
def __init__(self, resNum=4, inChannel=6, outChannel=3): | |
super(ResidualHourGlass, self).__init__() | |
self.inConv = nn.Conv2d(inChannel, 64, kernel_size=3, padding=1) | |
self.residualBefore = nn.Sequential(*[ResidualBlock(64) for _ in range(2)]) | |
self.down1 = nn.Sequential( | |
*[DownsampleBlock(64, 128, withConvRelu=False), ConvBlock(128, 128, convNum=2)]) | |
self.down2 = nn.Sequential( | |
*[DownsampleBlock(128, 256, withConvRelu=False), ConvBlock(256, 256, convNum=2)]) | |
self.residual = nn.Sequential(*[ResidualBlock(256) for _ in range(resNum)]) | |
self.up2 = nn.Sequential(*[UpsampleBlock(256, 128), ConvBlock(128, 128, convNum=2)]) | |
self.skip2 = SkipConnection(128) | |
self.up1 = nn.Sequential(*[UpsampleBlock(128, 64), ConvBlock(64, 64, convNum=2)]) | |
self.skip1 = SkipConnection(64) | |
self.residualAfter = nn.Sequential(*[ResidualBlock(64) for _ in range(2)]) | |
self.outConv = nn.Sequential( | |
nn.Conv2d(64, outChannel, kernel_size=3, padding=1), | |
nn.Tanh() | |
) | |
def forward(self, x): | |
f1 = self.inConv(x) | |
f1 = self.residualBefore(f1) | |
f2 = self.down1(f1) | |
f3 = self.down2(f2) | |
r3 = self.residual(f3) | |
r2 = self.skip2(self.up2(r3), f2) | |
r1 = self.skip1(self.up1(r2), f1) | |
y = self.residualAfter(r1) | |
y = self.outConv(y) | |
return y | |