""" Group-specific modules They handle features that also depends on the mask. Features are typically of shape batch_size * num_objects * num_channels * H * W All of them are permutation equivariant w.r.t. to the num_objects dimension """ import torch import torch.nn as nn import torch.nn.functional as F def interpolate_groups(g, ratio, mode, align_corners): batch_size, num_objects = g.shape[:2] g = F.interpolate( g.flatten(start_dim=0, end_dim=1), scale_factor=ratio, mode=mode, align_corners=align_corners, ) g = g.view(batch_size, num_objects, *g.shape[1:]) return g def upsample_groups(g, ratio=2, mode="bilinear", align_corners=False): return interpolate_groups(g, ratio, mode, align_corners) def downsample_groups(g, ratio=1 / 2, mode="area", align_corners=None): return interpolate_groups(g, ratio, mode, align_corners) class GConv2D(nn.Conv2d): def forward(self, g): batch_size, num_objects = g.shape[:2] g = super().forward(g.flatten(start_dim=0, end_dim=1)) return g.view(batch_size, num_objects, *g.shape[1:]) class GroupResBlock(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() if in_dim == out_dim: self.downsample = None else: self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) def forward(self, g): out_g = self.conv1(F.relu(g)) out_g = self.conv2(F.relu(out_g)) if self.downsample is not None: g = self.downsample(g) return out_g + g class MainToGroupDistributor(nn.Module): def __init__(self, x_transform=None, method="cat", reverse_order=False): super().__init__() self.x_transform = x_transform self.method = method self.reverse_order = reverse_order def forward(self, x, g): num_objects = g.shape[1] if self.x_transform is not None: x = self.x_transform(x) if self.method == "cat": if self.reverse_order: g = torch.cat( [g, x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)], 2 ) else: g = torch.cat( [x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1), g], 2 ) elif self.method == "add": g = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + g else: raise NotImplementedError return g