|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
|
|
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert |
|
|
|
def window_partition(x, window_size: int): |
|
""" |
|
Args: |
|
x: (B, H, W, C) |
|
window_size (int): window size |
|
|
|
Returns: |
|
windows: (num_windows*B, window_size, window_size, C) |
|
""" |
|
B, H, W, C = x.shape |
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
|
return windows |
|
|
|
|
|
def window_reverse(windows, window_size: int, H: int, W: int): |
|
""" |
|
Args: |
|
windows: (num_windows*B, window_size, window_size, C) |
|
window_size (int): Window size |
|
H (int): Height of image |
|
W (int): Width of image |
|
|
|
Returns: |
|
x: (B, H, W, C) |
|
""" |
|
B = int(windows.shape[0] / (H * W / window_size / window_size)) |
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|
return x |
|
|
|
|
|
|
|
class WindowAttention(nn.Module): |
|
r""" Window based multi-head self attention (W-MSA) module with relative position bias. |
|
It supports both of shifted and non-shifted window. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
num_heads (int): Number of attention heads. |
|
head_dim (int): Number of channels per head (dim // num_heads if not set) |
|
window_size (tuple[int]): The height and width of the window. |
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
|
""" |
|
|
|
def __init__(self, dim, appearance_guidance_dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.): |
|
|
|
super().__init__() |
|
self.dim = dim |
|
self.window_size = to_2tuple(window_size) |
|
win_h, win_w = self.window_size |
|
self.window_area = win_h * win_w |
|
self.num_heads = num_heads |
|
head_dim = head_dim or dim // num_heads |
|
attn_dim = head_dim * num_heads |
|
self.scale = head_dim ** -0.5 |
|
|
|
self.q = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias) |
|
self.k = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias) |
|
self.v = nn.Linear(dim, attn_dim, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(attn_dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, x, mask=None): |
|
""" |
|
Args: |
|
x: input features with shape of (num_windows*B, N, C) |
|
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
|
""" |
|
B_, N, C = x.shape |
|
|
|
q = self.q(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3) |
|
k = self.k(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3) |
|
v = self.v(x[:, :, :self.dim]).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3) |
|
|
|
q = q * self.scale |
|
attn = (q @ k.transpose(-2, -1)) |
|
|
|
if mask is not None: |
|
num_win = mask.shape[0] |
|
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
|
attn = attn.view(-1, self.num_heads, N, N) |
|
attn = self.softmax(attn) |
|
else: |
|
attn = self.softmax(attn) |
|
|
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
class SwinTransformerBlock(nn.Module): |
|
r""" Swin Transformer Block. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
input_resolution (tuple[int]): Input resulotion. |
|
window_size (int): Window size. |
|
num_heads (int): Number of attention heads. |
|
head_dim (int): Enforce the number of channels per head |
|
shift_size (int): Shift size for SW-MSA. |
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
drop (float, optional): Dropout rate. Default: 0.0 |
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
""" |
|
|
|
def __init__( |
|
self, dim, appearance_guidance_dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0, |
|
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., |
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
|
super().__init__() |
|
self.dim = dim |
|
self.input_resolution = input_resolution |
|
self.window_size = window_size |
|
self.shift_size = shift_size |
|
self.mlp_ratio = mlp_ratio |
|
if min(self.input_resolution) <= self.window_size: |
|
|
|
self.shift_size = 0 |
|
self.window_size = min(self.input_resolution) |
|
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn = WindowAttention( |
|
dim, appearance_guidance_dim=appearance_guidance_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), |
|
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) |
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) |
|
|
|
if self.shift_size > 0: |
|
|
|
H, W = self.input_resolution |
|
img_mask = torch.zeros((1, H, W, 1)) |
|
cnt = 0 |
|
for h in ( |
|
slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)): |
|
for w in ( |
|
slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)): |
|
img_mask[:, h, w, :] = cnt |
|
cnt += 1 |
|
mask_windows = window_partition(img_mask, self.window_size) |
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
else: |
|
attn_mask = None |
|
|
|
self.register_buffer("attn_mask", attn_mask) |
|
|
|
def forward(self, x, appearance_guidance): |
|
H, W = self.input_resolution |
|
B, L, C = x.shape |
|
assert L == H * W, "input feature has wrong size" |
|
|
|
shortcut = x |
|
x = self.norm1(x) |
|
x = x.view(B, H, W, C) |
|
if appearance_guidance is not None: |
|
appearance_guidance = appearance_guidance.view(B, H, W, -1) |
|
x = torch.cat([x, appearance_guidance], dim=-1) |
|
|
|
|
|
if self.shift_size > 0: |
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
|
else: |
|
shifted_x = x |
|
|
|
|
|
x_windows = window_partition(shifted_x, self.window_size) |
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) |
|
|
|
|
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) |
|
|
|
|
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) |
|
|
|
|
|
if self.shift_size > 0: |
|
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
|
else: |
|
x = shifted_x |
|
x = x.view(B, H * W, C) |
|
|
|
|
|
x = shortcut + self.drop_path(x) |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
|
return x |
|
|
|
|
|
class SwinTransformerBlockWrapper(nn.Module): |
|
def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5): |
|
super().__init__() |
|
self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0) |
|
self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2) |
|
self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None |
|
|
|
def forward(self, x, appearance_guidance): |
|
""" |
|
Arguments: |
|
x: B C T H W |
|
appearance_guidance: B C H W |
|
""" |
|
B, C, T, H, W = x.shape |
|
x = rearrange(x, 'B C T H W -> (B T) (H W) C') |
|
if appearance_guidance is not None: |
|
appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T)) |
|
x = self.block_1(x, appearance_guidance) |
|
x = self.block_2(x, appearance_guidance) |
|
x = rearrange(x, '(B T) (H W) C -> B C T H W', B=B, T=T, H=H, W=W) |
|
return x |
|
|
|
|
|
def elu_feature_map(x): |
|
return torch.nn.functional.elu(x) + 1 |
|
|
|
|
|
class LinearAttention(nn.Module): |
|
def __init__(self, eps=1e-6): |
|
super().__init__() |
|
self.feature_map = elu_feature_map |
|
self.eps = eps |
|
|
|
def forward(self, queries, keys, values): |
|
""" Multi-Head linear attention proposed in "Transformers are RNNs" |
|
Args: |
|
queries: [N, L, H, D] |
|
keys: [N, S, H, D] |
|
values: [N, S, H, D] |
|
q_mask: [N, L] |
|
kv_mask: [N, S] |
|
Returns: |
|
queried_values: (N, L, H, D) |
|
""" |
|
Q = self.feature_map(queries) |
|
K = self.feature_map(keys) |
|
|
|
v_length = values.size(1) |
|
values = values / v_length |
|
KV = torch.einsum("nshd,nshv->nhdv", K, values) |
|
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) |
|
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length |
|
|
|
return queried_values.contiguous() |
|
|
|
|
|
class FullAttention(nn.Module): |
|
def __init__(self, use_dropout=False, attention_dropout=0.1): |
|
super().__init__() |
|
self.use_dropout = use_dropout |
|
self.dropout = nn.Dropout(attention_dropout) |
|
|
|
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
|
""" Multi-head scaled dot-product attention, a.k.a full attention. |
|
Args: |
|
queries: [N, L, H, D] |
|
keys: [N, S, H, D] |
|
values: [N, S, H, D] |
|
q_mask: [N, L] |
|
kv_mask: [N, S] |
|
Returns: |
|
queried_values: (N, L, H, D) |
|
""" |
|
|
|
|
|
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) |
|
if kv_mask is not None: |
|
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) |
|
|
|
|
|
softmax_temp = 1. / queries.size(3)**.5 |
|
A = torch.softmax(softmax_temp * QK, dim=2) |
|
if self.use_dropout: |
|
A = self.dropout(A) |
|
|
|
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) |
|
|
|
return queried_values.contiguous() |
|
|
|
|
|
class AttentionLayer(nn.Module): |
|
def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'): |
|
super().__init__() |
|
self.nheads = nheads |
|
self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim) |
|
self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim) |
|
self.v = nn.Linear(hidden_dim, hidden_dim) |
|
|
|
if attention_type == 'linear': |
|
self.attention = LinearAttention() |
|
elif attention_type == 'full': |
|
self.attention = FullAttention() |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, x, guidance): |
|
""" |
|
Arguments: |
|
x: B, L, C |
|
guidance: B, L, C |
|
""" |
|
q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x) |
|
k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x) |
|
v = self.v(x) |
|
|
|
q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads) |
|
k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads) |
|
v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads) |
|
|
|
out = self.attention(q, k, v) |
|
out = rearrange(out, 'B L H D -> B L (H D)') |
|
return out |
|
|
|
|
|
class ClassTransformerLayer(nn.Module): |
|
def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None: |
|
super().__init__() |
|
self.pool = nn.AvgPool2d(pooling_size) |
|
self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type) |
|
self.MLP = nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim * 4, hidden_dim) |
|
) |
|
|
|
self.norm1 = nn.LayerNorm(hidden_dim) |
|
self.norm2 = nn.LayerNorm(hidden_dim) |
|
|
|
def pool_features(self, x): |
|
""" |
|
Intermediate pooling layer for computational efficiency. |
|
Arguments: |
|
x: B, C, T, H, W |
|
""" |
|
B = x.size(0) |
|
x = rearrange(x, 'B C T H W -> (B T) C H W') |
|
x = self.pool(x) |
|
x = rearrange(x, '(B T) C H W -> B C T H W', B=B) |
|
return x |
|
|
|
def forward(self, x, guidance): |
|
""" |
|
Arguments: |
|
x: B, C, T, H, W |
|
guidance: B, T, C |
|
""" |
|
B, _, _, H, W = x.size() |
|
x_pool = self.pool_features(x) |
|
*_, H_pool, W_pool = x_pool.size() |
|
|
|
x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C') |
|
if guidance is not None: |
|
guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool) |
|
|
|
x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) |
|
x_pool = x_pool + self.MLP(self.norm2(x_pool)) |
|
|
|
x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool) |
|
x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True) |
|
x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B) |
|
|
|
x = x + x_pool |
|
return x |
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
|
"""3x3 convolution with padding""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
|
padding=dilation, groups=groups, bias=False, dilation=dilation) |
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1): |
|
"""1x1 convolution""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
expansion = 4 |
|
__constants__ = ['downsample'] |
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, |
|
base_width=64, dilation=1, norm_layer=None): |
|
super(Bottleneck, self).__init__() |
|
if norm_layer is None: |
|
norm_layer = nn.BatchNorm2d |
|
width = int(planes * (base_width / 64.)) * groups |
|
|
|
self.conv1 = conv1x1(inplanes, width) |
|
self.bn1 = norm_layer(width) |
|
self.conv2 = conv3x3(width, width, stride, groups, dilation) |
|
self.bn2 = norm_layer(width) |
|
self.conv3 = conv1x1(width, planes * self.expansion) |
|
self.bn3 = norm_layer(planes * self.expansion) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv3(out) |
|
out = self.bn3(out) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class AggregatorLayer(nn.Module): |
|
def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None: |
|
super().__init__() |
|
self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size) |
|
self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size) |
|
|
|
|
|
def forward(self, x, appearance_guidance, text_guidance): |
|
""" |
|
Arguments: |
|
x: B C T H W |
|
""" |
|
x = self.swin_block(x, appearance_guidance) |
|
x = self.attention(x, text_guidance) |
|
return x |
|
|
|
|
|
class AggregatorResNetLayer(nn.Module): |
|
def __init__(self, hidden_dim=64, appearance_guidance=512) -> None: |
|
super().__init__() |
|
self.conv_linear = nn.Conv2d(hidden_dim + appearance_guidance, hidden_dim, kernel_size=1, stride=1) |
|
self.conv_layer = Bottleneck(hidden_dim, hidden_dim // 4) |
|
|
|
|
|
def forward(self, x, appearance_guidance): |
|
""" |
|
Arguments: |
|
x: B C T H W |
|
""" |
|
B, T = x.size(0), x.size(2) |
|
x = rearrange(x, 'B C T H W -> (B T) C H W') |
|
appearance_guidance = repeat(appearance_guidance, 'B C H W -> (B T) C H W', T=T) |
|
|
|
x = self.conv_linear(torch.cat([x, appearance_guidance], dim=1)) |
|
x = self.conv_layer(x) |
|
x = rearrange(x, '(B T) C H W -> B C T H W', B=B) |
|
return x |
|
|
|
|
|
class DoubleConv(nn.Module): |
|
"""(convolution => [GN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels, mid_channels=None): |
|
super().__init__() |
|
if not mid_channels: |
|
mid_channels = out_channels |
|
self.double_conv = nn.Sequential( |
|
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), |
|
nn.GroupNorm(mid_channels // 16, mid_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), |
|
nn.GroupNorm(mid_channels // 16, mid_channels), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.double_conv(x) |
|
|
|
|
|
class Up(nn.Module): |
|
"""Upscaling then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels, guidance_channels): |
|
super().__init__() |
|
|
|
self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2) |
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
|
def forward(self, x, guidance=None): |
|
x = self.up(x) |
|
if guidance is not None: |
|
T = x.size(0) // guidance.size(0) |
|
guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T) |
|
x = torch.cat([x, guidance], dim=1) |
|
return self.conv(x) |
|
|
|
|
|
class Aggregator(nn.Module): |
|
def __init__(self, |
|
text_guidance_dim=512, |
|
text_guidance_proj_dim=128, |
|
appearance_guidance_dim=512, |
|
appearance_guidance_proj_dim=128, |
|
decoder_dims = (64, 32), |
|
decoder_guidance_dims=(256, 128), |
|
decoder_guidance_proj_dims=(32, 16), |
|
num_layers=4, |
|
nheads=4, |
|
hidden_dim=128, |
|
pooling_size=(6, 6), |
|
feature_resolution=(24, 24), |
|
window_size=12, |
|
attention_type='linear', |
|
prompt_channel=80, |
|
) -> None: |
|
super().__init__() |
|
self.num_layers = num_layers |
|
self.hidden_dim = hidden_dim |
|
|
|
self.layers = nn.ModuleList([ |
|
AggregatorLayer( |
|
hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim, |
|
nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type |
|
) for _ in range(num_layers) |
|
]) |
|
|
|
self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3) |
|
|
|
self.guidance_projection = nn.Sequential( |
|
nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
) if appearance_guidance_dim > 0 else None |
|
|
|
self.text_guidance_projection = nn.Sequential( |
|
nn.Linear(text_guidance_dim, text_guidance_proj_dim), |
|
nn.ReLU(), |
|
) if text_guidance_dim > 0 else None |
|
|
|
self.decoder_guidance_projection = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims) |
|
]) if decoder_guidance_dims[0] > 0 else None |
|
|
|
self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0]) |
|
self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1]) |
|
self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1) |
|
|
|
def feature_map(self, img_feats, text_feats): |
|
img_feats = F.normalize(img_feats, dim=1) |
|
img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1]) |
|
text_feats = F.normalize(text_feats, dim=-1) |
|
text_feats = text_feats.mean(dim=-2) |
|
text_feats = F.normalize(text_feats, dim=-1) |
|
text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1]) |
|
return torch.cat((img_feats, text_feats), dim=1) |
|
|
|
def correlation(self, img_feats, text_feats): |
|
img_feats = F.normalize(img_feats, dim=1) |
|
text_feats = F.normalize(text_feats, dim=-1) |
|
corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats) |
|
return corr |
|
|
|
def corr_embed(self, x): |
|
B = x.shape[0] |
|
corr_embed = rearrange(x, 'B P T H W -> (B T) P H W') |
|
corr_embed = self.conv1(corr_embed) |
|
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B) |
|
return corr_embed |
|
|
|
def corr_projection(self, x, proj): |
|
corr_embed = rearrange(x, 'B C T H W -> B T H W C') |
|
corr_embed = proj(corr_embed) |
|
corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W') |
|
return corr_embed |
|
|
|
def upsample(self, x): |
|
B = x.shape[0] |
|
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') |
|
corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True) |
|
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B) |
|
return corr_embed |
|
|
|
def conv_decoder(self, x, guidance): |
|
B = x.shape[0] |
|
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') |
|
corr_embed = self.decoder1(corr_embed, guidance[0]) |
|
corr_embed = self.decoder2(corr_embed, guidance[1]) |
|
corr_embed = self.head(corr_embed) |
|
corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B) |
|
return corr_embed |
|
|
|
def forward(self, img_feats, text_feats, appearance_guidance): |
|
""" |
|
Arguments: |
|
img_feats: (B, C, H, W) |
|
text_feats: (B, T, P, C) |
|
apperance_guidance: tuple of (B, C, H, W) |
|
""" |
|
corr = self.correlation(img_feats, text_feats) |
|
|
|
corr_embed = self.corr_embed(corr) |
|
|
|
projected_guidance, projected_text_guidance, projected_decoder_guidance = None, None, [None, None] |
|
if self.guidance_projection is not None: |
|
projected_guidance = self.guidance_projection(appearance_guidance[0]) |
|
if self.decoder_guidance_projection is not None: |
|
projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])] |
|
|
|
if self.text_guidance_projection is not None: |
|
text_feats = text_feats.mean(dim=-2) |
|
text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) |
|
projected_text_guidance = self.text_guidance_projection(text_feats) |
|
|
|
for layer in self.layers: |
|
corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance) |
|
|
|
logit = self.conv_decoder(corr_embed, projected_decoder_guidance) |
|
|
|
return logit |
|
|