import os import einops import torch import torch as th import torch.nn as nn import cv2 from pytorch_lightning.utilities.distributed import rank_zero_only import numpy as np from torch.optim.lr_scheduler import LambdaLR from ldm.modules.diffusionmodules.util import ( conv_nd, linear, zero_module, timestep_embedding, ) from einops import rearrange, repeat from torchvision.utils import make_grid from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.util import log_txt_as_img, exists, instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import load_state_dict class CustomNet(LatentDiffusion): def __init__(self, text_encoder_config, sd_15_ckpt=None, use_cond_concat=False, use_bbox_mask=False, use_bg_inpainting=False, learning_rate_scale=10, *args, **kwargs): super().__init__(*args, **kwargs) self.text_encoder = instantiate_from_config(text_encoder_config) if sd_15_ckpt is not None: self.load_model_from_ckpt(ckpt=sd_15_ckpt) self.use_cond_concat = use_cond_concat self.use_bbox_mask = use_bbox_mask self.use_bg_inpainting = use_bg_inpainting self.learning_rate_scale = learning_rate_scale def load_model_from_ckpt(self, ckpt, verbose=True): print(" =========================== init Stable Diffusion pretrained checkpoint =========================== ") print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] sd_keys = sd.keys() missing = [] text_encoder_sd = self.text_encoder.state_dict() for k in text_encoder_sd.keys(): sd_k = "cond_stage_model."+ k if sd_k in sd_keys: text_encoder_sd[k] = sd[sd_k] else: missing.append(k) self.text_encoder.load_state_dict(text_encoder_sd) def configure_optimizers(self): lr = self.learning_rate params = [] params += list(self.cc_projection.parameters()) params_dualattn = [] for k, v in self.model.named_parameters(): if "to_k_text" in k or "to_v_text" in k: params_dualattn.append(v) print("training weight: ", k) else: params.append(v) opt = torch.optim.AdamW([ {'params':params_dualattn, 'lr': lr*self.learning_rate_scale}, {'params': params, 'lr': lr} ]) if self.use_scheduler: assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [opt], scheduler return opt def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: lr = self.optimizers().param_groups[0]['lr'] self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss def shared_step(self, batch, **kwargs): if 'txt' in self.ucg_training: k = 'txt' p = self.ucg_training[k] for i in range(len(batch[k])): if self.ucg_prng.choice(2, p=[1 - p, p]): if isinstance(batch[k], list): batch[k][i] = "" with torch.no_grad(): text = batch['txt'] text_embedding = self.text_encoder(text) x, c = self.get_input(batch, self.first_stage_key) c["c_crossattn"].append(text_embedding) loss = self(x, c,) return loss def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs): if isinstance(cond, dict): # hybrid case, cond is exptected to be a dict pass else: if not isinstance(cond, list): cond = [cond] key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon