Sophie98
change to streamlit
ad1ac8f
raw
history blame
9.22 kB
import torch
import torch.nn.functional as F
from StyleTransfer.srcTransformer.function import calc_mean_std, normal
from StyleTransfer.srcTransformer.misc import (
NestedTensor,
nested_tensor_from_tensor_list,
)
from StyleTransfer.srcTransformer.ViT_helper import to_2tuple
from torch import nn
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
img_size: int = 256,
patch_size: int = 8,
in_chans: int = 3,
embed_dim: int = 512,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x)
return x
decoder = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 3, (3, 3)),
)
vgg = nn.Sequential(
nn.Conv2d(3, 3, (1, 1)),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(3, 64, (3, 3)),
nn.ReLU(), # relu1-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(), # relu1-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 128, (3, 3)),
nn.ReLU(), # relu2-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(), # relu2-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 256, (3, 3)),
nn.ReLU(), # relu3-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 512, (3, 3)),
nn.ReLU(), # relu4-1, this is the last layer used
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-4
)
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class StyTrans(nn.Module):
"""This is the style transform transformer module"""
def __init__(
self, encoder: nn.Sequential, decoder: nn.Sequential, PatchEmbed, transformer
):
super().__init__()
enc_layers = list(encoder.children())
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
for name in ["enc_1", "enc_2", "enc_3", "enc_4", "enc_5"]:
for param in getattr(self, name).parameters():
param.requires_grad = False
self.mse_loss = nn.MSELoss()
self.transformer = transformer
self.decode = decoder
self.embedding = PatchEmbed
def encode_with_intermediate(self, input):
results = [input]
for i in range(5):
func = getattr(self, "enc_{:d}".format(i + 1))
results.append(func(results[-1]))
return results[1:]
def calc_content_loss(self, input, target):
assert input.size() == target.size()
assert target.requires_grad is False
return self.mse_loss(input, target)
def calc_style_loss(self, input, target):
assert input.size() == target.size()
assert target.requires_grad is False
input_mean, input_std = calc_mean_std(input)
target_mean, target_std = calc_mean_std(target)
return self.mse_loss(input_mean, target_mean) + self.mse_loss(
input_std, target_std
)
def forward(self, samples_c: NestedTensor, samples_s: NestedTensor):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W],
containing 1 on padded pixels
"""
content_input = samples_c
style_input = samples_s
if isinstance(samples_c, (list, torch.Tensor)):
samples_c = nested_tensor_from_tensor_list(
samples_c
) # support different-sized images padding is used for mask [tensor, mask]
if isinstance(samples_s, (list, torch.Tensor)):
samples_s = nested_tensor_from_tensor_list(samples_s)
# features used to calcate loss
content_feats = self.encode_with_intermediate(samples_c.tensors)
style_feats = self.encode_with_intermediate(samples_s.tensors)
# Linear projection
style = self.embedding(samples_s.tensors)
content = self.embedding(samples_c.tensors)
# postional embedding is calculated in transformer.py
pos_s = None
pos_c = None
mask = None
hs = self.transformer(style, mask, content, pos_c, pos_s)
Ics = self.decode(hs)
Ics_feats = self.encode_with_intermediate(Ics)
loss_c = self.calc_content_loss(
normal(Ics_feats[-1]), normal(content_feats[-1])
) + self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
# Style loss
loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
for i in range(1, 5):
loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
Icc = self.decode(self.transformer(content, mask, content, pos_c, pos_c))
Iss = self.decode(self.transformer(style, mask, style, pos_s, pos_s))
# Identity losses lambda 1
loss_lambda1 = self.calc_content_loss(
Icc, content_input
) + self.calc_content_loss(Iss, style_input)
# Identity losses lambda 2
Icc_feats = self.encode_with_intermediate(Icc)
Iss_feats = self.encode_with_intermediate(Iss)
loss_lambda2 = self.calc_content_loss(
Icc_feats[0], content_feats[0]
) + self.calc_content_loss(Iss_feats[0], style_feats[0])
for i in range(1, 5):
loss_lambda2 += self.calc_content_loss(
Icc_feats[i], content_feats[i]
) + self.calc_content_loss(Iss_feats[i], style_feats[i])
# Please select and comment out one of the following two sentences
return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 # train
# return Ics #test