|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
|
|
nn.BatchNorm2d(out_dim), nn.ReLU(True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_layer(in_dim, out_dim,bias=False):
|
|
return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
|
|
nn.BatchNorm1d(out_dim), nn.ReLU(True))
|
|
|
|
|
|
class AttentionPool2d(nn.Module):
|
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
|
super().__init__()
|
|
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
|
self.num_heads = num_heads
|
|
|
|
def forward(self, x):
|
|
x = x.flatten(start_dim=2).permute(2, 0, 1)
|
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)
|
|
x = x + self.positional_embedding[:, None, :].to(x.dtype)
|
|
x, _ = F.multi_head_attention_forward(
|
|
query=x[:1], key=x, value=x,
|
|
embed_dim_to_check=x.shape[-1],
|
|
num_heads=self.num_heads,
|
|
q_proj_weight=self.q_proj.weight,
|
|
k_proj_weight=self.k_proj.weight,
|
|
v_proj_weight=self.v_proj.weight,
|
|
in_proj_weight=None,
|
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
bias_k=None,
|
|
bias_v=None,
|
|
add_zero_attn=False,
|
|
dropout_p=0,
|
|
out_proj_weight=self.c_proj.weight,
|
|
out_proj_bias=self.c_proj.bias,
|
|
use_separate_proj_weight=True,
|
|
training=self.training,
|
|
need_weights=False
|
|
)
|
|
return x.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoordConv(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=1):
|
|
super().__init__()
|
|
self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size,
|
|
padding, stride)
|
|
|
|
def add_coord(self, input):
|
|
b, _, h, w = input.size()
|
|
x_range = torch.linspace(-1, 1, w, device=input.device)
|
|
y_range = torch.linspace(-1, 1, h, device=input.device)
|
|
y, x = torch.meshgrid(y_range, x_range)
|
|
y = y.expand([b, 1, -1, -1])
|
|
x = x.expand([b, 1, -1, -1])
|
|
coord_feat = torch.cat([x, y], 1)
|
|
input = torch.cat([input, coord_feat], 1)
|
|
return input
|
|
|
|
def forward(self, x):
|
|
x = self.add_coord(x)
|
|
x = self.conv1(x)
|
|
return x
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
def __init__(self,
|
|
num_layers,
|
|
d_model,
|
|
nhead,
|
|
dim_ffn,
|
|
dropout,
|
|
return_intermediate=False):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([
|
|
TransformerDecoderLayer(d_model=d_model,
|
|
nhead=nhead,
|
|
dim_feedforward=dim_ffn,
|
|
dropout=dropout) for _ in range(num_layers)
|
|
])
|
|
self.num_layers = num_layers
|
|
self.norm = nn.LayerNorm(d_model)
|
|
self.return_intermediate = return_intermediate
|
|
|
|
@staticmethod
|
|
def pos1d(d_model, length):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param length: length of positions
|
|
:return: length*d_model position matrix
|
|
"""
|
|
if d_model % 2 != 0:
|
|
raise ValueError("Cannot use sin/cos positional encoding with "
|
|
"odd dim (got dim={:d})".format(d_model))
|
|
pe = torch.zeros(length, d_model)
|
|
position = torch.arange(0, length).unsqueeze(1)
|
|
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
|
-(math.log(10000.0) / d_model)))
|
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
|
|
|
return pe.unsqueeze(1)
|
|
|
|
@staticmethod
|
|
def pos2d(d_model, height, width):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param height: height of the positions
|
|
:param width: width of the positions
|
|
:return: d_model*height*width position matrix
|
|
"""
|
|
if d_model % 4 != 0:
|
|
raise ValueError("Cannot use sin/cos positional encoding with "
|
|
"odd dimension (got dim={:d})".format(d_model))
|
|
pe = torch.zeros(d_model, height, width)
|
|
|
|
d_model = int(d_model / 2)
|
|
div_term = torch.exp(
|
|
torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
|
|
pos_w = torch.arange(0., width).unsqueeze(1)
|
|
pos_h = torch.arange(0., height).unsqueeze(1)
|
|
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(
|
|
0, 1).unsqueeze(1).repeat(1, height, 1)
|
|
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(
|
|
0, 1).unsqueeze(1).repeat(1, height, 1)
|
|
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(
|
|
0, 1).unsqueeze(2).repeat(1, 1, width)
|
|
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(
|
|
0, 1).unsqueeze(2).repeat(1, 1, width)
|
|
|
|
return pe.reshape(-1, 1, height * width).permute(2, 1, 0)
|
|
|
|
def forward(self, vis, txt, pad_mask):
|
|
'''
|
|
vis: b, 512, h, w
|
|
txt: b, L, 512
|
|
pad_mask: b, L
|
|
'''
|
|
B, C, H, W = vis.size()
|
|
_, L, D = txt.size()
|
|
|
|
vis_pos = self.pos2d(C, H, W)
|
|
txt_pos = self.pos1d(D, L)
|
|
|
|
vis = vis.reshape(B, C, -1).permute(2, 0, 1)
|
|
txt = txt.permute(1, 0, 2)
|
|
|
|
output = vis
|
|
intermediate = []
|
|
for layer in self.layers:
|
|
output = layer(output, txt, vis_pos, txt_pos, pad_mask)
|
|
if self.return_intermediate:
|
|
|
|
intermediate.append(self.norm(output).permute(1, 2, 0))
|
|
|
|
if self.norm is not None:
|
|
|
|
output = self.norm(output).permute(1, 2, 0)
|
|
if self.return_intermediate:
|
|
intermediate.pop()
|
|
intermediate.append(output)
|
|
|
|
return intermediate
|
|
else:
|
|
|
|
return output
|
|
return output
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
def __init__(self,
|
|
d_model=512,
|
|
nhead=9,
|
|
dim_feedforward=2048,
|
|
dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.self_attn_norm = nn.LayerNorm(d_model)
|
|
self.cross_attn_norm = nn.LayerNorm(d_model)
|
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
self.multihead_attn = nn.MultiheadAttention(d_model,
|
|
nhead,
|
|
dropout=dropout,
|
|
kdim=d_model,
|
|
vdim=d_model)
|
|
|
|
self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
|
|
nn.ReLU(True), nn.Dropout(dropout),
|
|
nn.LayerNorm(dim_feedforward),
|
|
nn.Linear(dim_feedforward, d_model))
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
def with_pos_embed(self, tensor, pos):
|
|
return tensor if pos is None else tensor + pos.to(tensor.device)
|
|
|
|
def forward(self, vis, txt, vis_pos, txt_pos, pad_mask):
|
|
'''
|
|
vis: 26*26, b, 512
|
|
txt: L, b, 512
|
|
vis_pos: 26*26, 1, 512
|
|
txt_pos: L, 1, 512
|
|
pad_mask: b, L
|
|
'''
|
|
|
|
vis2 = self.norm1(vis)
|
|
q = k = self.with_pos_embed(vis2, vis_pos)
|
|
vis2 = self.self_attn(q, k, value=vis2)[0]
|
|
vis2 = self.self_attn_norm(vis2)
|
|
vis = vis + self.dropout1(vis2)
|
|
|
|
vis2 = self.norm2(vis)
|
|
vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos),
|
|
key=self.with_pos_embed(txt, txt_pos),
|
|
value=txt,
|
|
key_padding_mask=pad_mask)[0]
|
|
vis2 = self.cross_attn_norm(vis2)
|
|
vis = vis + self.dropout2(vis2)
|
|
|
|
vis2 = self.norm3(vis)
|
|
vis2 = self.ffn(vis2)
|
|
vis = vis + self.dropout3(vis2)
|
|
return vis
|
|
|
|
class Text_Projector(nn.Module):
|
|
def __init__(self, args, in_channels=[512, 1024, 1024],
|
|
out_channels=[256, 512, 1024]):
|
|
|
|
super(Text_Projector, self).__init__()
|
|
|
|
self.proj = linear_layer(args, in_channels[2], out_channels[2])
|
|
self.ReLU = nn.ReLU(True)
|
|
|
|
def forward(self, text):
|
|
|
|
text = self.ReLU(text + self.proj(text))
|
|
|
|
return text
|
|
|
|
class Image_Projector(nn.Module):
|
|
def __init__(self, args, in_channels=[512, 1024, 1024],
|
|
out_channels=[256, 512, 1024]):
|
|
|
|
super(Image_Projector, self).__init__()
|
|
|
|
self.proj = linear_layer(args, in_channels[0], out_channels[2])
|
|
self.ReLU = nn.ReLU(True)
|
|
|
|
def forward(self, image):
|
|
|
|
image = self.ReLU(image + self.proj(image))
|
|
|
|
return image
|
|
|
|
class Adapter(nn.Module):
|
|
def __init__(self, c_in, reduction=4):
|
|
super(Adapter, self).__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(c_in, c_in // reduction, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(c_in // reduction, c_in, bias=False),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class GAP(nn.Module):
|
|
def __init__(self, kernel):
|
|
super(GAP, self).__init__()
|
|
self.k = kernel
|
|
|
|
def forward(self, x):
|
|
x = F.adaptive_avg_pool2d(x, self.k)
|
|
|
|
return x.squeeze(-1).squeeze(-1)
|
|
|
|
class AdaptiveSpatialFeatureFusion(nn.Module):
|
|
def __init__(self, args, in_channels=[512, 1024, 1024],
|
|
out_channels=[256, 512, 1024]):
|
|
|
|
super(AdaptiveSpatialFeatureFusion, self).__init__()
|
|
self.weight = nn.LayerNorm(out_channels[2])
|
|
self.proj = linear_layer(args, in_channels[0], out_channels[2])
|
|
|
|
def forward(self, feature_map1, feature_map2):
|
|
|
|
|
|
feature_map2 = self.proj(feature_map2.squeeze(-1).squeeze(-1))
|
|
feature_map1 = feature_map1.squeeze(-1).squeeze(-1)
|
|
weights1 = torch.norm(feature_map1, dim=1).unsqueeze(-1)
|
|
weights2 = torch.norm(feature_map2, dim=1).unsqueeze(-1)
|
|
weights1 = weights1 / (weights1 + weights2)
|
|
weights2 = 1 - weights1
|
|
|
|
fused_feature_map = weights1 * feature_map1 + weights2 * feature_map2
|
|
|
|
return fused_feature_map
|
|
|
|
class ModifiedAttentionPool2d(nn.Module):
|
|
def __init__(self,
|
|
spacial_dim: int,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
output_dim: int = None):
|
|
super().__init__()
|
|
self.spacial_dim = spacial_dim
|
|
self.positional_embedding = nn.Parameter(
|
|
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
|
self.num_heads = num_heads
|
|
|
|
self.connect = nn.Sequential(
|
|
nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
|
|
nn.BatchNorm2d(output_dim))
|
|
|
|
def resize_pos_embed(self, pos_embed, input_shpae):
|
|
"""Resize pos_embed weights.
|
|
Resize pos_embed using bicubic interpolate method.
|
|
Args:
|
|
pos_embed (torch.Tensor): Position embedding weights.
|
|
input_shpae (tuple): Tuple for (downsampled input image height,
|
|
downsampled input image width).
|
|
pos_shape (tuple): The resolution of downsampled origin training
|
|
image.
|
|
mode (str): Algorithm used for upsampling:
|
|
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
|
``'trilinear'``. Default: ``'nearest'``
|
|
Return:
|
|
torch.Tensor: The resized pos_embed of shape [B, C, L_new]
|
|
"""
|
|
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
|
pos_h = pos_w = self.spacial_dim
|
|
cls_token_weight = pos_embed[:, 0]
|
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
|
pos_embed_weight = pos_embed_weight.reshape(
|
|
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
|
pos_embed_weight = F.interpolate(pos_embed_weight,
|
|
size=input_shpae,
|
|
align_corners=False,
|
|
mode='bicubic')
|
|
cls_token_weight = cls_token_weight.unsqueeze(1)
|
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
|
|
|
return pos_embed_weight.transpose(-2, -1)
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.size()
|
|
res = self.connect(x)
|
|
x = x.reshape(B, C, -1)
|
|
|
|
pos_embed = self.positional_embedding.unsqueeze(0)
|
|
pos_embed = self.resize_pos_embed(pos_embed, (H, W))
|
|
x = x + pos_embed.to(x.dtype)
|
|
x = x.permute(2, 0, 1)
|
|
x, _ = F.multi_head_attention_forward(
|
|
query=x,
|
|
key=x,
|
|
value=x,
|
|
embed_dim_to_check=x.shape[-1],
|
|
num_heads=self.num_heads,
|
|
q_proj_weight=self.q_proj.weight,
|
|
k_proj_weight=self.k_proj.weight,
|
|
v_proj_weight=self.v_proj.weight,
|
|
in_proj_weight=None,
|
|
in_proj_bias=torch.cat(
|
|
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
bias_k=None,
|
|
bias_v=None,
|
|
add_zero_attn=False,
|
|
dropout_p=0,
|
|
out_proj_weight=self.c_proj.weight,
|
|
out_proj_bias=self.c_proj.bias,
|
|
use_separate_proj_weight=True,
|
|
training=self.training,
|
|
need_weights=False)
|
|
xt = x[0]
|
|
x = x.permute(1, 2, 0).reshape(B, -1, H, W)
|
|
x = x + res
|
|
x = F.relu(x, True)
|
|
|
|
return x, xt
|
|
|
|
|
|
class FPN(nn.Module):
|
|
def __init__(self, args,
|
|
in_channels=[512, 1024, 1024],
|
|
out_channels=[256, 512, 1024, 1024]):
|
|
super(FPN, self).__init__()
|
|
input_resolution = args.input_size
|
|
heads = args.heads
|
|
output_dim = args.output_dim
|
|
embed_dim = args.emb_dim
|
|
|
|
self.attn = ModifiedAttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
|
|
|
self.txt_proj = linear_layer(args, in_channels[2], out_channels[2])
|
|
|
|
self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0)
|
|
|
|
self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]),
|
|
nn.ReLU(True))
|
|
|
|
|
|
self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
|
|
self.f2_cat = conv_layer(out_channels[2] + out_channels[1],
|
|
out_channels[1], 1, 0)
|
|
|
|
self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1)
|
|
self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
|
|
out_channels[1], 1, 0)
|
|
|
|
self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1)
|
|
self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1)
|
|
self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
|
|
|
|
self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0)
|
|
self.coordconv = nn.Sequential(
|
|
CoordConv(out_channels[1], out_channels[1], 3, 1),
|
|
conv_layer(out_channels[1], out_channels[3], 3, 1))
|
|
|
|
def forward(self, imgs, text):
|
|
|
|
v3, v4, v5 = imgs
|
|
|
|
|
|
|
|
v5, _ = self.attn(v5)
|
|
text_ = self.txt_proj(text)
|
|
state = text_.unsqueeze(-1).unsqueeze(
|
|
-1)
|
|
|
|
f5 = self.f1_v_proj(v5)
|
|
|
|
f5 = self.norm_layer(f5 * state)
|
|
|
|
f4 = self.f2_v_proj(v4)
|
|
|
|
|
|
f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
|
|
f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
|
|
|
|
f3 = self.f3_v_proj(v3)
|
|
f3 = F.avg_pool2d(f3, 2, 2)
|
|
|
|
|
|
f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
|
|
|
|
fq5 = self.f4_proj5(f5)
|
|
fq4 = self.f4_proj4(f4)
|
|
fq3 = self.f4_proj3(f3)
|
|
|
|
fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
|
|
fq = torch.cat([fq3, fq4, fq5], dim=1)
|
|
fq = self.aggr(fq)
|
|
fq = self.coordconv(fq)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return fq
|
|
|
|
class ViTFPN(nn.Module):
|
|
def __init__(self, image_resolution,
|
|
in_channels=[512, 768, 768],
|
|
out_channels=[768, 768, 768, 512]):
|
|
super(ViTFPN, self).__init__()
|
|
|
|
self.txt_proj = linear_layer(in_channels[0], out_channels[1])
|
|
|
|
self.f1_v_proj = conv_layer(in_channels[1], out_channels[1], 1, 0)
|
|
self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[1]),
|
|
nn.ReLU(True))
|
|
|
|
self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
|
|
self.f2_cat = conv_layer(out_channels[0] + out_channels[0],
|
|
out_channels[0], 1, 0)
|
|
|
|
self.f3_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
|
|
self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
|
|
out_channels[1], 1, 0)
|
|
|
|
self.f4_proj5 = conv_layer(out_channels[1], out_channels[0], 3, 1)
|
|
self.f4_proj4 = conv_layer(out_channels[0], out_channels[0], 3, 1)
|
|
self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
|
|
|
|
self.aggr = conv_layer(3 * out_channels[0], out_channels[0], 1, 0)
|
|
self.coordconv = nn.Sequential(
|
|
CoordConv(out_channels[0], out_channels[0], 3, 1),
|
|
conv_layer(out_channels[0], out_channels[-1], 3, 1))
|
|
|
|
self.attnpool = AttentionPool2d(image_resolution // 32, out_channels[-1],
|
|
8, out_channels[-1])
|
|
def forward(self, imgs, state, vis):
|
|
|
|
v3, v4, v5 = imgs
|
|
|
|
|
|
state = self.txt_proj(state)
|
|
state = state.unsqueeze(-1).unsqueeze(
|
|
-1)
|
|
f5 = self.f1_v_proj(v5)
|
|
f5 = self.norm_layer(f5 * state)
|
|
|
|
f4 = self.f2_v_proj(v4)
|
|
b, c, h, w = f4.size()
|
|
f5_ = F.interpolate(f5, (h, w), mode='bilinear')
|
|
f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
|
|
|
|
|
|
f3 = self.f3_v_proj(v3)
|
|
f3 = F.avg_pool2d(f3, 2, 2)
|
|
|
|
|
|
f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
|
|
|
|
fq5 = self.f4_proj5(f5)
|
|
fq4 = self.f4_proj4(f4)
|
|
fq3 = self.f4_proj3(f3)
|
|
|
|
fq5 = F.interpolate(fq5, (h, w), mode='bilinear')
|
|
fq = torch.cat([fq3, fq4, fq5], dim=1)
|
|
fq = self.aggr(fq)
|
|
if not vis:
|
|
fq = self.coordconv(fq)
|
|
fq = self.attnpool(fq)
|
|
|
|
return fq
|
|
|
|
|
|
|