Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class Conv2d(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
padding=0, padding_mode='zeros', bias=True, residual=False): | |
super(Conv2d, self).__init__() | |
self.conv_block = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size, stride, | |
padding, padding_mode=padding_mode, bias=bias), | |
nn.BatchNorm2d(out_channels) | |
) | |
self.residual = residual | |
self.act = nn.ReLU() | |
def forward(self, x): | |
out = self.conv_block(x) | |
if self.residual: | |
out += x | |
out = self.act(out) | |
return out | |
class ResnetBlock(nn.Module): | |
def __init__(self, channel, padding_mode, norm_layer=nn.BatchNorm2d, bias=False): | |
super().__init__() | |
if padding_mode not in ['reflect', 'zero']: | |
raise NotImplementedError(f"{padding_mode} is not supported!") | |
self.block = nn.Sequential( | |
nn.Conv2d(channel, channel, kernel_size=3, padding=1, padding_mode=padding_mode, bias=bias), | |
norm_layer(channel) | |
) | |
self.act = nn.ReLU() | |
def forward(self, x): | |
out = self.block(x) | |
out = out + x | |
out = self.act(out) | |
return out | |
class ResidualBlocks(nn.Module): | |
def __init__(self, channel, n_blocks=6): | |
super().__init__() | |
model = [] | |
for i in range(n_blocks): # add ResNet blocks | |
model += [ResnetBlock(channel, padding_mode='reflect')] | |
self.module = nn.Sequential(*model) | |
def forward(self, x): | |
return self.module(x) | |
class SelfAttentionBlock(nn.Module): | |
def __init__(self, in_dim): | |
super().__init__() | |
self.feature_dim = in_dim // 8 | |
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1) | |
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1) | |
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x): | |
B, C, H, W = x.size() | |
_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # B x C x (H'*W') | |
_key = self.key_conv(x).view(B, -1, H * W) # B x C x (H'*W') | |
attn_matrix = torch.bmm(_query, _key) | |
attention = self.softmax(attn_matrix) # B x (H'*W') x (H'*W') | |
_value = self.value_conv(x).view(B, -1, H * W) # B X C X (H * W) | |
out = torch.bmm(_value, attention.permute(0, 2, 1)) | |
out = out.view(B, C, H, W) | |
out = self.gamma * out + x | |
return out | |
class ContextAwareAttentionBlock(nn.Module): | |
def __init__(self, in_channels, hidden_dim=128): | |
super().__init__() | |
self.self_attn = SelfAttentionBlock(in_channels) | |
self.fc = nn.Linear(in_channels, hidden_dim) | |
self.context_vector = nn.Linear(hidden_dim, 1, bias=False) | |
self.softmax = nn.Softmax(dim=1) | |
def forward(self, style_features): | |
B, C, H, W = style_features.size() | |
h = self.self_attn(style_features) | |
h = h.permute(0, 2, 3, 1).reshape(-1, C) | |
h = torch.tanh(self.fc(h)) # (B*H*W) x self.hidden_dim | |
h = self.context_vector(h) # (B*H*W) x 1 | |
attention_score = self.softmax(h.view(B, H * W)).view(B, 1, H, W) # B x 1 x H x W | |
return torch.sum(style_features * attention_score, dim=[2, 3]) # B x C | |
class LayerAttentionBlock(nn.Module): | |
"""from FTransGAN | |
""" | |
def __init__(self, in_channels): | |
super().__init__() | |
self.in_channels = in_channels | |
self.width_feat = 4 | |
self.height_feat = 4 | |
self.fc = nn.Linear(self.in_channels * self.width_feat * self.height_feat, 3) | |
self.softmax = nn.Softmax(dim=1) | |
def forward(self, style_features, style_features_1, style_features_2, style_features_3, B, K): | |
style_features = torch.mean(style_features.view(B, K, self.in_channels, self.height_feat, self.width_feat), dim=1) | |
style_features = style_features.view(B, -1) | |
weight = self.softmax(self.fc(style_features)) | |
style_features_1 = torch.mean(style_features_1.view(B, K, self.in_channels), dim=1) | |
style_features_2 = torch.mean(style_features_2.view(B, K, self.in_channels), dim=1) | |
style_features_3 = torch.mean(style_features_3.view(B, K, self.in_channels), dim=1) | |
style_features = (style_features_1 * weight.narrow(1, 0, 1) + | |
style_features_2 * weight.narrow(1, 1, 1) + | |
style_features_3 * weight.narrow(1, 2, 1)) | |
style_features = style_features.view(B, self.in_channels, 1, 1) | |
return style_features | |
class StyleAttentionBlock(nn.Module): | |
"""from FTransGAN | |
""" | |
def __init__(self, in_channels): | |
super().__init__() | |
self.num_local_attention = 3 | |
for module_idx in range(1, self.num_local_attention + 1): | |
self.add_module(f"local_attention_{module_idx}", | |
ContextAwareAttentionBlock(in_channels)) | |
for module_idx in range(1, self.num_local_attention): | |
self.add_module(f"downsample_{module_idx}", | |
Conv2d(in_channels, in_channels, | |
kernel_size=3, stride=2, padding=1, bias=False)) | |
self.add_module(f"layer_attention", LayerAttentionBlock(in_channels)) | |
def forward(self, x, B, K): | |
feature_1 = self.local_attention_1(x) | |
x = self.downsample_1(x) | |
feature_2 = self.local_attention_2(x) | |
x = self.downsample_2(x) | |
feature_3 = self.local_attention_3(x) | |
out = self.layer_attention(x, feature_1, feature_2, feature_3, B, K) | |
return out | |