Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .common import CausalConv3d | |
| class Downsampler(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| spatial_downsample_factor: int = 1, | |
| temporal_downsample_factor: int = 1, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.spatial_downsample_factor = spatial_downsample_factor | |
| self.temporal_downsample_factor = temporal_downsample_factor | |
| class SpatialDownsampler3D(Downsampler): | |
| def __init__(self, in_channels: int, out_channels): | |
| if out_channels is None: | |
| out_channels = in_channels | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| spatial_downsample_factor=2, | |
| temporal_downsample_factor=1, | |
| ) | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=(1, 2, 2), | |
| padding=0, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.pad(x, (0, 1, 0, 1)) | |
| return self.conv(x) | |
| class TemporalDownsampler3D(Downsampler): | |
| def __init__(self, in_channels: int, out_channels): | |
| if out_channels is None: | |
| out_channels = in_channels | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| spatial_downsample_factor=1, | |
| temporal_downsample_factor=2, | |
| ) | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=(2, 1, 1), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(x) | |
| class SpatialTemporalDownsampler3D(Downsampler): | |
| def __init__(self, in_channels: int, out_channels): | |
| if out_channels is None: | |
| out_channels = in_channels | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| spatial_downsample_factor=2, | |
| temporal_downsample_factor=2, | |
| ) | |
| self.conv = CausalConv3d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=(2, 2, 2), | |
| padding=0, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.pad(x, (0, 1, 0, 1)) | |
| return self.conv(x) | |
| class BlurPooling2D(Downsampler): | |
| def __init__(self, in_channels: int, out_channels): | |
| if out_channels is None: | |
| out_channels = in_channels | |
| assert in_channels == out_channels | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| spatial_downsample_factor=2, | |
| temporal_downsample_factor=1, | |
| ) | |
| filt = torch.tensor([1, 2, 1], dtype=torch.float32) | |
| filt = torch.einsum("i,j -> ij", filt, filt) | |
| filt = filt / filt.sum() | |
| filt = filt[None, None].repeat(out_channels, 1, 1, 1) | |
| self.register_buffer("filt", filt) | |
| self.filt: torch.Tensor | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, H, W) | |
| return F.conv2d(x, self.filt, stride=2, padding=1, groups=self.in_channels) | |
| class BlurPooling3D(Downsampler): | |
| def __init__(self, in_channels: int, out_channels): | |
| if out_channels is None: | |
| out_channels = in_channels | |
| assert in_channels == out_channels | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| spatial_downsample_factor=2, | |
| temporal_downsample_factor=2, | |
| ) | |
| filt = torch.tensor([1, 2, 1], dtype=torch.float32) | |
| filt = torch.einsum("i,j,k -> ijk", filt, filt, filt) | |
| filt = filt / filt.sum() | |
| filt = filt[None, None].repeat(out_channels, 1, 1, 1, 1) | |
| self.register_buffer("filt", filt) | |
| self.filt: torch.Tensor | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, T, H, W) | |
| return F.conv3d(x, self.filt, stride=2, padding=1, groups=self.in_channels) | |