deepkyu's picture
initial commit
1ba3df3
raw
history blame contribute delete
No virus
5.88 kB
import torch
import torch.nn as nn
class Conv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, padding_mode='zeros', bias=True, residual=False):
super(Conv2d, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, padding_mode=padding_mode, bias=bias),
nn.BatchNorm2d(out_channels)
)
self.residual = residual
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
out = self.act(out)
return out
class ResnetBlock(nn.Module):
def __init__(self, channel, padding_mode, norm_layer=nn.BatchNorm2d, bias=False):
super().__init__()
if padding_mode not in ['reflect', 'zero']:
raise NotImplementedError(f"{padding_mode} is not supported!")
self.block = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=3, padding=1, padding_mode=padding_mode, bias=bias),
norm_layer(channel)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.block(x)
out = out + x
out = self.act(out)
return out
class ResidualBlocks(nn.Module):
def __init__(self, channel, n_blocks=6):
super().__init__()
model = []
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(channel, padding_mode='reflect')]
self.module = nn.Sequential(*model)
def forward(self, x):
return self.module(x)
class SelfAttentionBlock(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.feature_dim = in_dim // 8
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, H, W = x.size()
_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # B x C x (H'*W')
_key = self.key_conv(x).view(B, -1, H * W) # B x C x (H'*W')
attn_matrix = torch.bmm(_query, _key)
attention = self.softmax(attn_matrix) # B x (H'*W') x (H'*W')
_value = self.value_conv(x).view(B, -1, H * W) # B X C X (H * W)
out = torch.bmm(_value, attention.permute(0, 2, 1))
out = out.view(B, C, H, W)
out = self.gamma * out + x
return out
class ContextAwareAttentionBlock(nn.Module):
def __init__(self, in_channels, hidden_dim=128):
super().__init__()
self.self_attn = SelfAttentionBlock(in_channels)
self.fc = nn.Linear(in_channels, hidden_dim)
self.context_vector = nn.Linear(hidden_dim, 1, bias=False)
self.softmax = nn.Softmax(dim=1)
def forward(self, style_features):
B, C, H, W = style_features.size()
h = self.self_attn(style_features)
h = h.permute(0, 2, 3, 1).reshape(-1, C)
h = torch.tanh(self.fc(h)) # (B*H*W) x self.hidden_dim
h = self.context_vector(h) # (B*H*W) x 1
attention_score = self.softmax(h.view(B, H * W)).view(B, 1, H, W) # B x 1 x H x W
return torch.sum(style_features * attention_score, dim=[2, 3]) # B x C
class LayerAttentionBlock(nn.Module):
"""from FTransGAN
"""
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.width_feat = 4
self.height_feat = 4
self.fc = nn.Linear(self.in_channels * self.width_feat * self.height_feat, 3)
self.softmax = nn.Softmax(dim=1)
def forward(self, style_features, style_features_1, style_features_2, style_features_3, B, K):
style_features = torch.mean(style_features.view(B, K, self.in_channels, self.height_feat, self.width_feat), dim=1)
style_features = style_features.view(B, -1)
weight = self.softmax(self.fc(style_features))
style_features_1 = torch.mean(style_features_1.view(B, K, self.in_channels), dim=1)
style_features_2 = torch.mean(style_features_2.view(B, K, self.in_channels), dim=1)
style_features_3 = torch.mean(style_features_3.view(B, K, self.in_channels), dim=1)
style_features = (style_features_1 * weight.narrow(1, 0, 1) +
style_features_2 * weight.narrow(1, 1, 1) +
style_features_3 * weight.narrow(1, 2, 1))
style_features = style_features.view(B, self.in_channels, 1, 1)
return style_features
class StyleAttentionBlock(nn.Module):
"""from FTransGAN
"""
def __init__(self, in_channels):
super().__init__()
self.num_local_attention = 3
for module_idx in range(1, self.num_local_attention + 1):
self.add_module(f"local_attention_{module_idx}",
ContextAwareAttentionBlock(in_channels))
for module_idx in range(1, self.num_local_attention):
self.add_module(f"downsample_{module_idx}",
Conv2d(in_channels, in_channels,
kernel_size=3, stride=2, padding=1, bias=False))
self.add_module(f"layer_attention", LayerAttentionBlock(in_channels))
def forward(self, x, B, K):
feature_1 = self.local_attention_1(x)
x = self.downsample_1(x)
feature_2 = self.local_attention_2(x)
x = self.downsample_2(x)
feature_3 = self.local_attention_3(x)
out = self.layer_attention(x, feature_1, feature_2, feature_3, B, K)
return out