|
from torch import nn |
|
from torch.autograd import Variable |
|
import torch |
|
import torch.nn.functional as F |
|
from .munit import ResBlocks, Conv2dBlock |
|
import math |
|
|
|
|
|
class Unet(nn.Module): |
|
def __init__(self, resolution=256, secret_len=100, return_residual=False) -> None: |
|
super().__init__() |
|
self.secret_len = secret_len |
|
self.return_residual = return_residual |
|
self.secret_dense = nn.Linear(secret_len, 16*16*3) |
|
log_resolution = int(math.log(resolution, 2)) |
|
assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." |
|
self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))) |
|
|
|
self.enc = Encoder(2, 4, 6, 64, 'bn' , 'relu', 'reflect') |
|
self.dec = Decoder(2, 4, self.enc.output_dim, 3, 'bn', 'relu', 'reflect') |
|
|
|
def forward(self, image, secret): |
|
|
|
fingerprint = F.relu(self.secret_dense(secret)) |
|
fingerprint = fingerprint.view((-1, 3, 16, 16)) |
|
fingerprint_enlarged = self.secret_upsample(fingerprint) |
|
inputs = torch.cat([fingerprint_enlarged, image], dim=1) |
|
emb = self.enc(inputs) |
|
|
|
out = self.dec(emb) |
|
return out |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): |
|
super().__init__() |
|
self.model = [] |
|
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] |
|
|
|
for i in range(n_downsample): |
|
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] |
|
dim *= 2 |
|
|
|
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] |
|
|
|
self.model = nn.ModuleList(self.model) |
|
self.output_dim = dim |
|
|
|
def forward(self, x): |
|
out = [] |
|
for block in self.model: |
|
x = block(x) |
|
out.append(x) |
|
|
|
return out |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): |
|
super(Decoder, self).__init__() |
|
|
|
self.model = [] |
|
|
|
self.model += [DecoderBlock('resblock', n_res, dim, res_norm, activ, pad_type=pad_type)] |
|
|
|
for i in range(n_upsample): |
|
self.model += [DecoderBlock('upsample', dim, dim//2,'bn', activ, pad_type) |
|
] |
|
dim //= 2 |
|
|
|
self.output_layer = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type) |
|
|
|
self.model = nn.ModuleList(self.model) |
|
|
|
def forward(self, x): |
|
x1 = x.pop() |
|
for block in self.model: |
|
x2 = x.pop() |
|
|
|
x1 = block(x1, x2) |
|
x1 = self.output_layer(x1) |
|
return x1 |
|
|
|
|
|
class Merge(nn.Module): |
|
def __init__(self, dim, activation='relu'): |
|
super().__init__() |
|
self.conv = nn.Conv2d(2*dim, dim, 3, 1, 1) |
|
|
|
if activation == 'relu': |
|
self.activation = nn.ReLU(inplace=True) |
|
elif activation == 'lrelu': |
|
self.activation = nn.LeakyReLU(0.2, inplace=True) |
|
elif activation == 'prelu': |
|
self.activation = nn.PReLU() |
|
elif activation == 'selu': |
|
self.activation = nn.SELU(inplace=True) |
|
elif activation == 'tanh': |
|
self.activation = nn.Tanh() |
|
elif activation == 'none': |
|
self.activation = None |
|
else: |
|
assert 0, "Unsupported activation: {}".format(activation) |
|
def forward(self, x1, x2): |
|
x = torch.cat([x1, x2], dim=1) |
|
x = self.conv(x) |
|
x = self.activation(x) |
|
return x |
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, block_type, in_dim, out_dim, norm, activ='relu', pad_type='reflect'): |
|
super().__init__() |
|
assert block_type in ['resblock', 'upsample'] |
|
if block_type == 'resblock': |
|
self.core_layer = ResBlocks(in_dim, out_dim, norm, activ, pad_type=pad_type) |
|
else: |
|
assert out_dim == in_dim//2 |
|
self.core_layer = nn.Sequential(nn.Upsample(scale_factor=2), |
|
Conv2dBlock(in_dim, out_dim, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type)) |
|
self.merge = Merge(out_dim, activ) |
|
|
|
def forward(self, x1, x2): |
|
x1 = self.core_layer(x1) |
|
return self.merge(x1, x2) |
|
|