Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class DotAttn(nn.Module): | |
| """ Dot-Attention """ | |
| def forward(self, inp, h): | |
| score = self.softmax(inp, h) | |
| return score.expand_as(inp).mul(inp).sum(1), score | |
| def softmax(self, inp, h): | |
| raw_score = inp.bmm(h.unsqueeze(2)) | |
| score = F.softmax(raw_score, dim=1) | |
| return score | |
| class ScaledDotAttn(nn.Module): | |
| """ Scaled Dot-Attention """ | |
| def forward(self, inp, h): | |
| score = self.softmax(inp, h) | |
| return score.expand_as(inp).mul(inp).sum(1), score | |
| def softmax(self, inp, h): | |
| raw_score = inp.bmm(h.unsqueeze(2)) / np.sqrt(h.shape[-1]) | |
| score = F.softmax(raw_score, dim=1) | |
| return score | |
| class Fusion(nn.Module): | |
| """ Base Fusion Class""" | |
| def __init__(self, input_dim=3): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| def tile_x2(self, x1, x2, x2_proj=None): | |
| if x2_proj: | |
| x2 = x2_proj(x2) | |
| x2 = x2.unsqueeze(-1).unsqueeze(-1) | |
| x2 = x2.repeat(x1.shape[0], 1, x1.shape[-2], x1.shape[-1]) | |
| return x2 | |
| def batch_tile_x2(self, x1, x2, x2_proj=None): | |
| if x2_proj: | |
| x2 = x2_proj(x2) | |
| x2 = x2.unsqueeze(-1).unsqueeze(-1) | |
| x2 = x2.repeat(1, 1, x1.shape[-2], x1.shape[-1]) | |
| return x2 | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| raise NotImplementedError() | |
| class FusionAdd(Fusion): | |
| """ x1 + x2 """ | |
| def __init__(self, input_dim=3): | |
| super(FusionAdd, self).__init__(input_dim=input_dim) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.tile_x2(x1, x2, x2_proj) | |
| return x1 + x2 | |
| class FusionMult(Fusion): | |
| """ x1 * x2 """ | |
| def __init__(self, input_dim=3): | |
| super(FusionMult, self).__init__(input_dim=input_dim) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.batch_tile_x2(x1, x2, x2_proj) # self.batch_tile_x2(x1, x2, x2_proj) | |
| return x1 * x2 | |
| class FusionMax(Fusion): | |
| """ max(x1, x2) """ | |
| def __init__(self, input_dim=3): | |
| super(FusionMax, self).__init__(input_dim=input_dim) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.tile_x2(x1, x2, x2_proj) | |
| return torch.max(x1, x2) | |
| class FusionConcat(Fusion): | |
| """ [x1; x2] """ | |
| def __init__(self, input_dim=3): | |
| super(FusionConcat, self).__init__(input_dim=input_dim) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.tile_x2(x1, x2, x2_proj) | |
| return torch.cat([x1, x2], dim=1) | |
| class FusionConv(Fusion): | |
| """ 1x1 convs after [x1; x2] """ | |
| def __init__(self, input_dim=3): | |
| super(FusionConv, self).__init__(input_dim=input_dim) | |
| self.conv = nn.Sequential( | |
| nn.ReLU(True), | |
| nn.Conv2d(input_dim * 2, input_dim, kernel_size=1, bias=False) | |
| ) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.tile_x2(x1, x2, x2_proj) | |
| x = torch.cat([x1, x2], dim=1) # [B, 2C, H, W] | |
| x = self.conv(x) # [B, C, H, W] | |
| return x | |
| class FusionConvLat(Fusion): | |
| """ 1x1 convs after [x1; x2] for lateral fusion """ | |
| def __init__(self, input_dim=3, output_dim=3): | |
| super(FusionConvLat, self).__init__(input_dim=input_dim) | |
| self.conv = nn.Sequential( | |
| nn.ReLU(True), | |
| nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False) | |
| ) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.tile_x2(x1, x2, x2_proj) | |
| x = torch.cat([x1, x2], dim=1) # [B, input_dim, H, W] | |
| x = self.conv(x) # [B, output_dim, H, W] | |
| return x | |
| ## ------------- NOTE ---------------- | |
| ## The following are various fusion types I experimented with. | |
| ## Most of them didn't work well ¯\_(ツ)_/¯ | |
| ## But it doesn't mean there isn't a better way of | |
| ## doing lateral and multi-modal (language+vision) fusion. | |
| class FusionFiLM(Fusion): | |
| """ FiLM (Perez et. al, https://arxiv.org/abs/1709.07871). | |
| Note: This is not used inside a Residual block before ReLU. | |
| I had a version this in UpBlock with FiLM, which didn't seem to work at all. | |
| """ | |
| def __init__(self, input_dim=3, output_dim=3): | |
| super(FusionFiLM, self).__init__(input_dim=input_dim) | |
| def forward(self, x1, x2, gamma, beta): | |
| g = self.tile_x2(x1, x2, gamma) | |
| b = self.tile_x2(x1, x2, beta) | |
| return x1 * g + b | |
| class FusionDeepConv(Fusion): | |
| """ Multi-Layer 1x1 convs after [x1; x2] """ | |
| def __init__(self, input_dim=3): | |
| super(FusionDeepConv, self).__init__(input_dim=input_dim) | |
| self.conv = nn.Sequential( | |
| nn.ReLU(True), | |
| nn.Conv2d(input_dim * 2, input_dim, kernel_size=1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(input_dim, input_dim, kernel_size=1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(input_dim, input_dim, kernel_size=1, bias=False), | |
| ) | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): | |
| x2 = self.tile_x2(x1, x2, x2_proj) | |
| x = torch.cat([x1, x2], dim=1) # [B, 2C, H, W] | |
| x = self.conv(x) # [B, C, H, W] | |
| return x | |
| class FusionMultWord(nn.Module): | |
| """ Product with weighted-sum of words """ | |
| def __init__(self, input_dim=3): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| B, D, H, W = x1.shape | |
| x2_len = int(x2_mask.count_nonzero()) | |
| weighted_x1 = torch.zeros_like(x1) | |
| for t in range(x2_len): | |
| x2_t = x2_proj(x2[:,t]) if x2_proj else x2[:,t] | |
| x2_t = x2_t.unsqueeze(-1).unsqueeze(-1).repeat(B, 1, H, W) | |
| weighted_x1 += x1 * x2_t | |
| weighted_x1 /= x2_len | |
| return weighted_x1 | |
| class FusionWordAttention(nn.Module): | |
| """ Word Attention """ | |
| def __init__(self, input_dim=3): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.dot_attn = DotAttn() | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| B, D, H, W = x1.shape | |
| x1_flat = x1.reshape(B, D, H*W) | |
| x2_len = int(x2_mask.count_nonzero()) | |
| # TODO: batch this unrolling? | |
| weight_sum_x1_flat = torch.zeros_like(x1_flat) | |
| for t in range(x2_len): | |
| x2_t = x2_proj(x2[:,t]) if x2_proj else x2[:,t] | |
| x2_t = x2_t.repeat(B, 1) | |
| _, attn_x1 = self.dot_attn(x1_flat.transpose(1, 2), x2_t) | |
| weight_sum_x1_flat += x1_flat * attn_x1.transpose(1, 2) | |
| weight_sum_x1_flat /= x2_len | |
| x2 = weight_sum_x1_flat.reshape(B, D, H, W) | |
| return x2 | |
| class FusionSentenceAttention(nn.Module): | |
| """ Sentence Attention """ | |
| def __init__(self, input_dim=3): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.dot_attn = ScaledDotAttn() | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| B, D, H, W = x1.shape | |
| x1_flat = x1.reshape(B, D, H*W) | |
| x2_t = x2_proj(x2) if x2_proj else x2 | |
| x2_t = x2_t.repeat(B, 1) | |
| _, attn_x1 = self.dot_attn(x1_flat.transpose(1, 2), x2_t) | |
| weight_sum_x1_flat = x1_flat * attn_x1.transpose(1, 2) | |
| x2 = weight_sum_x1_flat.reshape(B, D, H, W) | |
| return x2 | |
| class CrossModalAttention2d(nn.Module): | |
| """ Cross-Modal Attention. Adapted from: https://github.com/openai/CLIP/blob/main/clip/model.py#L56 """ | |
| def __init__(self, spacial_dim=7, embed_dim=1024, num_heads=32, | |
| output_dim=1024, lang_dim=512, lang_max_tokens=77): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.lang_dim = lang_dim | |
| self.lang_max_tokens = lang_max_tokens | |
| self.num_heads = num_heads | |
| self.lang_proj = nn.Linear(self.lang_dim, embed_dim) | |
| self.vision_positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2, embed_dim) / embed_dim ** 0.5) | |
| self.lang_positional_embedding = nn.Parameter(torch.randn(lang_max_tokens, embed_dim) / embed_dim ** 0.5) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
| def forward(self, x, l, l_mask): | |
| # reshape vision features | |
| x_shape = x.shape | |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC | |
| x = x + self.vision_positional_embedding[:x.shape[0], None, :].to(x.dtype) # (HW)NC | |
| # project language | |
| l = l.permute(1, 0, 2) | |
| l_shape = l.shape | |
| l = l.reshape(-1, self.lang_dim) | |
| l = self.lang_proj(l) | |
| l = l.reshape(l_shape[0], l_shape[1], self.embed_dim) | |
| l = l + self.lang_positional_embedding[:, None, :].to(l.dtype) | |
| # hard language mask | |
| l_len = int(l_mask.count_nonzero()) | |
| l = l[:l_len] | |
| l = l.repeat(1, x.shape[1], 1) | |
| x, _ = F.multi_head_attention_forward( | |
| query=x, key=l, value=l, | |
| embed_dim_to_check=x.shape[-1], | |
| num_heads=self.num_heads, | |
| q_proj_weight=self.q_proj.weight, | |
| k_proj_weight=self.k_proj.weight, | |
| v_proj_weight=self.v_proj.weight, | |
| in_proj_weight=None, | |
| in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
| bias_k=None, | |
| bias_v=None, | |
| add_zero_attn=False, | |
| dropout_p=0, | |
| out_proj_weight=self.c_proj.weight, | |
| out_proj_bias=self.c_proj.bias, | |
| use_separate_proj_weight=True, | |
| training=self.training, | |
| need_weights=False | |
| ) | |
| x = x.permute(1, 2, 0) | |
| x = x.reshape(x_shape) | |
| return x | |
| class FusionMultiHeadedWordAttention(nn.Module): | |
| """ Multi-Headed Word Attention that uses Cross Modal Attention at different scales """ | |
| def __init__(self, input_dim=3): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.attn1 = CrossModalAttention2d(spacial_dim=7, embed_dim=1024, output_dim=1024) | |
| self.attn2 = CrossModalAttention2d(spacial_dim=14, embed_dim=512, output_dim=512) | |
| self.attn3 = CrossModalAttention2d(spacial_dim=28, embed_dim=256, output_dim=256) | |
| self.multi_headed_attns = { | |
| 1024: self.attn1, | |
| 512: self.attn2, | |
| 256: self.attn3, | |
| } | |
| def forward(self, x1, x2, x2_mask=None, x2_proj=None): | |
| emb_dim = x1.shape[1] | |
| x = self.multi_headed_attns[emb_dim](x1, x2, x2_mask) | |
| return x | |
| names = { | |
| 'add': FusionAdd, | |
| 'mult': FusionMult, | |
| 'mult_word': FusionMultWord, | |
| 'film': FusionFiLM, | |
| 'max': FusionMax, | |
| 'concat': FusionConcat, | |
| 'conv': FusionConv, | |
| 'deep_conv': FusionDeepConv, | |
| 'word_attn': FusionWordAttention, | |
| 'sent_attn': FusionSentenceAttention, | |
| 'multi_headed_word_attn': FusionMultiHeadedWordAttention, | |
| } | |