|
import math |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as thf |
|
import pytorch_lightning as pl |
|
from ldm.util import instantiate_from_config |
|
import einops |
|
import kornia |
|
import numpy as np |
|
import torchvision |
|
from contextlib import contextmanager |
|
from ldm.modules.ema import LitEma |
|
|
|
|
|
class FlAE(pl.LightningModule): |
|
def __init__(self, |
|
cover_key, |
|
secret_key, |
|
secret_len, |
|
resolution, |
|
secret_encoder_config, |
|
secret_decoder_config, |
|
loss_config, |
|
noise_config='__none__', |
|
ckpt_path="__none__", |
|
use_ema=False |
|
): |
|
super().__init__() |
|
self.cover_key = cover_key |
|
self.secret_key = secret_key |
|
secret_encoder_config.params.secret_len = secret_len |
|
secret_decoder_config.params.secret_len = secret_len |
|
secret_encoder_config.params.resolution = resolution |
|
secret_decoder_config.params.resolution = 224 |
|
self.encoder = instantiate_from_config(secret_encoder_config) |
|
self.decoder = instantiate_from_config(secret_decoder_config) |
|
self.loss_layer = instantiate_from_config(loss_config) |
|
if noise_config != '__none__': |
|
print('Using noise') |
|
self.noise = instantiate_from_config(noise_config) |
|
|
|
self.use_ema = use_ema |
|
if self.use_ema: |
|
print('Using EMA') |
|
self.encoder_ema = LitEma(self.encoder) |
|
self.decoder_ema = LitEma(self.decoder) |
|
print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()))}.") |
|
|
|
if ckpt_path != "__none__": |
|
self.init_from_ckpt(ckpt_path, ignore_keys=[]) |
|
|
|
|
|
self.fixed_img = None |
|
self.fixed_secret = None |
|
self.register_buffer("fixed_input", torch.tensor(True)) |
|
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") |
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
keys = list(sd.keys()) |
|
for k in keys: |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
print("Deleting key {} from state_dict.".format(k)) |
|
del sd[k] |
|
self.load_state_dict(sd, strict=False) |
|
print(f"Restored from {path}") |
|
|
|
@contextmanager |
|
def ema_scope(self, context=None): |
|
if self.use_ema: |
|
self.encoder_ema.store(self.encoder.parameters()) |
|
self.decoder_ema.store(self.decoder.parameters()) |
|
self.encoder_ema.copy_to(self.encoder) |
|
self.decoder_ema.copy_to(self.decoder) |
|
if context is not None: |
|
print(f"{context}: Switched to EMA weights") |
|
try: |
|
yield None |
|
finally: |
|
if self.use_ema: |
|
self.encoder_ema.restore(self.encoder.parameters()) |
|
self.decoder_ema.restore(self.decoder.parameters()) |
|
if context is not None: |
|
print(f"{context}: Restored training weights") |
|
|
|
def on_train_batch_end(self, *args, **kwargs): |
|
if self.use_ema: |
|
self.encoder_ema(self.encoder) |
|
self.decoder_ema(self.decoder) |
|
|
|
@torch.no_grad() |
|
def get_input(self, batch, bs=None): |
|
image = batch[self.cover_key] |
|
secret = batch[self.secret_key] |
|
if bs is not None: |
|
image = image[:bs] |
|
secret = secret[:bs] |
|
else: |
|
bs = image.shape[0] |
|
|
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous() |
|
|
|
|
|
|
|
if self.fixed_input: |
|
if self.fixed_img is None: |
|
print('[TRAINING] Warmup - using fixed input image for now!') |
|
self.fixed_img = image.detach().clone()[:bs] |
|
self.fixed_secret = secret.detach().clone()[:bs] |
|
image = self.fixed_img |
|
new_bs = min(secret.shape[0], image.shape[0]) |
|
image, secret = image[:new_bs], secret[:new_bs] |
|
|
|
out = [image, secret] |
|
return out |
|
|
|
def forward(self, cover, secret): |
|
|
|
enc_out = self.encoder(cover, secret) |
|
if self.encoder.return_residual: |
|
return cover + enc_out, enc_out |
|
else: |
|
return enc_out, enc_out - cover |
|
|
|
def shared_step(self, batch): |
|
x, s = self.get_input(batch) |
|
stego, residual = self(x, s) |
|
if hasattr(self, "noise") and self.noise.is_activated(): |
|
stego_noised = self.noise(stego, self.global_step, p=0.9) |
|
else: |
|
stego_noised = self.crop(stego) |
|
stego_noised = torch.clamp(stego_noised, -1, 1) |
|
spred = self.decoder(stego_noised) |
|
|
|
loss, loss_dict = self.loss_layer(x, stego, None, s, spred, self.global_step) |
|
bit_acc = loss_dict["bit_acc"] |
|
|
|
bit_acc_ = bit_acc.item() |
|
|
|
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated(): |
|
self.loss_layer.activate_ramp(self.global_step) |
|
|
|
if (bit_acc_ > 0.95) and (not self.fixed_input): |
|
if hasattr(self, 'noise') and (not self.noise.is_activated()): |
|
self.noise.activate(self.global_step) |
|
|
|
if (bit_acc_ > 0.9) and self.fixed_input: |
|
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.') |
|
self.fixed_input = ~self.fixed_input |
|
return loss, loss_dict |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss, loss_dict = self.shared_step(batch) |
|
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} |
|
self.log_dict(loss_dict, prog_bar=True, |
|
logger=True, on_step=True, on_epoch=True) |
|
|
|
self.log("global_step", self.global_step, |
|
prog_bar=True, logger=True, on_step=True, on_epoch=False) |
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
@torch.no_grad() |
|
def validation_step(self, batch, batch_idx): |
|
_, loss_dict_no_ema = self.shared_step(batch) |
|
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} |
|
with self.ema_scope(): |
|
_, loss_dict_ema = self.shared_step(batch) |
|
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} |
|
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
|
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
|
|
|
@torch.no_grad() |
|
def log_images(self, batch, fixed_input=False, **kwargs): |
|
log = dict() |
|
if fixed_input and self.fixed_img is not None: |
|
x, s = self.fixed_img, self.fixed_secret |
|
else: |
|
x, s = self.get_input(batch) |
|
stego, residual = self(x, s) |
|
if hasattr(self, 'noise') and self.noise.is_activated(): |
|
img_noise = self.noise(stego, self.global_step, p=1.0) |
|
log['noised'] = img_noise |
|
log['input'] = x |
|
log['stego'] = stego |
|
log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1 |
|
return log |
|
|
|
def configure_optimizers(self): |
|
lr = self.learning_rate |
|
params = list(self.encoder.parameters()) + list(self.decoder.parameters()) |
|
optimizer = torch.optim.AdamW(params, lr=lr) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
class SecretEncoder(nn.Module): |
|
def __init__(self, resolution=256, secret_len=100, return_residual=False, act='tanh') -> None: |
|
super().__init__() |
|
self.secret_len = secret_len |
|
self.return_residual = return_residual |
|
self.act_fn = lambda x: torch.tanh(x) if act == 'tanh' else thf.sigmoid(x) * 2.0 -1.0 |
|
self.secret_dense = nn.Linear(secret_len, 16*16*3) |
|
log_resolution = int(math.log(resolution, 2)) |
|
assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." |
|
self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))) |
|
self.conv1 = nn.Conv2d(2 * 3, 32, 3, 1, 1) |
|
self.conv2 = nn.Conv2d(32, 32, 3, 2, 1) |
|
self.conv3 = nn.Conv2d(32, 64, 3, 2, 1) |
|
self.conv4 = nn.Conv2d(64, 128, 3, 2, 1) |
|
self.conv5 = nn.Conv2d(128, 256, 3, 2, 1) |
|
self.pad6 = nn.ZeroPad2d((0, 1, 0, 1)) |
|
self.up6 = nn.Conv2d(256, 128, 2, 1) |
|
self.upsample6 = nn.Upsample(scale_factor=(2, 2)) |
|
self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1) |
|
self.pad7 = nn.ZeroPad2d((0, 1, 0, 1)) |
|
self.up7 = nn.Conv2d(128, 64, 2, 1) |
|
self.upsample7 = nn.Upsample(scale_factor=(2, 2)) |
|
self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1) |
|
self.pad8 = nn.ZeroPad2d((0, 1, 0, 1)) |
|
self.up8 = nn.Conv2d(64, 32, 2, 1) |
|
self.upsample8 = nn.Upsample(scale_factor=(2, 2)) |
|
self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1) |
|
self.pad9 = nn.ZeroPad2d((0, 1, 0, 1)) |
|
self.up9 = nn.Conv2d(32, 32, 2, 1) |
|
self.upsample9 = nn.Upsample(scale_factor=(2, 2)) |
|
self.conv9 = nn.Conv2d(32 + 32 + 2 * 3, 32, 3, 1, 1) |
|
self.conv10 = nn.Conv2d(32, 32, 3, 1, 1) |
|
self.residual = nn.Conv2d(32, 3, 1) |
|
|
|
def forward(self, image, secret): |
|
fingerprint = thf.relu(self.secret_dense(secret)) |
|
fingerprint = fingerprint.view((-1, 3, 16, 16)) |
|
fingerprint_enlarged = self.secret_upsample(fingerprint) |
|
|
|
inputs = torch.cat([fingerprint_enlarged, image], dim=1) |
|
|
|
|
|
|
|
conv1 = thf.relu(self.conv1(inputs)) |
|
conv2 = thf.relu(self.conv2(conv1)) |
|
conv3 = thf.relu(self.conv3(conv2)) |
|
conv4 = thf.relu(self.conv4(conv3)) |
|
conv5 = thf.relu(self.conv5(conv4)) |
|
up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5)))) |
|
merge6 = torch.cat([conv4, up6], dim=1) |
|
conv6 = thf.relu(self.conv6(merge6)) |
|
up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6)))) |
|
merge7 = torch.cat([conv3, up7], dim=1) |
|
conv7 = thf.relu(self.conv7(merge7)) |
|
up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7)))) |
|
merge8 = torch.cat([conv2, up8], dim=1) |
|
conv8 = thf.relu(self.conv8(merge8)) |
|
up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8)))) |
|
merge9 = torch.cat([conv1, up9, inputs], dim=1) |
|
conv9 = thf.relu(self.conv9(merge9)) |
|
conv10 = thf.relu(self.conv10(conv9)) |
|
residual = self.residual(conv10) |
|
residual = self.act_fn(residual) |
|
return residual |
|
|
|
|
|
class SecretEncoder1(nn.Module): |
|
def __init__(self, resolution=256, secret_len=100) -> None: |
|
pass |
|
|
|
class SecretDecoder(nn.Module): |
|
def __init__(self, arch='resnet18', resolution=224, secret_len=100): |
|
super().__init__() |
|
self.resolution = resolution |
|
self.arch = arch |
|
if arch == 'resnet18': |
|
self.decoder = torchvision.models.resnet18(pretrained=True, progress=False) |
|
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len) |
|
elif arch == 'resnet50': |
|
self.decoder = torchvision.models.resnet50(pretrained=True, progress=False) |
|
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len) |
|
elif arch == 'simple': |
|
self.decoder = SimpleCNN(resolution, secret_len) |
|
else: |
|
raise ValueError('Unknown architecture') |
|
|
|
def forward(self, image): |
|
if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution: |
|
image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) |
|
x = self.decoder(image) |
|
return x |
|
|
|
|
|
class SimpleCNN(nn.Module): |
|
def __init__(self, resolution=224, secret_len=100): |
|
super().__init__() |
|
self.resolution = resolution |
|
self.IMAGE_CHANNELS = 3 |
|
self.decoder = nn.Sequential( |
|
nn.Conv2d(self.IMAGE_CHANNELS, 32, (3, 3), 2, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 32, 3, 1, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 64, 3, 2, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, 3, 1, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, 3, 2, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 128, 3, 2, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 128, (3, 3), 2, 1), |
|
nn.ReLU(), |
|
) |
|
self.dense = nn.Sequential( |
|
nn.Linear(resolution * resolution * 128 // 32 // 32, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, secret_len), |
|
) |
|
|
|
def forward(self, image): |
|
x = self.decoder(image) |
|
x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32) |
|
return self.dense(x) |