import copy import clip import os import random import torch import torch.nn as nn import torch.nn.functional as F from enum import Enum from PIL import Image from torch import autograd from .base_model import BaseModel from models.modules import networks from models.modules.stylegan2.model import Generator, Discriminator, StyledConv, ToRGB, EqualLinear, ResBlock, ConvLayer, PixelNorm from models.modules.stylegan2.op import conv2d_gradfix from models.modules.stylegan2.non_leaking import augment from models.modules.vit.losses import LossG class TrainingPhase(Enum): ENCODER = 1 BASE_MODEL = 2 CLIP_MAPPING = 3 FEW_SHOT = 4 class CLIPFeats2Wplus(nn.Module): def __init__(self, n_tokens=16, embedding_dim=512): super().__init__() self.position_embedding = nn.Parameter(embedding_dim ** -0.5 * torch.randn(n_tokens, embedding_dim)) self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, norm_first=True), num_layers=4) def forward(self, x): x_in = x.view(x.shape[0], 1, x.shape[1]) + self.position_embedding return F.leaky_relu(self.transformer(x_in.permute(1, 0, 2)), negative_slope=0.2) class Stylizer(nn.Module): def __init__(self, ngf=64, phase=TrainingPhase.ENCODER, model_weights=None): super(Stylizer, self).__init__() # encoder self.encoder = nn.Sequential( ConvLayer(3, ngf, 3), # 512 ResBlock(ngf * 1, ngf * 1), # 256 ResBlock(ngf * 1, ngf * 2), # 128 ResBlock(ngf * 2, ngf * 4), # 64 ResBlock(ngf * 4, ngf * 8), # 32 ConvLayer(ngf * 8, ngf * 8, 3) # 32 ) # mapping network self.mapping_z = nn.Sequential(*([ PixelNorm() ] + [ EqualLinear(512, 512, activation='fused_lrelu', lr_mul=0.01) for _ in range(8) ])) # style-based decoder channels = { 32 : ngf * 8, 64 : ngf * 8, 128: ngf * 4, 256: ngf * 2, 512: ngf * 1 } self.decoder0 = StyledConv(channels[32], channels[32], 3, 512) self.to_rgb0 = ToRGB(channels[32], 512, upsample=False) for i in range(4): ichan = channels[2 ** (i + 5)] ochan = channels[2 ** (i + 6)] setattr(self, f'decoder{i + 1}a', StyledConv(ichan, ochan, 3, 512, upsample=True)) setattr(self, f'decoder{i + 1}b', StyledConv(ochan, ochan, 3, 512)) setattr(self, f'to_rgb{i + 1}', ToRGB(ochan, 512)) self.n_latent = 10 # random style for testing self.test_z = torch.randn(1, 512) # load pretrained model weights if phase == TrainingPhase.ENCODER: # load pretrained stylegan2 and freeze these params for param in self.mapping_z.parameters(): param.requires_grad = False for i in range(4): for key in [f'decoder{i + 1}a', f'decoder{i + 1}b', f'to_rgb{i + 1}']: for param in getattr(self, key).parameters(): param.requires_grad = False self.load_state_dict(self._convert_stylegan2_dict(model_weights), strict=False) elif phase == TrainingPhase.BASE_MODEL: # load pretrained encoder and stylegan2 decoder self.load_state_dict(model_weights) elif phase == TrainingPhase.CLIP_MAPPING: self.clip_mapper = CLIPFeats2Wplus(n_tokens=self.n_latent) # load pretraned base model and freeze all params except clip mapper self.load_state_dict(model_weights, strict=False) params = dict(self.named_parameters()) for k in params.keys(): if 'clip_mapper' in k: print(f'{k} not freezed !') continue params[k].requires_grad = False elif phase == TrainingPhase.FEW_SHOT: self.clip_mapper = CLIPFeats2Wplus(n_tokens=self.n_latent) # load pretrained base model and freeze encoder & mapping self.load_state_dict(model_weights) self.encoder.requires_grad_(False) self.mapping_z.requires_grad_(False) self.clip_mapper.requires_grad_(False) def _convert_stylegan2_dict(self, src): res = {} for k, v in src.items(): if k.startswith('style.'): res[k.replace('style.', 'mapping_z.')] = v else: name, idx = k.split('.')[:2] if name == 'convs': idx = int(idx) if idx >= 6: res[k.replace(f'{name}.{idx}.', f'decoder{idx // 2 - 2}{chr(97 + idx % 2)}.')] = v elif name == 'to_rgbs': idx = int(idx) if idx >= 3: res[k.replace(f'{name}.{idx}.', f'to_rgb{idx - 2}.')] = v return res def get_styles(self, x, **kwargs): if len(kwargs) == 0: return self.mapping_z(self.test_z.to(x.device).repeat(x.shape[0], 1)).repeat(self.n_latent, 1, 1) elif 'mixing' in kwargs and kwargs['mixing']: w0 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device)) w1 = self.mapping_z(torch.randn(x.shape[0], 512, device=x.device)) inject_index = random.randint(1, self.n_latent - 1) return torch.cat([ w0.repeat(inject_index, 1, 1), w1.repeat(self.n_latent - inject_index, 1, 1) ]) elif 'z' in kwargs: return self.mapping_z(kwargs['z']).repeat(self.n_latent, 1, 1) elif 'clip_feats' in kwargs: return self.clip_mapper(kwargs['clip_feats']) else: z = torch.randn(x.shape[0], 512, device=x.device) return self.mapping_z(z).repeat(self.n_latent, 1, 1) def forward(self, x, **kwargs): # encode feat = self.encoder(x) # get style code styles = self.get_styles(x, **kwargs) # style-based generate feat = self.decoder0(feat, styles[0]) out = self.to_rgb0(feat, styles[1]) for i in range(4): feat = getattr(self, f'decoder{i + 1}a')(feat, styles[i * 2 + 1]) feat = getattr(self, f'decoder{i + 1}b')(feat, styles[i * 2 + 2]) out = getattr(self, f'to_rgb{i + 1}')(feat, styles[i * 2 + 3], out) return F.hardtanh(out) class StyleBasedPix2PixIIModel(BaseModel): """ This class implements the Style-Based Pix2Pix model version II. """ def __init__(self, config, DDP_device=None): BaseModel.__init__(self, config, DDP_device=DDP_device) self.d_reg_freq = 16 self.lambda_r1 = 10 self.step = 0 self.phase = TrainingPhase(config['training']['phase']) # specify the training losses you want to print out. The training/test scripts will call if self.phase == TrainingPhase.ENCODER: self.loss_names = ['G', 'G_L1', 'G_Feat'] elif self.phase == TrainingPhase.BASE_MODEL: self.loss_names = ['G', 'G_ST', 'G_GAN', 'D'] elif self.phase == TrainingPhase.CLIP_MAPPING: self.loss_names = ['G', 'G_L1', 'G_Feat'] elif self.phase == TrainingPhase.FEW_SHOT: self.loss_names = ['G', 'G_ST', 'G_CLIP', 'G_PROJ'] # specify the images you want to save/display. The training/test scripts will call self.visual_names = ['real_A', 'real_B', 'fake_B'] # specify the models you want to save to the disk. The training/test scripts will call and . if self.isTrain: self.model_names = ['G', 'G_ema', 'D'] else: # during test time, only load Gs self.model_names = ['G_ema'] self.data_aug_prob = config['training']['data_aug_prob'] min_feats_size = tuple(config['model']['min_feats_size']) def __init_net(model): return networks.init_net(model, init_type='none', init_gain=0.0, gpu_ids=self.gpu_ids, DDP_device=self.DDP_device, find_unused_parameters=config['training']['find_unused_parameters']) if self.phase == TrainingPhase.ENCODER: # train a encoder for stylegan2 # load and init pretrained stylegan2 model_dict = torch.load(config['training']['pretrained_model'], map_location='cpu') self.stylegan2 = Generator(512, 512, 8) self.stylegan2.load_state_dict(model_dict['g']) self.stylegan2 = __init_net(self.stylegan2) self.stylegan2.eval() self.stylegan2.requires_grad_(False) # init netG self.netG = Stylizer(ngf=config['model']['ngf'], phase=self.phase, model_weights=model_dict['g']) self.netG = __init_net(self.netG) # init netD self.netD = Discriminator(min(min_feats_size) * 128, min_feats_size) self.netD.load_state_dict(model_dict['d']) self.netD = __init_net(self.netD) self.netD.eval() self.netD.requires_grad_(False) elif self.phase == TrainingPhase.BASE_MODEL: # finetune the whole model model_dict = torch.load(config['training']['pretrained_model'], map_location='cpu') # init netG self.netG = Stylizer(ngf=config['model']['ngf'], phase=self.phase, model_weights=model_dict['G_ema_model']) self.netG = __init_net(self.netG) # init netD self.netD = Discriminator(min(min_feats_size) * 128, min_feats_size) self.netD.load_state_dict(model_dict['D_model']) self.netD = __init_net(self.netD) elif self.phase == TrainingPhase.CLIP_MAPPING or self.phase == TrainingPhase.FEW_SHOT: # train the clip mapper or zero/one shot finetune # init CLIP self.clip_model, self.pil_to_tensor = clip.load('ViT-B/32', device=self.device) self.clip_model.eval() self.clip_model.requires_grad_(False) model_dict = torch.load(config['training']['pretrained_model'], map_location='cpu') # init netG self.netG = Stylizer(ngf=config['model']['ngf'], phase=self.phase, model_weights=model_dict['G_ema_model']) self.netG = __init_net(self.netG) # init netD self.netD = Discriminator(min(min_feats_size) * 128, min_feats_size) self.netD.load_state_dict(model_dict['D_model']) self.netD = __init_net(self.netD) self.netD.eval() self.netD.requires_grad_(False) if self.phase == TrainingPhase.FEW_SHOT: # set hook to get clip vit tokens def clip_vit_hook(model, feat_in, feat_out): self.clip_vit_tokens = feat_out[1:].permute(1, 0, 2).float() # remove cls token self.clip_model.visual.transformer.resblocks[3].register_forward_hook(clip_vit_hook) # create netG ema self.netG_ema = copy.deepcopy(self.netG) self.netG_ema.eval() self.netG_ema.requires_grad_(False) self.ema(self.netG_ema, self.netG, 0.0) # CLIP mean & std self.clip_mean = torch.tensor((0.48145466, 0.4578275, 0.40821073), device=self.device).view(1, 3, 1, 1) self.clip_std = torch.tensor((0.26862954, 0.26130258, 0.27577711), device=self.device).view(1, 3, 1, 1) if self.isTrain: # define loss functions if self.phase == TrainingPhase.ENCODER: self.criterionL1 = nn.L1Loss() elif self.phase == TrainingPhase.BASE_MODEL: self.criterionStyleGAN = networks.GANLoss('wgangp').to(self.device) self.vitLoss = LossG(self.device) elif self.phase == TrainingPhase.CLIP_MAPPING: self.criterionL1 = nn.L1Loss() elif self.phase == TrainingPhase.FEW_SHOT: self.criterionL1 = nn.L1Loss() self.vitLoss = LossG(self.device) self.cosineSim = nn.CosineSimilarity(dim=1) # initialize optimizers; schedulers will be automatically created by function . self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=config['training']['lr'], betas=(config['training']['beta1'], 0.999)) d_reg_ratio = self.d_reg_freq / (self.d_reg_freq + 1) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=config['training']['lr'] * d_reg_ratio, betas=(config['training']['beta1'] ** d_reg_ratio, 0.999 ** d_reg_ratio)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def ema(self, tgt, src, decay=0.999): param_tgt = dict(tgt.named_parameters()) param_src = dict(src.named_parameters()) for key in param_tgt.keys(): param_tgt[key].data.mul_(decay).add_(param_src[key].data, alpha=1.0 - decay) def preprocess_clip_image(self, x, size): x = x * 0.5 + 0.5 x = F.interpolate(x, (size, size), mode='bilinear', antialias=True, align_corners=False) return (x - self.clip_mean) / self.clip_std def set_input(self, input): if self.phase == TrainingPhase.ENCODER: # sample via stylegan2 self.z = torch.randn(self.config['dataset']['batch_size'], 512, device=self.device) with torch.no_grad(): self.real_A = F.hardtanh(self.stylegan2.forward([self.z])[0]) self.real_B = self.real_A.clone() elif self.phase == TrainingPhase.BASE_MODEL: if self.config['common']['phase'] == 'test': self.real_A = input['test_A'].to(self.device) self.real_B = input['test_B'].to(self.device) self.image_paths = input['test_A_path'] else: self.real_A = input['unpaired_A'].to(self.device) self.real_B = input['unpaired_B'].to(self.device) self.image_paths = input['unpaired_A_path'] elif self.phase == TrainingPhase.CLIP_MAPPING: self.real_A = input['unpaired_A'].to(self.device) with torch.no_grad(): self.real_B = self.netG_ema(self.real_A, mixing=random.random() < self.config['training']['style_mixing_prob']) self.clip_feats = self.clip_model.encode_image(self.preprocess_clip_image(self.real_B, self.clip_model.visual.input_resolution)) self.clip_feats /= self.clip_feats.norm(dim=1, keepdim=True) elif self.phase == TrainingPhase.FEW_SHOT: self.real_A = input['unpaired_A'].to(self.device) self.real_B = self.real_A if not hasattr(self, 'clip_feats'): with torch.no_grad(): if os.path.isfile(self.config['training']['image_prompt']): image = self.pil_to_tensor(Image.open(self.config['training']['image_prompt'])).unsqueeze(0).to(self.device) self.clip_feats = self.clip_model.encode_image(image) ref_tokens = self.clip_vit_tokens ref_tokens /= ref_tokens.norm(dim=2, keepdim=True) D = ref_tokens.shape[2] ref_tokens = ref_tokens.reshape(-1, D).permute(1, 0) U, _, _ = torch.linalg.svd(ref_tokens, full_matrices=False) self.UUT = U @ U.permute(1, 0) self.use_image_prompt = True else: text = clip.tokenize(self.config['training']['text_prompt']).to(self.device) self.clip_feats = self.clip_model.encode_text(text) self.use_image_prompt = False # get source text prompt feature text = clip.tokenize(self.config['training']['src_text_prompt']).to(self.device) self.src_clip_feats = self.clip_model.encode_text(text) self.src_clip_feats /= self.src_clip_feats.norm(dim=1, keepdim=True) self.src_clip_feats = self.src_clip_feats.repeat(self.config['dataset']['batch_size'], 1) self.clip_feats /= self.clip_feats.norm(dim=1, keepdim=True) self.clip_feats = self.clip_feats.repeat(self.config['dataset']['batch_size'], 1) # get direction in clip space with torch.no_grad(): self.real_A_clip_feats = self.clip_model.encode_image(self.preprocess_clip_image(self.real_A, self.clip_model.visual.input_resolution)) self.real_A_clip_feats /= self.real_A_clip_feats.norm(dim=1, keepdim=True) if self.use_image_prompt: self.src_clip_feats = self.real_A_clip_feats.mean(dim=0, keepdim=True).repeat(self.config['dataset']['batch_size'], 1) self.clip_feats_dir = self.clip_feats - self.src_clip_feats def forward(self, use_ema=False): if self.phase == TrainingPhase.ENCODER: if use_ema: self.fake_B = self.netG_ema(self.real_A, z=self.z) else: self.fake_B = self.netG(self.real_A, z=self.z) elif self.phase == TrainingPhase.BASE_MODEL: if not self.isTrain: self.fake_B = self.netG_ema(self.real_A, mixing=False) elif use_ema: self.fake_B = self.netG_ema(self.real_A, mixing=random.random() < self.config['training']['style_mixing_prob']) else: self.fake_B = self.netG(self.real_A, mixing=random.random() < self.config['training']['style_mixing_prob']) elif self.phase == TrainingPhase.CLIP_MAPPING or self.phase == TrainingPhase.FEW_SHOT: if use_ema: self.fake_B = self.netG_ema(self.real_A, clip_feats=self.clip_feats) else: self.fake_B = self.netG(self.real_A, clip_feats=self.clip_feats) def backward_D_r1(self): self.real_B.requires_grad = True if self.data_aug_prob == 0.0: real_aug = self.real_B else: real_aug, _ = augment(self.real_B, self.data_aug_prob) real_pred = self.netD(real_aug) with conv2d_gradfix.no_weight_gradients(): grad, = autograd.grad(outputs=real_pred.sum(), inputs=real_aug, create_graph=True) r1_loss = grad.pow(2).reshape(grad.shape[0], -1).sum(1).mean() (r1_loss * self.lambda_r1 / 2 * self.d_reg_freq + 0 * real_pred[0]).backward() def backward_D(self, backward=True): if self.data_aug_prob == 0.0: loss_fake = self.criterionStyleGAN(self.netD(self.fake_B.detach()), False) loss_real = self.criterionStyleGAN(self.netD(self.real_B), True) else: fake_aug, _ = augment(self.fake_B.detach(), self.data_aug_prob) real_aug, _ = augment(self.real_B, self.data_aug_prob) loss_fake = self.criterionStyleGAN(self.netD(fake_aug), False) loss_real = self.criterionStyleGAN(self.netD(real_aug), True) self.loss_D = (loss_fake + loss_real) * 0.5 if backward: self.loss_D.backward() def backward_G(self, backward=True): self.loss_G = 0 if self.phase == TrainingPhase.ENCODER or self.phase == TrainingPhase.CLIP_MAPPING: self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) with torch.no_grad(): real_feats = self.netD(self.real_B, rtn_feats=True) fake_feats = self.netD(self.fake_B, rtn_feats=True) self.loss_G_Feat = sum([ self.criterionL1(fake, real) for fake, real in zip(fake_feats, real_feats) ]) self.loss_G += self.loss_G_L1 * self.config['training']['lambda_L1'] self.loss_G += self.loss_G_Feat * self.config['training']['lambda_Feat'] elif self.phase == TrainingPhase.BASE_MODEL: self.loss_G_ST = self.vitLoss.calculate_global_ssim_loss(self.fake_B * 0.5 + 0.5, self.real_A * 0.5 + 0.5) if self.data_aug_prob == 0.0: self.loss_G_GAN = self.criterionStyleGAN(self.netD(self.fake_B), True) else: fake_aug, _ = augment(self.fake_B, self.data_aug_prob) self.loss_G_GAN = self.criterionStyleGAN(self.netD(fake_aug), True) self.loss_G += self.loss_G_ST * self.config['training']['lambda_ST'] self.loss_G += self.loss_G_GAN * self.config['training']['lambda_GAN'] elif self.phase == TrainingPhase.FEW_SHOT: self.loss_G_ST = self.vitLoss.calculate_global_ssim_loss(self.fake_B * 0.5 + 0.5, self.real_A * 0.5 + 0.5) fake_clip_feats = self.clip_model.encode_image(self.preprocess_clip_image(self.fake_B, self.clip_model.visual.input_resolution)) fake_clip_feats = fake_clip_feats / fake_clip_feats.norm(dim=1, keepdim=True) fake_clip_feats_dir = fake_clip_feats - self.real_A_clip_feats self.loss_G_CLIP = (1.0 - self.cosineSim(fake_clip_feats_dir, self.clip_feats_dir)).mean() if self.use_image_prompt: fake_tokens = self.clip_vit_tokens fake_tokens = fake_tokens / fake_tokens.norm(dim=2, keepdim=True) D = fake_tokens.shape[2] fake_tokens = fake_tokens.reshape(-1, D).permute(1, 0) self.loss_G_PROJ = self.criterionL1(self.UUT @ fake_tokens, fake_tokens) else: self.loss_G_PROJ = 0.0 self.loss_G += self.loss_G_ST * self.config['training']['lambda_ST'] self.loss_G += self.loss_G_CLIP * self.config['training']['lambda_CLIP'] self.loss_G += self.loss_G_PROJ * self.config['training']['lambda_PROJ'] if backward: self.loss_G.backward() def optimize_parameters(self): # forward self.forward() if not self.phase == TrainingPhase.BASE_MODEL: # only G self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # update G_ema self.ema(self.netG_ema, self.netG, decay=self.config['training']['ema']) else: # G self.set_requires_grad([self.netD], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D self.set_requires_grad([self.netD], True) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update G_ema self.ema(self.netG_ema, self.netG, decay=self.config['training']['ema']) # r1 reg self.step += 1 if self.step % self.d_reg_freq == 0: self.optimizer_D.zero_grad() self.backward_D_r1() self.optimizer_D.step() def eval_step(self): self.forward(use_ema=True) self.backward_G(False) if self.phase == TrainingPhase.BASE_MODEL: self.backward_D(False) self.step += 1 def trace_jit(self, input): self.netG = self.netG.module.cpu() traced_script_module = torch.jit.trace(self.netG, input) dummy_output = self.netG_ema(input) dummy_output_traced = traced_script_module(input) return traced_script_module, dummy_output, dummy_output_traced