Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert | |
def window_partition(x, window_size: int): | |
""" | |
Args: | |
x: (B, H, W, C) | |
window_size (int): window size | |
Returns: | |
windows: (num_windows*B, window_size, window_size, C) | |
""" | |
B, H, W, C = x.shape | |
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) | |
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | |
return windows | |
def window_reverse(windows, window_size: int, H: int, W: int): | |
""" | |
Args: | |
windows: (num_windows*B, window_size, window_size, C) | |
window_size (int): Window size | |
H (int): Height of image | |
W (int): Width of image | |
Returns: | |
x: (B, H, W, C) | |
""" | |
B = int(windows.shape[0] / (H * W / window_size / window_size)) | |
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) | |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |
return x | |
class WindowAttention(nn.Module): | |
r""" Window based multi-head self attention (W-MSA) module with relative position bias. | |
It supports both of shifted and non-shifted window. | |
Args: | |
dim (int): Number of input channels. | |
num_heads (int): Number of attention heads. | |
head_dim (int): Number of channels per head (dim // num_heads if not set) | |
window_size (tuple[int]): The height and width of the window. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
""" | |
def __init__(self, dim, appearance_guidance_dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.dim = dim | |
self.window_size = to_2tuple(window_size) # Wh, Ww | |
win_h, win_w = self.window_size | |
self.window_area = win_h * win_w | |
self.num_heads = num_heads | |
head_dim = head_dim or dim // num_heads | |
attn_dim = head_dim * num_heads | |
self.scale = head_dim ** -0.5 | |
self.q = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias) | |
self.k = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias) | |
self.v = nn.Linear(dim, attn_dim, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(attn_dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x, mask=None): | |
""" | |
Args: | |
x: input features with shape of (num_windows*B, N, C) | |
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |
""" | |
B_, N, C = x.shape | |
q = self.q(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
k = self.k(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
v = self.v(x[:, :, :self.dim]).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
q = q * self.scale | |
attn = (q @ k.transpose(-2, -1)) | |
if mask is not None: | |
num_win = mask.shape[0] | |
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) | |
attn = attn.view(-1, self.num_heads, N, N) | |
attn = self.softmax(attn) | |
else: | |
attn = self.softmax(attn) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class SwinTransformerBlock(nn.Module): | |
r""" Swin Transformer Block. | |
Args: | |
dim (int): Number of input channels. | |
input_resolution (tuple[int]): Input resulotion. | |
window_size (int): Window size. | |
num_heads (int): Number of attention heads. | |
head_dim (int): Enforce the number of channels per head | |
shift_size (int): Shift size for SW-MSA. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
drop (float, optional): Dropout rate. Default: 0.0 | |
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
""" | |
def __init__( | |
self, dim, appearance_guidance_dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0, | |
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., | |
act_layer=nn.GELU, norm_layer=nn.LayerNorm): | |
super().__init__() | |
self.dim = dim | |
self.input_resolution = input_resolution | |
self.window_size = window_size | |
self.shift_size = shift_size | |
self.mlp_ratio = mlp_ratio | |
if min(self.input_resolution) <= self.window_size: | |
# if window size is larger than input resolution, we don't partition windows | |
self.shift_size = 0 | |
self.window_size = min(self.input_resolution) | |
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" | |
self.norm1 = norm_layer(dim) | |
self.attn = WindowAttention( | |
dim, appearance_guidance_dim=appearance_guidance_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), | |
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) | |
if self.shift_size > 0: | |
# calculate attention mask for SW-MSA | |
H, W = self.input_resolution | |
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 | |
cnt = 0 | |
for h in ( | |
slice(0, -self.window_size), | |
slice(-self.window_size, -self.shift_size), | |
slice(-self.shift_size, None)): | |
for w in ( | |
slice(0, -self.window_size), | |
slice(-self.window_size, -self.shift_size), | |
slice(-self.shift_size, None)): | |
img_mask[:, h, w, :] = cnt | |
cnt += 1 | |
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 | |
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) | |
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) | |
else: | |
attn_mask = None | |
self.register_buffer("attn_mask", attn_mask) | |
def forward(self, x, appearance_guidance): | |
H, W = self.input_resolution | |
B, L, C = x.shape | |
assert L == H * W, "input feature has wrong size" | |
shortcut = x | |
x = self.norm1(x) | |
x = x.view(B, H, W, C) | |
if appearance_guidance is not None: | |
appearance_guidance = appearance_guidance.view(B, H, W, -1) | |
x = torch.cat([x, appearance_guidance], dim=-1) | |
# cyclic shift | |
if self.shift_size > 0: | |
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |
else: | |
shifted_x = x | |
# partition windows | |
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C | |
x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) # num_win*B, window_size*window_size, C | |
# W-MSA/SW-MSA | |
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C | |
# merge windows | |
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) | |
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C | |
# reverse cyclic shift | |
if self.shift_size > 0: | |
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) | |
else: | |
x = shifted_x | |
x = x.view(B, H * W, C) | |
# FFN | |
x = shortcut + self.drop_path(x) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class SwinTransformerBlockWrapper(nn.Module): | |
def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5): | |
super().__init__() | |
self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0) | |
self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2) | |
self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None | |
def forward(self, x, appearance_guidance): | |
""" | |
Arguments: | |
x: B C T H W | |
appearance_guidance: B C H W | |
""" | |
B, C, T, H, W = x.shape | |
x = rearrange(x, 'B C T H W -> (B T) (H W) C') | |
if appearance_guidance is not None: | |
appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T)) | |
x = self.block_1(x, appearance_guidance) | |
x = self.block_2(x, appearance_guidance) | |
x = rearrange(x, '(B T) (H W) C -> B C T H W', B=B, T=T, H=H, W=W) | |
return x | |
def elu_feature_map(x): | |
return torch.nn.functional.elu(x) + 1 | |
class LinearAttention(nn.Module): | |
def __init__(self, eps=1e-6): | |
super().__init__() | |
self.feature_map = elu_feature_map | |
self.eps = eps | |
def forward(self, queries, keys, values): | |
""" Multi-Head linear attention proposed in "Transformers are RNNs" | |
Args: | |
queries: [N, L, H, D] | |
keys: [N, S, H, D] | |
values: [N, S, H, D] | |
q_mask: [N, L] | |
kv_mask: [N, S] | |
Returns: | |
queried_values: (N, L, H, D) | |
""" | |
Q = self.feature_map(queries) | |
K = self.feature_map(keys) | |
v_length = values.size(1) | |
values = values / v_length # prevent fp16 overflow | |
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V | |
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) | |
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length | |
return queried_values.contiguous() | |
class FullAttention(nn.Module): | |
def __init__(self, use_dropout=False, attention_dropout=0.1): | |
super().__init__() | |
self.use_dropout = use_dropout | |
self.dropout = nn.Dropout(attention_dropout) | |
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): | |
""" Multi-head scaled dot-product attention, a.k.a full attention. | |
Args: | |
queries: [N, L, H, D] | |
keys: [N, S, H, D] | |
values: [N, S, H, D] | |
q_mask: [N, L] | |
kv_mask: [N, S] | |
Returns: | |
queried_values: (N, L, H, D) | |
""" | |
# Compute the unnormalized attention and apply the masks | |
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) | |
if kv_mask is not None: | |
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) | |
# Compute the attention and the weighted average | |
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) | |
A = torch.softmax(softmax_temp * QK, dim=2) | |
if self.use_dropout: | |
A = self.dropout(A) | |
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) | |
return queried_values.contiguous() | |
class AttentionLayer(nn.Module): | |
def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'): | |
super().__init__() | |
self.nheads = nheads | |
self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim) | |
self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim) | |
self.v = nn.Linear(hidden_dim, hidden_dim) | |
if attention_type == 'linear': | |
self.attention = LinearAttention() | |
elif attention_type == 'full': | |
self.attention = FullAttention() | |
else: | |
raise NotImplementedError | |
def forward(self, x, guidance): | |
""" | |
Arguments: | |
x: B, L, C | |
guidance: B, L, C | |
""" | |
q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x) | |
k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x) | |
v = self.v(x) | |
q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads) | |
k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads) | |
v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads) | |
out = self.attention(q, k, v) | |
out = rearrange(out, 'B L H D -> B L (H D)') | |
return out | |
class ClassTransformerLayer(nn.Module): | |
def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None: | |
super().__init__() | |
self.pool = nn.AvgPool2d(pooling_size) | |
self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type) | |
self.MLP = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim * 4), | |
nn.ReLU(), | |
nn.Linear(hidden_dim * 4, hidden_dim) | |
) | |
self.norm1 = nn.LayerNorm(hidden_dim) | |
self.norm2 = nn.LayerNorm(hidden_dim) | |
def pool_features(self, x): | |
""" | |
Intermediate pooling layer for computational efficiency. | |
Arguments: | |
x: B, C, T, H, W | |
""" | |
B = x.size(0) | |
x = rearrange(x, 'B C T H W -> (B T) C H W') | |
x = self.pool(x) | |
x = rearrange(x, '(B T) C H W -> B C T H W', B=B) | |
return x | |
def forward(self, x, guidance): | |
""" | |
Arguments: | |
x: B, C, T, H, W | |
guidance: B, T, C | |
""" | |
B, _, _, H, W = x.size() | |
x_pool = self.pool_features(x) | |
*_, H_pool, W_pool = x_pool.size() | |
x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C') | |
if guidance is not None: | |
guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool) | |
x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) # Attention | |
x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP | |
x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool) | |
x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True) | |
x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B) | |
x = x + x_pool # Residual | |
return x | |
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |
"""3x3 convolution with padding""" | |
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |
padding=dilation, groups=groups, bias=False, dilation=dilation) | |
def conv1x1(in_planes, out_planes, stride=1): | |
"""1x1 convolution""" | |
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |
class Bottleneck(nn.Module): | |
expansion = 4 | |
__constants__ = ['downsample'] | |
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |
base_width=64, dilation=1, norm_layer=None): | |
super(Bottleneck, self).__init__() | |
if norm_layer is None: | |
norm_layer = nn.BatchNorm2d | |
width = int(planes * (base_width / 64.)) * groups | |
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |
self.conv1 = conv1x1(inplanes, width) | |
self.bn1 = norm_layer(width) | |
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |
self.bn2 = norm_layer(width) | |
self.conv3 = conv1x1(width, planes * self.expansion) | |
self.bn3 = norm_layer(planes * self.expansion) | |
self.relu = nn.ReLU(inplace=True) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out | |
class AggregatorLayer(nn.Module): | |
def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None: | |
super().__init__() | |
self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size) | |
self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size) | |
def forward(self, x, appearance_guidance, text_guidance): | |
""" | |
Arguments: | |
x: B C T H W | |
""" | |
x = self.swin_block(x, appearance_guidance) | |
x = self.attention(x, text_guidance) | |
return x | |
class AggregatorResNetLayer(nn.Module): | |
def __init__(self, hidden_dim=64, appearance_guidance=512) -> None: | |
super().__init__() | |
self.conv_linear = nn.Conv2d(hidden_dim + appearance_guidance, hidden_dim, kernel_size=1, stride=1) | |
self.conv_layer = Bottleneck(hidden_dim, hidden_dim // 4) | |
def forward(self, x, appearance_guidance): | |
""" | |
Arguments: | |
x: B C T H W | |
""" | |
B, T = x.size(0), x.size(2) | |
x = rearrange(x, 'B C T H W -> (B T) C H W') | |
appearance_guidance = repeat(appearance_guidance, 'B C H W -> (B T) C H W', T=T) | |
x = self.conv_linear(torch.cat([x, appearance_guidance], dim=1)) | |
x = self.conv_layer(x) | |
x = rearrange(x, '(B T) C H W -> B C T H W', B=B) | |
return x | |
class DoubleConv(nn.Module): | |
"""(convolution => [GN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
nn.GroupNorm(mid_channels // 16, mid_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
nn.GroupNorm(mid_channels // 16, mid_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
class Up(nn.Module): | |
"""Upscaling then double conv""" | |
def __init__(self, in_channels, out_channels, guidance_channels): | |
super().__init__() | |
self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2) | |
self.conv = DoubleConv(in_channels, out_channels) | |
def forward(self, x, guidance=None): | |
x = self.up(x) | |
if guidance is not None: | |
T = x.size(0) // guidance.size(0) | |
guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T) | |
x = torch.cat([x, guidance], dim=1) | |
return self.conv(x) | |
class Aggregator(nn.Module): | |
def __init__(self, | |
text_guidance_dim=512, | |
text_guidance_proj_dim=128, | |
appearance_guidance_dim=512, | |
appearance_guidance_proj_dim=128, | |
decoder_dims = (64, 32), | |
decoder_guidance_dims=(256, 128), | |
decoder_guidance_proj_dims=(32, 16), | |
num_layers=4, | |
nheads=4, | |
hidden_dim=128, | |
pooling_size=(6, 6), | |
feature_resolution=(24, 24), | |
window_size=12, | |
attention_type='linear', | |
prompt_channel=80, | |
) -> None: | |
super().__init__() | |
self.num_layers = num_layers | |
self.hidden_dim = hidden_dim | |
self.layers = nn.ModuleList([ | |
AggregatorLayer( | |
hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim, | |
nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type | |
) for _ in range(num_layers) | |
]) | |
self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3) | |
self.guidance_projection = nn.Sequential( | |
nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
) if appearance_guidance_dim > 0 else None | |
self.text_guidance_projection = nn.Sequential( | |
nn.Linear(text_guidance_dim, text_guidance_proj_dim), | |
nn.ReLU(), | |
) if text_guidance_dim > 0 else None | |
self.decoder_guidance_projection = nn.ModuleList([ | |
nn.Sequential( | |
nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims) | |
]) if decoder_guidance_dims[0] > 0 else None | |
self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0]) | |
self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1]) | |
self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1) | |
def feature_map(self, img_feats, text_feats): | |
img_feats = F.normalize(img_feats, dim=1) # B C H W | |
img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1]) | |
text_feats = F.normalize(text_feats, dim=-1) # B T P C | |
text_feats = text_feats.mean(dim=-2) | |
text_feats = F.normalize(text_feats, dim=-1) # B T C | |
text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1]) | |
return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W | |
def correlation(self, img_feats, text_feats): | |
img_feats = F.normalize(img_feats, dim=1) # B C H W | |
text_feats = F.normalize(text_feats, dim=-1) # B T P C | |
corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats) | |
return corr | |
def corr_embed(self, x): | |
B = x.shape[0] | |
corr_embed = rearrange(x, 'B P T H W -> (B T) P H W') | |
corr_embed = self.conv1(corr_embed) | |
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B) | |
return corr_embed | |
def corr_projection(self, x, proj): | |
corr_embed = rearrange(x, 'B C T H W -> B T H W C') | |
corr_embed = proj(corr_embed) | |
corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W') | |
return corr_embed | |
def upsample(self, x): | |
B = x.shape[0] | |
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') | |
corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True) | |
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B) | |
return corr_embed | |
def conv_decoder(self, x, guidance): | |
B = x.shape[0] | |
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') | |
corr_embed = self.decoder1(corr_embed, guidance[0]) | |
corr_embed = self.decoder2(corr_embed, guidance[1]) | |
corr_embed = self.head(corr_embed) | |
corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B) | |
return corr_embed | |
def forward(self, img_feats, text_feats, appearance_guidance): | |
""" | |
Arguments: | |
img_feats: (B, C, H, W) | |
text_feats: (B, T, P, C) | |
apperance_guidance: tuple of (B, C, H, W) | |
""" | |
corr = self.correlation(img_feats, text_feats) | |
#corr = self.feature_map(img_feats, text_feats) | |
corr_embed = self.corr_embed(corr) | |
projected_guidance, projected_text_guidance, projected_decoder_guidance = None, None, [None, None] | |
if self.guidance_projection is not None: | |
projected_guidance = self.guidance_projection(appearance_guidance[0]) | |
if self.decoder_guidance_projection is not None: | |
projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])] | |
if self.text_guidance_projection is not None: | |
text_feats = text_feats.mean(dim=-2) | |
text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) | |
projected_text_guidance = self.text_guidance_projection(text_feats) | |
for layer in self.layers: | |
corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance) | |
logit = self.conv_decoder(corr_embed, projected_decoder_guidance) | |
return logit | |