Spaces:
Runtime error
Runtime error
| from torch import nn | |
| class LRASPP(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.aspp1 = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(True) | |
| ) | |
| self.aspp2 = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(in_channels, out_channels, 1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| def forward_single_frame(self, x): | |
| return self.aspp1(x) * self.aspp2(x) | |
| def forward_time_series(self, x): | |
| B, T = x.shape[:2] | |
| x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T)) | |
| return x | |
| def forward(self, x): | |
| if x.ndim == 5: | |
| return self.forward_time_series(x) | |
| else: | |
| return self.forward_single_frame(x) |