test / flae /unet.py
Tu Bui
first commit
6142a25
raw
history blame
4.86 kB
from torch import nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F
from .munit import ResBlocks, Conv2dBlock
import math
class Unet(nn.Module):
def __init__(self, resolution=256, secret_len=100, return_residual=False) -> None:
super().__init__()
self.secret_len = secret_len
self.return_residual = return_residual
self.secret_dense = nn.Linear(secret_len, 16*16*3)
log_resolution = int(math.log(resolution, 2))
assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))
self.enc = Encoder(2, 4, 6, 64, 'bn' , 'relu', 'reflect')
self.dec = Decoder(2, 4, self.enc.output_dim, 3, 'bn', 'relu', 'reflect')
def forward(self, image, secret):
# import pdb; pdb.set_trace()
fingerprint = F.relu(self.secret_dense(secret))
fingerprint = fingerprint.view((-1, 3, 16, 16))
fingerprint_enlarged = self.secret_upsample(fingerprint)
inputs = torch.cat([fingerprint_enlarged, image], dim=1)
emb = self.enc(inputs)
# import pdb; pdb.set_trace()
out = self.dec(emb)
return out
class Encoder(nn.Module):
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
super().__init__()
self.model = []
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
# downsampling blocks
for i in range(n_downsample):
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
# residual blocks
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
# self.model = nn.(*self.model)
self.model = nn.ModuleList(self.model)
self.output_dim = dim
def forward(self, x):
out = []
for block in self.model:
x = block(x)
out.append(x)
# print(x.shape)
return out
class Decoder(nn.Module):
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
super(Decoder, self).__init__()
self.model = []
# AdaIN residual blocks
self.model += [DecoderBlock('resblock', n_res, dim, res_norm, activ, pad_type=pad_type)]
# upsampling blocks
for i in range(n_upsample):
self.model += [DecoderBlock('upsample', dim, dim//2,'bn', activ, pad_type)
]
dim //= 2
# use reflection padding in the last conv layer
self.output_layer = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
# self.model = nn.Sequential(*self.model)
self.model = nn.ModuleList(self.model)
def forward(self, x):
x1 = x.pop()
for block in self.model:
x2 = x.pop()
# print(x1.shape, x2.shape)
x1 = block(x1, x2)
x1 = self.output_layer(x1)
return x1
class Merge(nn.Module):
def __init__(self, dim, activation='relu'):
super().__init__()
self.conv = nn.Conv2d(2*dim, dim, 3, 1, 1)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1) # 2xdim
x = self.conv(x) # B,dim,H,W
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, block_type, in_dim, out_dim, norm, activ='relu', pad_type='reflect'):
super().__init__()
assert block_type in ['resblock', 'upsample']
if block_type == 'resblock':
self.core_layer = ResBlocks(in_dim, out_dim, norm, activ, pad_type=pad_type)
else:
assert out_dim == in_dim//2
self.core_layer = nn.Sequential(nn.Upsample(scale_factor=2),
Conv2dBlock(in_dim, out_dim, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type))
self.merge = Merge(out_dim, activ)
def forward(self, x1, x2):
x1 = self.core_layer(x1)
return self.merge(x1, x2)