test / flae /models.py
Tu Bui
first commit
6142a25
import math
import torch
from torch import nn
from torch.nn import functional as thf
import pytorch_lightning as pl
from ldm.util import instantiate_from_config
import einops
import kornia
import numpy as np
import torchvision
from contextlib import contextmanager
from ldm.modules.ema import LitEma
class FlAE(pl.LightningModule):
def __init__(self,
cover_key,
secret_key,
secret_len,
resolution,
secret_encoder_config,
secret_decoder_config,
loss_config,
noise_config='__none__',
ckpt_path="__none__",
use_ema=False
):
super().__init__()
self.cover_key = cover_key
self.secret_key = secret_key
secret_encoder_config.params.secret_len = secret_len
secret_decoder_config.params.secret_len = secret_len
secret_encoder_config.params.resolution = resolution
secret_decoder_config.params.resolution = 224
self.encoder = instantiate_from_config(secret_encoder_config)
self.decoder = instantiate_from_config(secret_decoder_config)
self.loss_layer = instantiate_from_config(loss_config)
if noise_config != '__none__':
print('Using noise')
self.noise = instantiate_from_config(noise_config)
self.use_ema = use_ema
if self.use_ema:
print('Using EMA')
self.encoder_ema = LitEma(self.encoder)
self.decoder_ema = LitEma(self.decoder)
print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
if ckpt_path != "__none__":
self.init_from_ckpt(ckpt_path, ignore_keys=[])
# early training phase
self.fixed_img = None
self.fixed_secret = None
self.register_buffer("fixed_input", torch.tensor(True))
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.encoder_ema.store(self.encoder.parameters())
self.decoder_ema.store(self.decoder.parameters())
self.encoder_ema.copy_to(self.encoder)
self.decoder_ema.copy_to(self.decoder)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.encoder_ema.restore(self.encoder.parameters())
self.decoder_ema.restore(self.decoder.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.encoder_ema(self.encoder)
self.decoder_ema(self.decoder)
@torch.no_grad()
def get_input(self, batch, bs=None):
image = batch[self.cover_key]
secret = batch[self.secret_key]
if bs is not None:
image = image[:bs]
secret = secret[:bs]
else:
bs = image.shape[0]
# encode image 1st stage
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
# check if using fixed input (early training phase)
# if self.training and self.fixed_input:
if self.fixed_input:
if self.fixed_img is None: # first iteration
print('[TRAINING] Warmup - using fixed input image for now!')
self.fixed_img = image.detach().clone()[:bs]
self.fixed_secret = secret.detach().clone()[:bs] # use for log_images with fixed_input option only
image = self.fixed_img
new_bs = min(secret.shape[0], image.shape[0])
image, secret = image[:new_bs], secret[:new_bs]
out = [image, secret]
return out
def forward(self, cover, secret):
# return a tuple (stego, residual)
enc_out = self.encoder(cover, secret)
if self.encoder.return_residual:
return cover + enc_out, enc_out
else:
return enc_out, enc_out - cover
def shared_step(self, batch):
x, s = self.get_input(batch)
stego, residual = self(x, s)
if hasattr(self, "noise") and self.noise.is_activated():
stego_noised = self.noise(stego, self.global_step, p=0.9)
else:
stego_noised = self.crop(stego)
stego_noised = torch.clamp(stego_noised, -1, 1)
spred = self.decoder(stego_noised)
loss, loss_dict = self.loss_layer(x, stego, None, s, spred, self.global_step)
bit_acc = loss_dict["bit_acc"]
bit_acc_ = bit_acc.item()
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
self.loss_layer.activate_ramp(self.global_step)
if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
if hasattr(self, 'noise') and (not self.noise.is_activated()):
self.noise.activate(self.global_step)
if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
self.fixed_input = ~self.fixed_input
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
# if self.use_scheduler:
# lr = self.optimizers().param_groups[0]['lr']
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
@torch.no_grad()
def log_images(self, batch, fixed_input=False, **kwargs):
log = dict()
if fixed_input and self.fixed_img is not None:
x, s = self.fixed_img, self.fixed_secret
else:
x, s = self.get_input(batch)
stego, residual = self(x, s)
if hasattr(self, 'noise') and self.noise.is_activated():
img_noise = self.noise(stego, self.global_step, p=1.0)
log['noised'] = img_noise
log['input'] = x
log['stego'] = stego
log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.encoder.parameters()) + list(self.decoder.parameters())
optimizer = torch.optim.AdamW(params, lr=lr)
return optimizer
class SecretEncoder(nn.Module):
def __init__(self, resolution=256, secret_len=100, return_residual=False, act='tanh') -> None:
super().__init__()
self.secret_len = secret_len
self.return_residual = return_residual
self.act_fn = lambda x: torch.tanh(x) if act == 'tanh' else thf.sigmoid(x) * 2.0 -1.0
self.secret_dense = nn.Linear(secret_len, 16*16*3)
log_resolution = int(math.log(resolution, 2))
assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))
self.conv1 = nn.Conv2d(2 * 3, 32, 3, 1, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)
self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
self.conv5 = nn.Conv2d(128, 256, 3, 2, 1)
self.pad6 = nn.ZeroPad2d((0, 1, 0, 1))
self.up6 = nn.Conv2d(256, 128, 2, 1)
self.upsample6 = nn.Upsample(scale_factor=(2, 2))
self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1)
self.pad7 = nn.ZeroPad2d((0, 1, 0, 1))
self.up7 = nn.Conv2d(128, 64, 2, 1)
self.upsample7 = nn.Upsample(scale_factor=(2, 2))
self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1)
self.pad8 = nn.ZeroPad2d((0, 1, 0, 1))
self.up8 = nn.Conv2d(64, 32, 2, 1)
self.upsample8 = nn.Upsample(scale_factor=(2, 2))
self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1)
self.pad9 = nn.ZeroPad2d((0, 1, 0, 1))
self.up9 = nn.Conv2d(32, 32, 2, 1)
self.upsample9 = nn.Upsample(scale_factor=(2, 2))
self.conv9 = nn.Conv2d(32 + 32 + 2 * 3, 32, 3, 1, 1)
self.conv10 = nn.Conv2d(32, 32, 3, 1, 1)
self.residual = nn.Conv2d(32, 3, 1)
def forward(self, image, secret):
fingerprint = thf.relu(self.secret_dense(secret))
fingerprint = fingerprint.view((-1, 3, 16, 16))
fingerprint_enlarged = self.secret_upsample(fingerprint)
# try:
inputs = torch.cat([fingerprint_enlarged, image], dim=1)
# except:
# print(fingerprint_enlarged.shape, image.shape, fingerprint.shape)
# import pdb; pdb.set_trace()
conv1 = thf.relu(self.conv1(inputs))
conv2 = thf.relu(self.conv2(conv1))
conv3 = thf.relu(self.conv3(conv2))
conv4 = thf.relu(self.conv4(conv3))
conv5 = thf.relu(self.conv5(conv4))
up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5))))
merge6 = torch.cat([conv4, up6], dim=1)
conv6 = thf.relu(self.conv6(merge6))
up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6))))
merge7 = torch.cat([conv3, up7], dim=1)
conv7 = thf.relu(self.conv7(merge7))
up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7))))
merge8 = torch.cat([conv2, up8], dim=1)
conv8 = thf.relu(self.conv8(merge8))
up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8))))
merge9 = torch.cat([conv1, up9, inputs], dim=1)
conv9 = thf.relu(self.conv9(merge9))
conv10 = thf.relu(self.conv10(conv9))
residual = self.residual(conv10)
residual = self.act_fn(residual)
return residual
class SecretEncoder1(nn.Module):
def __init__(self, resolution=256, secret_len=100) -> None:
pass
class SecretDecoder(nn.Module):
def __init__(self, arch='resnet18', resolution=224, secret_len=100):
super().__init__()
self.resolution = resolution
self.arch = arch
if arch == 'resnet18':
self.decoder = torchvision.models.resnet18(pretrained=True, progress=False)
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
elif arch == 'resnet50':
self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
elif arch == 'simple':
self.decoder = SimpleCNN(resolution, secret_len)
else:
raise ValueError('Unknown architecture')
def forward(self, image):
if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution:
image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False)
x = self.decoder(image)
return x
class SimpleCNN(nn.Module):
def __init__(self, resolution=224, secret_len=100):
super().__init__()
self.resolution = resolution
self.IMAGE_CHANNELS = 3
self.decoder = nn.Sequential(
nn.Conv2d(self.IMAGE_CHANNELS, 32, (3, 3), 2, 1), # resolution / 2
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 2, 1), # resolution / 4
nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, 2, 1), # resolution / 8
nn.ReLU(),
nn.Conv2d(64, 128, 3, 2, 1), # resolution / 16
nn.ReLU(),
nn.Conv2d(128, 128, (3, 3), 2, 1), # resolution / 32
nn.ReLU(),
)
self.dense = nn.Sequential(
nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
nn.ReLU(),
nn.Linear(512, secret_len),
)
def forward(self, image):
x = self.decoder(image)
x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
return self.dense(x)