from torch import nn from torch.autograd import Variable import torch import torch.nn.functional as F from .munit import ResBlocks, Conv2dBlock import math class Unet(nn.Module): def __init__(self, resolution=256, secret_len=100, return_residual=False) -> None: super().__init__() self.secret_len = secret_len self.return_residual = return_residual self.secret_dense = nn.Linear(secret_len, 16*16*3) log_resolution = int(math.log(resolution, 2)) assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))) self.enc = Encoder(2, 4, 6, 64, 'bn' , 'relu', 'reflect') self.dec = Decoder(2, 4, self.enc.output_dim, 3, 'bn', 'relu', 'reflect') def forward(self, image, secret): # import pdb; pdb.set_trace() fingerprint = F.relu(self.secret_dense(secret)) fingerprint = fingerprint.view((-1, 3, 16, 16)) fingerprint_enlarged = self.secret_upsample(fingerprint) inputs = torch.cat([fingerprint_enlarged, image], dim=1) emb = self.enc(inputs) # import pdb; pdb.set_trace() out = self.dec(emb) return out class Encoder(nn.Module): def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): super().__init__() self.model = [] self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] # downsampling blocks for i in range(n_downsample): self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] dim *= 2 # residual blocks self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] # self.model = nn.(*self.model) self.model = nn.ModuleList(self.model) self.output_dim = dim def forward(self, x): out = [] for block in self.model: x = block(x) out.append(x) # print(x.shape) return out class Decoder(nn.Module): def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): super(Decoder, self).__init__() self.model = [] # AdaIN residual blocks self.model += [DecoderBlock('resblock', n_res, dim, res_norm, activ, pad_type=pad_type)] # upsampling blocks for i in range(n_upsample): self.model += [DecoderBlock('upsample', dim, dim//2,'bn', activ, pad_type) ] dim //= 2 # use reflection padding in the last conv layer self.output_layer = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type) # self.model = nn.Sequential(*self.model) self.model = nn.ModuleList(self.model) def forward(self, x): x1 = x.pop() for block in self.model: x2 = x.pop() # print(x1.shape, x2.shape) x1 = block(x1, x2) x1 = self.output_layer(x1) return x1 class Merge(nn.Module): def __init__(self, dim, activation='relu'): super().__init__() self.conv = nn.Conv2d(2*dim, dim, 3, 1, 1) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) def forward(self, x1, x2): x = torch.cat([x1, x2], dim=1) # 2xdim x = self.conv(x) # B,dim,H,W x = self.activation(x) return x class DecoderBlock(nn.Module): def __init__(self, block_type, in_dim, out_dim, norm, activ='relu', pad_type='reflect'): super().__init__() assert block_type in ['resblock', 'upsample'] if block_type == 'resblock': self.core_layer = ResBlocks(in_dim, out_dim, norm, activ, pad_type=pad_type) else: assert out_dim == in_dim//2 self.core_layer = nn.Sequential(nn.Upsample(scale_factor=2), Conv2dBlock(in_dim, out_dim, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type)) self.merge = Merge(out_dim, activ) def forward(self, x1, x2): x1 = self.core_layer(x1) return self.merge(x1, x2)