import math import torch from torch import nn from torch.nn import functional as thf import pytorch_lightning as pl from ldm.util import instantiate_from_config import einops import kornia import numpy as np import torchvision from contextlib import contextmanager from ldm.modules.ema import LitEma class FlAE(pl.LightningModule): def __init__(self, cover_key, secret_key, secret_len, resolution, secret_encoder_config, secret_decoder_config, loss_config, noise_config='__none__', ckpt_path="__none__", use_ema=False ): super().__init__() self.cover_key = cover_key self.secret_key = secret_key secret_encoder_config.params.secret_len = secret_len secret_decoder_config.params.secret_len = secret_len secret_encoder_config.params.resolution = resolution secret_decoder_config.params.resolution = 224 self.encoder = instantiate_from_config(secret_encoder_config) self.decoder = instantiate_from_config(secret_decoder_config) self.loss_layer = instantiate_from_config(loss_config) if noise_config != '__none__': print('Using noise') self.noise = instantiate_from_config(noise_config) self.use_ema = use_ema if self.use_ema: print('Using EMA') self.encoder_ema = LitEma(self.encoder) self.decoder_ema = LitEma(self.decoder) print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()))}.") if ckpt_path != "__none__": self.init_from_ckpt(ckpt_path, ignore_keys=[]) # early training phase self.fixed_img = None self.fixed_secret = None self.register_buffer("fixed_input", torch.tensor(True)) self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") @contextmanager def ema_scope(self, context=None): if self.use_ema: self.encoder_ema.store(self.encoder.parameters()) self.decoder_ema.store(self.decoder.parameters()) self.encoder_ema.copy_to(self.encoder) self.decoder_ema.copy_to(self.decoder) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.encoder_ema.restore(self.encoder.parameters()) self.decoder_ema.restore(self.decoder.parameters()) if context is not None: print(f"{context}: Restored training weights") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.encoder_ema(self.encoder) self.decoder_ema(self.decoder) @torch.no_grad() def get_input(self, batch, bs=None): image = batch[self.cover_key] secret = batch[self.secret_key] if bs is not None: image = image[:bs] secret = secret[:bs] else: bs = image.shape[0] # encode image 1st stage image = einops.rearrange(image, "b h w c -> b c h w").contiguous() # check if using fixed input (early training phase) # if self.training and self.fixed_input: if self.fixed_input: if self.fixed_img is None: # first iteration print('[TRAINING] Warmup - using fixed input image for now!') self.fixed_img = image.detach().clone()[:bs] self.fixed_secret = secret.detach().clone()[:bs] # use for log_images with fixed_input option only image = self.fixed_img new_bs = min(secret.shape[0], image.shape[0]) image, secret = image[:new_bs], secret[:new_bs] out = [image, secret] return out def forward(self, cover, secret): # return a tuple (stego, residual) enc_out = self.encoder(cover, secret) if self.encoder.return_residual: return cover + enc_out, enc_out else: return enc_out, enc_out - cover def shared_step(self, batch): x, s = self.get_input(batch) stego, residual = self(x, s) if hasattr(self, "noise") and self.noise.is_activated(): stego_noised = self.noise(stego, self.global_step, p=0.9) else: stego_noised = self.crop(stego) stego_noised = torch.clamp(stego_noised, -1, 1) spred = self.decoder(stego_noised) loss, loss_dict = self.loss_layer(x, stego, None, s, spred, self.global_step) bit_acc = loss_dict["bit_acc"] bit_acc_ = bit_acc.item() if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated(): self.loss_layer.activate_ramp(self.global_step) if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage if hasattr(self, 'noise') and (not self.noise.is_activated()): self.noise.activate(self.global_step) if (bit_acc_ > 0.9) and self.fixed_input: # execute only once print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.') self.fixed_input = ~self.fixed_input return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) # if self.use_scheduler: # lr = self.optimizers().param_groups[0]['lr'] # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @torch.no_grad() def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @torch.no_grad() def log_images(self, batch, fixed_input=False, **kwargs): log = dict() if fixed_input and self.fixed_img is not None: x, s = self.fixed_img, self.fixed_secret else: x, s = self.get_input(batch) stego, residual = self(x, s) if hasattr(self, 'noise') and self.noise.is_activated(): img_noise = self.noise(stego, self.global_step, p=1.0) log['noised'] = img_noise log['input'] = x log['stego'] = stego log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1 return log def configure_optimizers(self): lr = self.learning_rate params = list(self.encoder.parameters()) + list(self.decoder.parameters()) optimizer = torch.optim.AdamW(params, lr=lr) return optimizer class SecretEncoder(nn.Module): def __init__(self, resolution=256, secret_len=100, return_residual=False, act='tanh') -> None: super().__init__() self.secret_len = secret_len self.return_residual = return_residual self.act_fn = lambda x: torch.tanh(x) if act == 'tanh' else thf.sigmoid(x) * 2.0 -1.0 self.secret_dense = nn.Linear(secret_len, 16*16*3) log_resolution = int(math.log(resolution, 2)) assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))) self.conv1 = nn.Conv2d(2 * 3, 32, 3, 1, 1) self.conv2 = nn.Conv2d(32, 32, 3, 2, 1) self.conv3 = nn.Conv2d(32, 64, 3, 2, 1) self.conv4 = nn.Conv2d(64, 128, 3, 2, 1) self.conv5 = nn.Conv2d(128, 256, 3, 2, 1) self.pad6 = nn.ZeroPad2d((0, 1, 0, 1)) self.up6 = nn.Conv2d(256, 128, 2, 1) self.upsample6 = nn.Upsample(scale_factor=(2, 2)) self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1) self.pad7 = nn.ZeroPad2d((0, 1, 0, 1)) self.up7 = nn.Conv2d(128, 64, 2, 1) self.upsample7 = nn.Upsample(scale_factor=(2, 2)) self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1) self.pad8 = nn.ZeroPad2d((0, 1, 0, 1)) self.up8 = nn.Conv2d(64, 32, 2, 1) self.upsample8 = nn.Upsample(scale_factor=(2, 2)) self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1) self.pad9 = nn.ZeroPad2d((0, 1, 0, 1)) self.up9 = nn.Conv2d(32, 32, 2, 1) self.upsample9 = nn.Upsample(scale_factor=(2, 2)) self.conv9 = nn.Conv2d(32 + 32 + 2 * 3, 32, 3, 1, 1) self.conv10 = nn.Conv2d(32, 32, 3, 1, 1) self.residual = nn.Conv2d(32, 3, 1) def forward(self, image, secret): fingerprint = thf.relu(self.secret_dense(secret)) fingerprint = fingerprint.view((-1, 3, 16, 16)) fingerprint_enlarged = self.secret_upsample(fingerprint) # try: inputs = torch.cat([fingerprint_enlarged, image], dim=1) # except: # print(fingerprint_enlarged.shape, image.shape, fingerprint.shape) # import pdb; pdb.set_trace() conv1 = thf.relu(self.conv1(inputs)) conv2 = thf.relu(self.conv2(conv1)) conv3 = thf.relu(self.conv3(conv2)) conv4 = thf.relu(self.conv4(conv3)) conv5 = thf.relu(self.conv5(conv4)) up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5)))) merge6 = torch.cat([conv4, up6], dim=1) conv6 = thf.relu(self.conv6(merge6)) up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6)))) merge7 = torch.cat([conv3, up7], dim=1) conv7 = thf.relu(self.conv7(merge7)) up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7)))) merge8 = torch.cat([conv2, up8], dim=1) conv8 = thf.relu(self.conv8(merge8)) up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8)))) merge9 = torch.cat([conv1, up9, inputs], dim=1) conv9 = thf.relu(self.conv9(merge9)) conv10 = thf.relu(self.conv10(conv9)) residual = self.residual(conv10) residual = self.act_fn(residual) return residual class SecretEncoder1(nn.Module): def __init__(self, resolution=256, secret_len=100) -> None: pass class SecretDecoder(nn.Module): def __init__(self, arch='resnet18', resolution=224, secret_len=100): super().__init__() self.resolution = resolution self.arch = arch if arch == 'resnet18': self.decoder = torchvision.models.resnet18(pretrained=True, progress=False) self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len) elif arch == 'resnet50': self.decoder = torchvision.models.resnet50(pretrained=True, progress=False) self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len) elif arch == 'simple': self.decoder = SimpleCNN(resolution, secret_len) else: raise ValueError('Unknown architecture') def forward(self, image): if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution: image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) x = self.decoder(image) return x class SimpleCNN(nn.Module): def __init__(self, resolution=224, secret_len=100): super().__init__() self.resolution = resolution self.IMAGE_CHANNELS = 3 self.decoder = nn.Sequential( nn.Conv2d(self.IMAGE_CHANNELS, 32, (3, 3), 2, 1), # resolution / 2 nn.ReLU(), nn.Conv2d(32, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 2, 1), # resolution / 4 nn.ReLU(), nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(), nn.Conv2d(64, 64, 3, 2, 1), # resolution / 8 nn.ReLU(), nn.Conv2d(64, 128, 3, 2, 1), # resolution / 16 nn.ReLU(), nn.Conv2d(128, 128, (3, 3), 2, 1), # resolution / 32 nn.ReLU(), ) self.dense = nn.Sequential( nn.Linear(resolution * resolution * 128 // 32 // 32, 512), nn.ReLU(), nn.Linear(512, secret_len), ) def forward(self, image): x = self.decoder(image) x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32) return self.dense(x)