seokju cho
initial commit
f8f62f3
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
def window_partition(x, window_size: int):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
head_dim (int): Number of channels per head (dim // num_heads if not set)
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, appearance_guidance_dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = to_2tuple(window_size) # Wh, Ww
win_h, win_w = self.window_size
self.window_area = win_h * win_w
self.num_heads = num_heads
head_dim = head_dim or dim // num_heads
attn_dim = head_dim * num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
self.k = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
self.v = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
q = self.q(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
k = self.k(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = self.v(x[:, :, :self.dim]).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
window_size (int): Window size.
num_heads (int): Number of attention heads.
head_dim (int): Enforce the number of channels per head
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self, dim, appearance_guidance_dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, appearance_guidance_dim=appearance_guidance_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0
for h in (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)):
for w in (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x, appearance_guidance):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
if appearance_guidance is not None:
appearance_guidance = appearance_guidance.view(B, H, W, -1)
x = torch.cat([x, appearance_guidance], dim=-1)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) # num_win*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class SwinTransformerBlockWrapper(nn.Module):
def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5):
super().__init__()
self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0)
self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2)
self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None
def forward(self, x, appearance_guidance):
"""
Arguments:
x: B C T H W
appearance_guidance: B C H W
"""
B, C, T, H, W = x.shape
x = rearrange(x, 'B C T H W -> (B T) (H W) C')
if appearance_guidance is not None:
appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T))
x = self.block_1(x, appearance_guidance)
x = self.block_2(x, appearance_guidance)
x = rearrange(x, '(B T) (H W) C -> B C T H W', B=B, T=T, H=H, W=W)
return x
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class LinearAttention(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps
def forward(self, queries, keys, values):
""" Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
return queried_values.contiguous()
class FullAttention(nn.Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
if kv_mask is not None:
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
# Compute the attention and the weighted average
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
A = torch.softmax(softmax_temp * QK, dim=2)
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
return queried_values.contiguous()
class AttentionLayer(nn.Module):
def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'):
super().__init__()
self.nheads = nheads
self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
if attention_type == 'linear':
self.attention = LinearAttention()
elif attention_type == 'full':
self.attention = FullAttention()
else:
raise NotImplementedError
def forward(self, x, guidance):
"""
Arguments:
x: B, L, C
guidance: B, L, C
"""
q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x)
k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x)
v = self.v(x)
q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads)
k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads)
v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads)
out = self.attention(q, k, v)
out = rearrange(out, 'B L H D -> B L (H D)')
return out
class ClassTransformerLayer(nn.Module):
def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None:
super().__init__()
self.pool = nn.AvgPool2d(pooling_size)
self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type)
self.MLP = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def pool_features(self, x):
"""
Intermediate pooling layer for computational efficiency.
Arguments:
x: B, C, T, H, W
"""
B = x.size(0)
x = rearrange(x, 'B C T H W -> (B T) C H W')
x = self.pool(x)
x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
return x
def forward(self, x, guidance):
"""
Arguments:
x: B, C, T, H, W
guidance: B, T, C
"""
B, _, _, H, W = x.size()
x_pool = self.pool_features(x)
*_, H_pool, W_pool = x_pool.size()
x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C')
if guidance is not None:
guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool)
x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) # Attention
x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool)
x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True)
x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B)
x = x + x_pool # Residual
return x
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AggregatorLayer(nn.Module):
def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None:
super().__init__()
self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size)
self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size)
def forward(self, x, appearance_guidance, text_guidance):
"""
Arguments:
x: B C T H W
"""
x = self.swin_block(x, appearance_guidance)
x = self.attention(x, text_guidance)
return x
class AggregatorResNetLayer(nn.Module):
def __init__(self, hidden_dim=64, appearance_guidance=512) -> None:
super().__init__()
self.conv_linear = nn.Conv2d(hidden_dim + appearance_guidance, hidden_dim, kernel_size=1, stride=1)
self.conv_layer = Bottleneck(hidden_dim, hidden_dim // 4)
def forward(self, x, appearance_guidance):
"""
Arguments:
x: B C T H W
"""
B, T = x.size(0), x.size(2)
x = rearrange(x, 'B C T H W -> (B T) C H W')
appearance_guidance = repeat(appearance_guidance, 'B C H W -> (B T) C H W', T=T)
x = self.conv_linear(torch.cat([x, appearance_guidance], dim=1))
x = self.conv_layer(x)
x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
return x
class DoubleConv(nn.Module):
"""(convolution => [GN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(mid_channels // 16, mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(mid_channels // 16, mid_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, guidance_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x, guidance=None):
x = self.up(x)
if guidance is not None:
T = x.size(0) // guidance.size(0)
guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
x = torch.cat([x, guidance], dim=1)
return self.conv(x)
class Aggregator(nn.Module):
def __init__(self,
text_guidance_dim=512,
text_guidance_proj_dim=128,
appearance_guidance_dim=512,
appearance_guidance_proj_dim=128,
decoder_dims = (64, 32),
decoder_guidance_dims=(256, 128),
decoder_guidance_proj_dims=(32, 16),
num_layers=4,
nheads=4,
hidden_dim=128,
pooling_size=(6, 6),
feature_resolution=(24, 24),
window_size=12,
attention_type='linear',
prompt_channel=80,
) -> None:
super().__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.layers = nn.ModuleList([
AggregatorLayer(
hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim,
nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type
) for _ in range(num_layers)
])
self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3)
self.guidance_projection = nn.Sequential(
nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
) if appearance_guidance_dim > 0 else None
self.text_guidance_projection = nn.Sequential(
nn.Linear(text_guidance_dim, text_guidance_proj_dim),
nn.ReLU(),
) if text_guidance_dim > 0 else None
self.decoder_guidance_projection = nn.ModuleList([
nn.Sequential(
nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims)
]) if decoder_guidance_dims[0] > 0 else None
self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0])
self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1])
self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
def feature_map(self, img_feats, text_feats):
img_feats = F.normalize(img_feats, dim=1) # B C H W
img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1])
text_feats = F.normalize(text_feats, dim=-1) # B T P C
text_feats = text_feats.mean(dim=-2)
text_feats = F.normalize(text_feats, dim=-1) # B T C
text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1])
return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
def correlation(self, img_feats, text_feats):
img_feats = F.normalize(img_feats, dim=1) # B C H W
text_feats = F.normalize(text_feats, dim=-1) # B T P C
corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
return corr
def corr_embed(self, x):
B = x.shape[0]
corr_embed = rearrange(x, 'B P T H W -> (B T) P H W')
corr_embed = self.conv1(corr_embed)
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
return corr_embed
def corr_projection(self, x, proj):
corr_embed = rearrange(x, 'B C T H W -> B T H W C')
corr_embed = proj(corr_embed)
corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W')
return corr_embed
def upsample(self, x):
B = x.shape[0]
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True)
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
return corr_embed
def conv_decoder(self, x, guidance):
B = x.shape[0]
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
corr_embed = self.decoder1(corr_embed, guidance[0])
corr_embed = self.decoder2(corr_embed, guidance[1])
corr_embed = self.head(corr_embed)
corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B)
return corr_embed
def forward(self, img_feats, text_feats, appearance_guidance):
"""
Arguments:
img_feats: (B, C, H, W)
text_feats: (B, T, P, C)
apperance_guidance: tuple of (B, C, H, W)
"""
corr = self.correlation(img_feats, text_feats)
#corr = self.feature_map(img_feats, text_feats)
corr_embed = self.corr_embed(corr)
projected_guidance, projected_text_guidance, projected_decoder_guidance = None, None, [None, None]
if self.guidance_projection is not None:
projected_guidance = self.guidance_projection(appearance_guidance[0])
if self.decoder_guidance_projection is not None:
projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])]
if self.text_guidance_projection is not None:
text_feats = text_feats.mean(dim=-2)
text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
projected_text_guidance = self.text_guidance_projection(text_feats)
for layer in self.layers:
corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance)
logit = self.conv_decoder(corr_embed, projected_decoder_guidance)
return logit