| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from matanyone.model.channel_attn import CAResBlock | |
| def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, | |
| align_corners: bool) -> torch.Tensor: | |
| batch_size, num_objects = g.shape[:2] | |
| g = F.interpolate(g.flatten(start_dim=0, end_dim=1), | |
| scale_factor=ratio, | |
| mode=mode, | |
| align_corners=align_corners) | |
| g = g.view(batch_size, num_objects, *g.shape[1:]) | |
| return g | |
| def upsample_groups(g: torch.Tensor, | |
| ratio: float = 2, | |
| mode: str = 'bilinear', | |
| align_corners: bool = False) -> torch.Tensor: | |
| return interpolate_groups(g, ratio, mode, align_corners) | |
| def downsample_groups(g: torch.Tensor, | |
| ratio: float = 1 / 2, | |
| mode: str = 'area', | |
| align_corners: bool = None) -> torch.Tensor: | |
| return interpolate_groups(g, ratio, mode, align_corners) | |
| class GConv2d(nn.Conv2d): | |
| def forward(self, g: torch.Tensor) -> torch.Tensor: | |
| batch_size, num_objects = g.shape[:2] | |
| g = super().forward(g.flatten(start_dim=0, end_dim=1)) | |
| return g.view(batch_size, num_objects, *g.shape[1:]) | |
| class GroupResBlock(nn.Module): | |
| def __init__(self, in_dim: int, out_dim: int): | |
| super().__init__() | |
| if in_dim == out_dim: | |
| self.downsample = nn.Identity() | |
| else: | |
| self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) | |
| self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) | |
| self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) | |
| def forward(self, g: torch.Tensor) -> torch.Tensor: | |
| out_g = self.conv1(F.relu(g)) | |
| out_g = self.conv2(F.relu(out_g)) | |
| g = self.downsample(g) | |
| return out_g + g | |
| class MainToGroupDistributor(nn.Module): | |
| def __init__(self, | |
| x_transform: Optional[nn.Module] = None, | |
| g_transform: Optional[nn.Module] = None, | |
| method: str = 'cat', | |
| reverse_order: bool = False): | |
| super().__init__() | |
| self.x_transform = x_transform | |
| self.g_transform = g_transform | |
| self.method = method | |
| self.reverse_order = reverse_order | |
| def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: | |
| num_objects = g.shape[1] | |
| if self.x_transform is not None: | |
| x = self.x_transform(x) | |
| if self.g_transform is not None: | |
| g = self.g_transform(g) | |
| if not skip_expand: | |
| x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) | |
| if self.method == 'cat': | |
| if self.reverse_order: | |
| g = torch.cat([g, x], 2) | |
| else: | |
| g = torch.cat([x, g], 2) | |
| elif self.method == 'add': | |
| g = x + g | |
| elif self.method == 'mulcat': | |
| g = torch.cat([x * g, g], dim=2) | |
| elif self.method == 'muladd': | |
| g = x * g + g | |
| else: | |
| raise NotImplementedError | |
| return g | |
| class GroupFeatureFusionBlock(nn.Module): | |
| def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): | |
| super().__init__() | |
| x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) | |
| g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) | |
| self.distributor = MainToGroupDistributor(x_transform=x_transform, | |
| g_transform=g_transform, | |
| method='add') | |
| self.block1 = CAResBlock(out_dim, out_dim) | |
| self.block2 = CAResBlock(out_dim, out_dim) | |
| def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: | |
| batch_size, num_objects = g.shape[:2] | |
| g = self.distributor(x, g) | |
| g = g.flatten(start_dim=0, end_dim=1) | |
| g = self.block1(g) | |
| g = self.block2(g) | |
| g = g.view(batch_size, num_objects, *g.shape[1:]) | |
| return g |