Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from spiga.models.cnn.layers import Conv, Deconv, Residual | |
| class Hourglass(nn.Module): | |
| def __init__(self, n, f, bn=None, increase=0): | |
| super(Hourglass, self).__init__() | |
| nf = f + increase | |
| self.up1 = Residual(f, f) | |
| # Lower branch | |
| self.pool1 = Conv(f, f, 2, 2, bn=True, relu=True) | |
| self.low1 = Residual(f, nf) | |
| self.n = n | |
| # Recursive hourglass | |
| if self.n > 1: | |
| self.low2 = Hourglass(n - 1, nf, bn=bn) | |
| else: | |
| self.low2 = Residual(nf, nf) | |
| self.low3 = Residual(nf, f) | |
| self.up2 = Deconv(f, f, 2, 2, bn=True, relu=True) | |
| def forward(self, x): | |
| up1 = self.up1(x) | |
| pool1 = self.pool1(x) | |
| low1 = self.low1(pool1) | |
| low2 = self.low2(low1) | |
| low3 = self.low3(low2) | |
| up2 = self.up2(low3) | |
| return up1 + up2 | |
| class HourglassCore(Hourglass): | |
| def __init__(self, n, f, bn=None, increase=0): | |
| super(HourglassCore, self).__init__(n, f, bn=bn, increase=increase) | |
| nf = f + increase | |
| if self.n > 1: | |
| self.low2 = HourglassCore(n - 1, nf, bn=bn) | |
| def forward(self, x, core=[]): | |
| up1 = self.up1(x) | |
| pool1 = self.pool1(x) | |
| low1 = self.low1(pool1) | |
| if self.n > 1: | |
| low2, core = self.low2(low1, core=core) | |
| else: | |
| low2 = self.low2(low1) | |
| core.append(low2) | |
| low3 = self.low3(low2) | |
| if self.n > 1: | |
| core.append(low3) | |
| up2 = self.up2(low3) | |
| return up1 + up2, core | |