Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # import open_clip | |
| def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1): | |
| return nn.Sequential( | |
| nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False), | |
| nn.BatchNorm2d(out_dim), nn.ReLU(True)) | |
| # return nn.Sequential( | |
| # nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False), | |
| # nn.LayerNorm(out_dim), nn.ReLU(True)) | |
| # def conv_layer_1(in_dim, out_dim, kernel_size=1, padding=0, stride=1): | |
| # return nn.Sequential( | |
| # nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False), | |
| # nn.LayerNorm(out_dim), nn.ReLU(True)) | |
| def linear_layer(in_dim, out_dim,bias=False): | |
| return nn.Sequential(nn.Linear(in_dim, out_dim, bias), | |
| nn.BatchNorm1d(out_dim), nn.ReLU(True)) | |
| # return nn.Sequential(nn.Linear(in_dim, out_dim, bias), | |
| # nn.LayerNorm(out_dim), nn.ReLU(True)) | |
| class AttentionPool2d(nn.Module): | |
| def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
| super().__init__() | |
| self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
| self.num_heads = num_heads | |
| def forward(self, x): | |
| x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC | |
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |
| x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |
| x, _ = F.multi_head_attention_forward( | |
| query=x[:1], key=x, value=x, | |
| embed_dim_to_check=x.shape[-1], | |
| num_heads=self.num_heads, | |
| q_proj_weight=self.q_proj.weight, | |
| k_proj_weight=self.k_proj.weight, | |
| v_proj_weight=self.v_proj.weight, | |
| in_proj_weight=None, | |
| in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
| bias_k=None, | |
| bias_v=None, | |
| add_zero_attn=False, | |
| dropout_p=0, | |
| out_proj_weight=self.c_proj.weight, | |
| out_proj_bias=self.c_proj.bias, | |
| use_separate_proj_weight=True, | |
| training=self.training, | |
| need_weights=False | |
| ) | |
| return x.squeeze(0) | |
| # class AttentionPool2d(nn.Module): | |
| # def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
| # super().__init__() | |
| # self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) | |
| # self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| # self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| # self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| # self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
| # self.num_heads = num_heads | |
| # | |
| # def forward(self, x): | |
| # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC | |
| # x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |
| # x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |
| # x, _ = F.multi_head_attention_forward( | |
| # query=x, key=x, value=x, | |
| # embed_dim_to_check=x.shape[-1], | |
| # num_heads=self.num_heads, | |
| # q_proj_weight=self.q_proj.weight, | |
| # k_proj_weight=self.k_proj.weight, | |
| # v_proj_weight=self.v_proj.weight, | |
| # in_proj_weight=None, | |
| # in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
| # bias_k=None, | |
| # bias_v=None, | |
| # add_zero_attn=False, | |
| # dropout_p=0, | |
| # out_proj_weight=self.c_proj.weight, | |
| # out_proj_bias=self.c_proj.bias, | |
| # use_separate_proj_weight=True, | |
| # training=self.training, | |
| # need_weights=False | |
| # ) | |
| # | |
| # return x[0] | |
| class CoordConv(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1): | |
| super().__init__() | |
| self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size, | |
| padding, stride) | |
| def add_coord(self, input): | |
| b, _, h, w = input.size() | |
| x_range = torch.linspace(-1, 1, w, device=input.device) | |
| y_range = torch.linspace(-1, 1, h, device=input.device) | |
| y, x = torch.meshgrid(y_range, x_range) | |
| y = y.expand([b, 1, -1, -1]) | |
| x = x.expand([b, 1, -1, -1]) | |
| coord_feat = torch.cat([x, y], 1) | |
| input = torch.cat([input, coord_feat], 1) | |
| return input | |
| def forward(self, x): | |
| x = self.add_coord(x) | |
| x = self.conv1(x) | |
| return x | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, | |
| num_layers, | |
| d_model, | |
| nhead, | |
| dim_ffn, | |
| dropout, | |
| return_intermediate=False): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| TransformerDecoderLayer(d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_ffn, | |
| dropout=dropout) for _ in range(num_layers) | |
| ]) | |
| self.num_layers = num_layers | |
| self.norm = nn.LayerNorm(d_model) | |
| self.return_intermediate = return_intermediate | |
| def pos1d(d_model, length): | |
| """ | |
| :param d_model: dimension of the model | |
| :param length: length of positions | |
| :return: length*d_model position matrix | |
| """ | |
| if d_model % 2 != 0: | |
| raise ValueError("Cannot use sin/cos positional encoding with " | |
| "odd dim (got dim={:d})".format(d_model)) | |
| pe = torch.zeros(length, d_model) | |
| position = torch.arange(0, length).unsqueeze(1) | |
| div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * | |
| -(math.log(10000.0) / d_model))) | |
| pe[:, 0::2] = torch.sin(position.float() * div_term) | |
| pe[:, 1::2] = torch.cos(position.float() * div_term) | |
| return pe.unsqueeze(1) # n, 1, 512 | |
| def pos2d(d_model, height, width): | |
| """ | |
| :param d_model: dimension of the model | |
| :param height: height of the positions | |
| :param width: width of the positions | |
| :return: d_model*height*width position matrix | |
| """ | |
| if d_model % 4 != 0: | |
| raise ValueError("Cannot use sin/cos positional encoding with " | |
| "odd dimension (got dim={:d})".format(d_model)) | |
| pe = torch.zeros(d_model, height, width) | |
| # Each dimension use half of d_model | |
| d_model = int(d_model / 2) | |
| div_term = torch.exp( | |
| torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) | |
| pos_w = torch.arange(0., width).unsqueeze(1) | |
| pos_h = torch.arange(0., height).unsqueeze(1) | |
| pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose( | |
| 0, 1).unsqueeze(1).repeat(1, height, 1) | |
| pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose( | |
| 0, 1).unsqueeze(1).repeat(1, height, 1) | |
| pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose( | |
| 0, 1).unsqueeze(2).repeat(1, 1, width) | |
| pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose( | |
| 0, 1).unsqueeze(2).repeat(1, 1, width) | |
| return pe.reshape(-1, 1, height * width).permute(2, 1, 0) # hw, 1, 512 | |
| def forward(self, vis, txt, pad_mask): | |
| ''' | |
| vis: b, 512, h, w | |
| txt: b, L, 512 | |
| pad_mask: b, L | |
| ''' | |
| B, C, H, W = vis.size() | |
| _, L, D = txt.size() | |
| # position encoding | |
| vis_pos = self.pos2d(C, H, W) | |
| txt_pos = self.pos1d(D, L) | |
| # reshape & permute | |
| vis = vis.reshape(B, C, -1).permute(2, 0, 1) | |
| txt = txt.permute(1, 0, 2) | |
| # forward | |
| output = vis | |
| intermediate = [] | |
| for layer in self.layers: | |
| output = layer(output, txt, vis_pos, txt_pos, pad_mask) | |
| if self.return_intermediate: | |
| # HW, b, 512 -> b, 512, HW | |
| intermediate.append(self.norm(output).permute(1, 2, 0)) | |
| if self.norm is not None: | |
| # HW, b, 512 -> b, 512, HW | |
| output = self.norm(output).permute(1, 2, 0) | |
| if self.return_intermediate: | |
| intermediate.pop() | |
| intermediate.append(output) | |
| # [output1, output2, ..., output_n] | |
| return intermediate | |
| else: | |
| # b, 512, HW | |
| return output | |
| return output | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__(self, | |
| d_model=512, | |
| nhead=9, | |
| dim_feedforward=2048, | |
| dropout=0.1): | |
| super().__init__() | |
| # Normalization Layer | |
| self.self_attn_norm = nn.LayerNorm(d_model) | |
| self.cross_attn_norm = nn.LayerNorm(d_model) | |
| # Attention Layer | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.multihead_attn = nn.MultiheadAttention(d_model, | |
| nhead, | |
| dropout=dropout, | |
| kdim=d_model, | |
| vdim=d_model) | |
| # FFN | |
| self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward), | |
| nn.ReLU(True), nn.Dropout(dropout), | |
| nn.LayerNorm(dim_feedforward), | |
| nn.Linear(dim_feedforward, d_model)) | |
| # LayerNorm & Dropout | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| def with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos.to(tensor.device) | |
| def forward(self, vis, txt, vis_pos, txt_pos, pad_mask): | |
| ''' | |
| vis: 26*26, b, 512 | |
| txt: L, b, 512 | |
| vis_pos: 26*26, 1, 512 | |
| txt_pos: L, 1, 512 | |
| pad_mask: b, L | |
| ''' | |
| # Self-Attention | |
| vis2 = self.norm1(vis) | |
| q = k = self.with_pos_embed(vis2, vis_pos) | |
| vis2 = self.self_attn(q, k, value=vis2)[0] | |
| vis2 = self.self_attn_norm(vis2) | |
| vis = vis + self.dropout1(vis2) | |
| # Cross-Attention | |
| vis2 = self.norm2(vis) | |
| vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos), | |
| key=self.with_pos_embed(txt, txt_pos), | |
| value=txt, | |
| key_padding_mask=pad_mask)[0] | |
| vis2 = self.cross_attn_norm(vis2) | |
| vis = vis + self.dropout2(vis2) | |
| # FFN | |
| vis2 = self.norm3(vis) | |
| vis2 = self.ffn(vis2) | |
| vis = vis + self.dropout3(vis2) | |
| return vis | |
| class Text_Projector(nn.Module): | |
| def __init__(self, args, in_channels=[512, 1024, 1024], | |
| out_channels=[256, 512, 1024]): | |
| super(Text_Projector, self).__init__() | |
| self.proj = linear_layer(args, in_channels[2], out_channels[2]) | |
| self.ReLU = nn.ReLU(True) | |
| def forward(self, text): | |
| text = self.ReLU(text + self.proj(text)) | |
| return text | |
| class Image_Projector(nn.Module): | |
| def __init__(self, args, in_channels=[512, 1024, 1024], | |
| out_channels=[256, 512, 1024]): | |
| super(Image_Projector, self).__init__() | |
| self.proj = linear_layer(args, in_channels[0], out_channels[2]) | |
| self.ReLU = nn.ReLU(True) | |
| def forward(self, image): | |
| image = self.ReLU(image + self.proj(image)) | |
| return image | |
| class Adapter(nn.Module): | |
| def __init__(self, c_in, reduction=4): | |
| super(Adapter, self).__init__() | |
| self.fc = nn.Sequential( | |
| nn.Linear(c_in, c_in // reduction, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(c_in // reduction, c_in, bias=False), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| x = self.fc(x) | |
| return x | |
| class GAP(nn.Module): | |
| def __init__(self, kernel): | |
| super(GAP, self).__init__() | |
| self.k = kernel | |
| # self.fc = nn.Linear(512, 1024) | |
| def forward(self, x): | |
| x = F.adaptive_avg_pool2d(x, self.k) | |
| return x.squeeze(-1).squeeze(-1) | |
| class AdaptiveSpatialFeatureFusion(nn.Module): | |
| def __init__(self, args, in_channels=[512, 1024, 1024], | |
| out_channels=[256, 512, 1024]): | |
| super(AdaptiveSpatialFeatureFusion, self).__init__() | |
| self.weight = nn.LayerNorm(out_channels[2]) | |
| self.proj = linear_layer(args, in_channels[0], out_channels[2]) | |
| def forward(self, feature_map1, feature_map2): | |
| # feature_map1 : b, 1024, 1, 1 | |
| # feature_map2 : b, 512, 1, 1 | |
| feature_map2 = self.proj(feature_map2.squeeze(-1).squeeze(-1)) | |
| feature_map1 = feature_map1.squeeze(-1).squeeze(-1) | |
| weights1 = torch.norm(feature_map1, dim=1).unsqueeze(-1) | |
| weights2 = torch.norm(feature_map2, dim=1).unsqueeze(-1) | |
| weights1 = weights1 / (weights1 + weights2) | |
| weights2 = 1 - weights1 | |
| fused_feature_map = weights1 * feature_map1 + weights2 * feature_map2 | |
| # b, 1024 | |
| return fused_feature_map | |
| class ModifiedAttentionPool2d(nn.Module): | |
| def __init__(self, | |
| spacial_dim: int, | |
| embed_dim: int, | |
| num_heads: int, | |
| output_dim: int = None): | |
| super().__init__() | |
| self.spacial_dim = spacial_dim | |
| self.positional_embedding = nn.Parameter( | |
| torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
| self.num_heads = num_heads | |
| # residual | |
| self.connect = nn.Sequential( | |
| nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False), | |
| nn.BatchNorm2d(output_dim)) | |
| def resize_pos_embed(self, pos_embed, input_shpae): | |
| """Resize pos_embed weights. | |
| Resize pos_embed using bicubic interpolate method. | |
| Args: | |
| pos_embed (torch.Tensor): Position embedding weights. | |
| input_shpae (tuple): Tuple for (downsampled input image height, | |
| downsampled input image width). | |
| pos_shape (tuple): The resolution of downsampled origin training | |
| image. | |
| mode (str): Algorithm used for upsampling: | |
| ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | | |
| ``'trilinear'``. Default: ``'nearest'`` | |
| Return: | |
| torch.Tensor: The resized pos_embed of shape [B, C, L_new] | |
| """ | |
| assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' | |
| pos_h = pos_w = self.spacial_dim | |
| cls_token_weight = pos_embed[:, 0] | |
| pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] | |
| pos_embed_weight = pos_embed_weight.reshape( | |
| 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) | |
| pos_embed_weight = F.interpolate(pos_embed_weight, | |
| size=input_shpae, | |
| align_corners=False, | |
| mode='bicubic') | |
| cls_token_weight = cls_token_weight.unsqueeze(1) | |
| pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) | |
| # pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) | |
| return pos_embed_weight.transpose(-2, -1) | |
| def forward(self, x): | |
| B, C, H, W = x.size() | |
| res = self.connect(x) | |
| x = x.reshape(B, C, -1) # NC(HW) | |
| # x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW) | |
| pos_embed = self.positional_embedding.unsqueeze(0) | |
| pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW) | |
| x = x + pos_embed.to(x.dtype) # NC(HW) | |
| x = x.permute(2, 0, 1) # (HW)NC | |
| x, _ = F.multi_head_attention_forward( | |
| query=x, | |
| key=x, | |
| value=x, | |
| embed_dim_to_check=x.shape[-1], | |
| num_heads=self.num_heads, | |
| q_proj_weight=self.q_proj.weight, | |
| k_proj_weight=self.k_proj.weight, | |
| v_proj_weight=self.v_proj.weight, | |
| in_proj_weight=None, | |
| in_proj_bias=torch.cat( | |
| [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
| bias_k=None, | |
| bias_v=None, | |
| add_zero_attn=False, | |
| dropout_p=0, | |
| out_proj_weight=self.c_proj.weight, | |
| out_proj_bias=self.c_proj.bias, | |
| use_separate_proj_weight=True, | |
| training=self.training, | |
| need_weights=False) | |
| xt = x[0] | |
| x = x.permute(1, 2, 0).reshape(B, -1, H, W) | |
| x = x + res | |
| x = F.relu(x, True) | |
| return x, xt | |
| # modified | |
| class FPN(nn.Module): | |
| def __init__(self, args, | |
| in_channels=[512, 1024, 1024], | |
| out_channels=[256, 512, 1024, 1024]): | |
| super(FPN, self).__init__() | |
| input_resolution = args.input_size | |
| heads = args.heads | |
| output_dim = args.output_dim | |
| embed_dim = args.emb_dim | |
| # image projection | |
| self.attn = ModifiedAttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) | |
| # text projection | |
| self.txt_proj = linear_layer(args, in_channels[2], out_channels[2]) | |
| # fusion 1: v5 & seq -> f_5: b, 1024, 13, 13 | |
| self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0) | |
| self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]), | |
| nn.ReLU(True)) | |
| # fusion 2: v4 & fm -> f_4: b, 512, 26, 26 | |
| self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1) | |
| self.f2_cat = conv_layer(out_channels[2] + out_channels[1], | |
| out_channels[1], 1, 0) | |
| # fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52 | |
| self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1) | |
| self.f3_cat = conv_layer(out_channels[0] + out_channels[1], | |
| out_channels[1], 1, 0) | |
| # fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26 | |
| self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1) | |
| self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1) | |
| self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1) | |
| # aggregation | |
| self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0) | |
| self.coordconv = nn.Sequential( | |
| CoordConv(out_channels[1], out_channels[1], 3, 1), | |
| conv_layer(out_channels[1], out_channels[3], 3, 1)) | |
| def forward(self, imgs, text): | |
| # v3, v4, v5: 256, 52, 52 / 512, 26, 26 / 1024, 13, 13 | |
| v3, v4, v5 = imgs | |
| # fusion 1: b, 1024, 13, 13 | |
| # text projection: b, 1024 -> b, 1024 | |
| v5, _ = self.attn(v5) | |
| text_ = self.txt_proj(text) | |
| state = text_.unsqueeze(-1).unsqueeze( | |
| -1)# b, 1024, 1, 1 | |
| f5 = self.f1_v_proj(v5) # b, 1024, 7, 7 | |
| f5 = self.norm_layer(f5 * state) | |
| # fusion 2: b, 512, 26, 26 | |
| f4 = self.f2_v_proj(v4) | |
| # f4 = f4.repeat(w2,1,1,1) | |
| f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear') | |
| f4 = self.f2_cat(torch.cat([f4, f5_], dim=1)) | |
| # fusion 3: b, 256, 26, 26 | |
| f3 = self.f3_v_proj(v3) | |
| f3 = F.avg_pool2d(f3, 2, 2) | |
| # f3 = f3.repeat(w2, 1, 1, 1) | |
| f3 = self.f3_cat(torch.cat([f3, f4], dim=1)) | |
| # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26 | |
| fq5 = self.f4_proj5(f5) | |
| fq4 = self.f4_proj4(f4) | |
| fq3 = self.f4_proj3(f3) | |
| # query | |
| fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear') | |
| fq = torch.cat([fq3, fq4, fq5], dim=1) | |
| fq = self.aggr(fq) | |
| fq = self.coordconv(fq) | |
| # fqq = fq.reshape(w1, w2, fq.shape[1], fq.shape[2], fq.shape[3]) | |
| # b, 512, 26, 26 | |
| # elif text.shape[0] != v3.shape[0]: | |
| # | |
| # text = self.txt_proj(text) | |
| # state = text.unsqueeze(-1).unsqueeze( | |
| # -1) # b, 1024, 1, 1 | |
| # state = state.view(v5.shape[0], int(text.shape[0] / v5.shape[0]), state.shape[1], state.shape[2], state.shape[3]) | |
| # | |
| # f5 = self.f1_v_proj(v5) # b, 1024, 7, 7 | |
| # f5 = f5.unsqueeze(1) | |
| # f5_ = f5 * state | |
| # f5_ = f5_.view(-1, f5.shape[2], f5.shape[3], f5.shape[4]) | |
| # f5 = self.norm_layer(f5_) | |
| # # fusion 2: b, 512, 26, 26 | |
| # f4 = self.f2_v_proj(v4) | |
| # # f4 = f4.repeat(w2,1,1,1) | |
| # | |
| # f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear') | |
| # f4 = f4.repeat(int(f5_.shape[0] / f4.shape[0]), 1, 1, 1) | |
| # f4 = self.f2_cat(torch.cat([f4, f5_], dim=1)) | |
| # | |
| # # fusion 3: b, 256, 26, 26 | |
| # f3 = self.f3_v_proj(v3) | |
| # f3 = F.avg_pool2d(f3, 2, 2) | |
| # # f3 = f3.repeat(w2, 1, 1, 1) | |
| # f3 = f3.repeat(int(f5_.shape[0] / f3.shape[0]), 1, 1, 1) | |
| # f3 = self.f3_cat(torch.cat([f3, f4], dim=1)) | |
| # # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26 | |
| # fq5 = self.f4_proj5(f5) | |
| # fq4 = self.f4_proj4(f4) | |
| # fq3 = self.f4_proj3(f3) | |
| # # query | |
| # fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear') | |
| # fq = torch.cat([fq3, fq4, fq5], dim=1) | |
| # fq = self.aggr(fq) | |
| # fq = self.coordconv(fq) | |
| return fq | |
| class ViTFPN(nn.Module): | |
| def __init__(self, image_resolution, | |
| in_channels=[512, 768, 768], | |
| out_channels=[768, 768, 768, 512]): | |
| super(ViTFPN, self).__init__() | |
| # text projection | |
| self.txt_proj = linear_layer(in_channels[0], out_channels[1]) | |
| # fusion 1: v5 & seq -> f_5: b, 1024, 13, 13 | |
| self.f1_v_proj = conv_layer(in_channels[1], out_channels[1], 1, 0) | |
| self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[1]), | |
| nn.ReLU(True)) | |
| # fusion 2: v4 & fm -> f_4: b, 512, 26, 26 | |
| self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1) | |
| self.f2_cat = conv_layer(out_channels[0] + out_channels[0], | |
| out_channels[0], 1, 0) | |
| # fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52 | |
| self.f3_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1) | |
| self.f3_cat = conv_layer(out_channels[0] + out_channels[1], | |
| out_channels[1], 1, 0) | |
| # fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26 | |
| self.f4_proj5 = conv_layer(out_channels[1], out_channels[0], 3, 1) | |
| self.f4_proj4 = conv_layer(out_channels[0], out_channels[0], 3, 1) | |
| self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1) | |
| # aggregation | |
| self.aggr = conv_layer(3 * out_channels[0], out_channels[0], 1, 0) | |
| self.coordconv = nn.Sequential( | |
| CoordConv(out_channels[0], out_channels[0], 3, 1), | |
| conv_layer(out_channels[0], out_channels[-1], 3, 1)) | |
| self.attnpool = AttentionPool2d(image_resolution // 32, out_channels[-1], | |
| 8, out_channels[-1]) | |
| def forward(self, imgs, state, vis): | |
| # v1 / v2 / b, 49, 1024/ b, 196, 512 | |
| v3, v4, v5 = imgs | |
| # fusion 1: b, 1024, 13, 13 | |
| # text projection: b, 1024 -> b, 1024 | |
| state = self.txt_proj(state) | |
| state = state.unsqueeze(-1).unsqueeze( | |
| -1)# b, 1024, 1, 1 | |
| f5 = self.f1_v_proj(v5) | |
| f5 = self.norm_layer(f5 * state) | |
| # fusion 2: b, 512, 26, 26 | |
| f4 = self.f2_v_proj(v4) | |
| b, c, h, w = f4.size() | |
| f5_ = F.interpolate(f5, (h, w), mode='bilinear') | |
| f4 = self.f2_cat(torch.cat([f4, f5_], dim=1)) | |
| # fusion 3: b, 256, 26, 26 | |
| f3 = self.f3_v_proj(v3) | |
| f3 = F.avg_pool2d(f3, 2, 2) | |
| # f3 = f3.repeat(w2, 1, 1, 1) | |
| f3 = self.f3_cat(torch.cat([f3, f4], dim=1)) | |
| # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26 | |
| fq5 = self.f4_proj5(f5) | |
| fq4 = self.f4_proj4(f4) | |
| fq3 = self.f4_proj3(f3) | |
| # query | |
| fq5 = F.interpolate(fq5, (h, w), mode='bilinear') | |
| fq = torch.cat([fq3, fq4, fq5], dim=1) | |
| fq = self.aggr(fq) | |
| if not vis: | |
| fq = self.coordconv(fq) | |
| fq = self.attnpool(fq) | |
| # b, 512, 26, 26 | |
| return fq | |