| import os |
| import requests |
| import sys |
| import copy |
| import random |
| import time |
| import glob |
| import math |
| import yaml |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| from peft import LoraConfig |
| from types import SimpleNamespace |
| from transformers import AutoTokenizer, CLIPTextModel |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler |
| from diffusers.utils.peft_utils import set_weights_and_activate_adapters |
| from diffusers.utils.import_utils import is_xformers_available |
|
|
| def make_1step_sched(pretrained_model_path): |
| noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") |
| noise_scheduler_1step.set_timesteps(1, device="cuda") |
| noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() |
| return noise_scheduler_1step |
|
|
| def find_filepath(directory, filename): |
| matches = glob.glob(f"{directory}/**/{filename}", recursive=True) |
| return matches[0] if matches else None |
|
|
|
|
| def read_yaml(file_path): |
| with open(file_path, 'r') as file: |
| data = yaml.safe_load(file) |
| return data |
|
|
| def initialize_vae(rank, return_lora_module_names=False, pretrained_model_name_or_path=None): |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
| vae.requires_grad_(False) |
| vae.train() |
| |
| l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] |
| l_grep = ["conv1","conv2","conv_in", "conv_shortcut", |
| "conv", "conv_out", "to_k", "to_q", "to_v", "to_out.0", |
| ] |
| for n, p in vae.named_parameters(): |
| if "bias" in n or "norm" in n: continue |
| for pattern in l_grep: |
| if pattern in n and ("encoder" in n): |
| l_target_modules_encoder.append(n.replace(".weight","")) |
| break |
| elif pattern in n and ("decoder" in n): |
| l_target_modules_decoder.append(n.replace(".weight","")) |
| break |
| elif ('quant_conv' in n) and ('post_quant_conv' not in n): |
| l_target_modules_encoder.append(n.replace(".weight","")) |
| break |
| elif 'post_quant_conv' in n: |
| l_target_modules_decoder.append(n.replace(".weight","")) |
| break |
| elif pattern in n: |
| l_modules_others.append(n.replace(".weight","")) |
| break |
| lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder) |
| lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder) |
| vae.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| vae.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| |
| if return_lora_module_names: |
| return vae, l_target_modules_encoder, l_target_modules_decoder, l_modules_others |
| else: |
| return vae |
|
|
| def initialize_unet(rank, return_lora_module_names=False, pretrained_model_name_or_path=None): |
| unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") |
| unet.requires_grad_(False) |
| unet.train() |
|
|
| l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] |
| l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"] |
| for n, p in unet.named_parameters(): |
| if "bias" in n or "norm" in n: continue |
| for pattern in l_grep: |
| if pattern in n and ("down_blocks" in n or "conv_in" in n): |
| l_target_modules_encoder.append(n.replace(".weight","")) |
| break |
| elif pattern in n and "up_blocks" in n: |
| l_target_modules_decoder.append(n.replace(".weight","")) |
| break |
| elif pattern in n: |
| l_modules_others.append(n.replace(".weight","")) |
| break |
| lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder) |
| lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder) |
| lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others) |
| unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| if return_lora_module_names: |
| return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others |
| else: |
| return unet |
|
|
| def initialize_unet_sr(rank, return_lora_module_names=False, pretrained_model_name_or_path=None, args=None): |
| unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") |
| if args.use_lr_concat_lr_999noise: |
| new_conv_in = torch.nn.Conv2d(8, 320, 3, 1, 1) |
| new_conv_in.weight.data[:, :4, ...] = unet.conv_in.weight.data |
| new_conv_in.weight.data[:, -4:, ...] = unet.conv_in.weight.data |
| new_conv_in.bias.data = unet.conv_in.bias.data |
| unet.conv_in = new_conv_in |
| unet.requires_grad_(False) |
| unet.train() |
|
|
| l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] |
| l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"] |
| for n, p in unet.named_parameters(): |
| if "bias" in n or "norm" in n: continue |
| for pattern in l_grep: |
| if pattern in n and ("down_blocks" in n or "conv_in" in n): |
| l_target_modules_encoder.append(n.replace(".weight","")) |
| break |
| elif pattern in n and "up_blocks" in n: |
| l_target_modules_decoder.append(n.replace(".weight","")) |
| break |
| elif pattern in n: |
| l_modules_others.append(n.replace(".weight","")) |
| break |
| lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder) |
| lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder) |
| lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others) |
| unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| if return_lora_module_names: |
| return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others |
| else: |
| return unet |
|
|
| class VSD(torch.nn.Module): |
| def __init__(self, args, accelerator): |
| super().__init__() |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") |
| self.sched = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| self.args = args |
|
|
| weight_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| weight_dtype = torch.bfloat16 |
|
|
| self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| self.unet_fix = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| self.unet_update, self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others =\ |
| initialize_unet(rank=args.lora_rank_unet_vsd, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True) |
| self.lora_rank_unet = args.lora_rank_unet_vsd |
|
|
| if args.enable_xformers_memory_efficient_attention: |
| if is_xformers_available(): |
| self.unet_fix.enable_xformers_memory_efficient_attention() |
| self.unet_update.enable_xformers_memory_efficient_attention() |
| else: |
| raise ValueError("xformers is not available, please install it by running `pip install xformers`") |
|
|
| if args.gradient_checkpointing: |
| self.unet_fix.enable_gradient_checkpointing() |
| self.unet_update.enable_gradient_checkpointing() |
|
|
| self.text_encoder.to(accelerator.device, dtype=weight_dtype) |
| self.unet_fix.to(accelerator.device, dtype=weight_dtype) |
| self.unet_update.to(accelerator.device) |
| self.vae.to(accelerator.device) |
| |
| self.text_encoder.requires_grad_(False) |
| self.vae.requires_grad_(False) |
| self.unet_fix.requires_grad_(False) |
|
|
| def set_eval(self): |
| self.unet_fix.eval() |
| self.unet.eval() |
| self.unet_update.eval() |
|
|
| def set_train(self): |
| self.unet_update.train() |
| for n, _p in self.unet_update.named_parameters(): |
| if "lora" in n: |
| _p.requires_grad = True |
|
|
| def forward(self, c_t, prompt=None, neg_prompt_tokens=None, prompt_tokens=None, deterministic=True, r=1.0, noise_map=None, args=None): |
|
|
| caption_enc = self.text_encoder(prompt_tokens)[0] |
| neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0] |
|
|
| encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
|
|
| output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) |
|
|
| return output_image, caption_enc, neg_caption_enc |
|
|
| def forward_latent(self, model, latents, timestep, prompt_embeds): |
| |
| noise_pred = model( |
| latents, |
| timestep=timestep, |
| encoder_hidden_states=prompt_embeds, |
| ).sample |
|
|
| return noise_pred |
|
|
| def compute_lora_loss(self, latents_pred, prompt_embeds, args): |
|
|
| latents_pred = latents_pred.detach() |
| prompt_embeds = prompt_embeds.detach() |
| noise = torch.randn_like(latents_pred) |
| bsz = latents_pred.shape[0] |
| timesteps = torch.randint(0, self.sched.config.num_train_timesteps, (bsz,), device=latents_pred.device) |
| timesteps = timesteps.long() |
| noisy_latents = self.sched.add_noise(latents_pred, noise, timesteps) |
| disc_pred = self.forward_latent( |
| self.unet_update, |
| timestep=timesteps, |
| latents=noisy_latents, |
| prompt_embeds=prompt_embeds |
| ) |
| if args.snr_gamma_vsd is None: |
| loss_d = F.mse_loss(disc_pred.float(), noise.float(), reduction="mean") |
| else: |
| |
| |
| |
| snr = compute_snr(self.sched, timesteps) |
| if self.sched.config.prediction_type == "v_prediction": |
| |
| snr = snr + 1 |
| mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr |
|
|
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
| loss_d = loss.mean() |
|
|
| return loss_d |
|
|
| def eps_to_mu(self, scheduler, model_output, sample, timesteps): |
| alphas_cumprod = scheduler.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) |
| alpha_prod_t = alphas_cumprod[timesteps] |
| while len(alpha_prod_t.shape) < len(sample.shape): |
| alpha_prod_t = alpha_prod_t.unsqueeze(-1) |
| beta_prod_t = 1 - alpha_prod_t |
| pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| return pred_original_sample |
|
|
| def distribution_matching_loss( |
| self, |
| real_model, |
| fake_model, |
| noise_scheduler, |
| latents, |
| prompt_embeds, |
| negative_prompt_embeds, |
| args, |
| ): |
| bsz = latents.shape[0] |
| min_dm_step = int(noise_scheduler.config.num_train_timesteps * args.min_dm_step_ratio) |
| max_dm_step = int(noise_scheduler.config.num_train_timesteps * args.max_dm_step_ratio) |
|
|
| timestep = torch.randint(min_dm_step, max_dm_step, (bsz,), device=latents.device).long() |
| noise = torch.randn_like(latents) |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timestep) |
|
|
| with torch.no_grad(): |
| noise_pred = self.forward_latent( |
| fake_model, |
| latents=noisy_latents, |
| timestep=timestep, |
| prompt_embeds=prompt_embeds.float(), |
| ) |
| pred_fake_latents = self.eps_to_mu(noise_scheduler, noise_pred, noisy_latents, timestep) |
|
|
| noisy_latents_input = torch.cat([noisy_latents] * 2) |
| timestep_input = torch.cat([timestep] * 2) |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
|
| noise_pred = self.forward_latent( |
| real_model, |
| latents=noisy_latents_input.to(dtype=torch.float16), |
| timestep=timestep_input, |
| prompt_embeds=prompt_embeds.to(dtype=torch.float16), |
| ) |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + args.cfg_vsd * (noise_pred_text - noise_pred_uncond) |
| noise_pred.to(dtype=torch.float32) |
|
|
| pred_real_latents = self.eps_to_mu(noise_scheduler, noise_pred, noisy_latents, timestep) |
|
|
| weighting_factor = torch.abs(latents - pred_real_latents).mean(dim=[1, 2, 3], keepdim=True) |
|
|
| grad = (pred_fake_latents - pred_real_latents) / weighting_factor |
| loss = F.mse_loss(latents, self.stopgrad(latents - grad)) |
| return loss |
|
|
| def stopgrad(self, x): |
| return x.detach() |
|
|
| def save_model(self, outf): |
| sd = {} |
| sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\ |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others |
| sd["rank_unet"] = self.lora_rank_unet |
| sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k} |
| torch.save(sd, outf) |
|
|
| class NAOSD(torch.nn.Module): |
| def __init__(self, args): |
| super().__init__() |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda() |
| self.sched = make_1step_sched(args.pretrained_model_name_or_path) |
| self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| self.args = args |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| |
| if args.pretrained_path is None: |
| vae, lora_vae_modules_encoder, lora_vae_modules_decoder, lora_vae_others =\ |
| initialize_vae(rank=args.lora_rank_vae, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True) |
| unet, lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others =\ |
| initialize_unet_sr(rank=args.lora_rank_unet, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True, args=args) |
| self.lora_rank_unet = args.lora_rank_unet |
| self.lora_rank_vae = args.lora_rank_vae |
| self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others = \ |
| lora_vae_modules_encoder, lora_vae_modules_decoder, lora_vae_others |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others |
| |
| self.unet, self.vae = unet, vae |
| |
| if args.pretrained_path is not None: |
| print('==================================> loading pre-trained weight') |
| sd = torch.load(args.pretrained_path) |
| self.load_ckpt_from_state_dict(sd) |
| self.lora_rank_unet = sd['rank_unet'] |
| self.lora_rank_vae = sd['rank_vae'] |
| self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others = \ |
| sd['vae_lora_encoder_modules'], sd['vae_lora_decoder_modules'], sd['vae_lora_others_modules'] |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| sd['unet_lora_encoder_modules'], sd['unet_lora_decoder_modules'], sd['unet_lora_others_modules'] |
|
|
| self.unet, self.vae = self.unet.cuda(), self.vae.cuda() |
| self.timesteps = torch.tensor([args.time_step], device="cuda").long() |
| self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long() |
| self.text_encoder.requires_grad_(False) |
|
|
| def set_eval(self): |
| self.unet.eval() |
| self.vae.eval() |
| self.unet.requires_grad_(False) |
| self.vae.requires_grad_(False) |
|
|
| def set_train(self): |
| self.unet.train() |
| self.vae.train() |
| for n, _p in self.unet.named_parameters(): |
| if "lora" in n: |
| _p.requires_grad = True |
| self.unet.conv_in.requires_grad_(True) |
| for n, _p in self.vae.named_parameters(): |
| if "lora" in n: |
| _p.requires_grad = True |
|
|
| def encode_prompt(self, prompt): |
| with torch.no_grad(): |
| text_input_ids = self.tokenizer( |
| prompt, max_length=self.tokenizer.model_max_length, |
| padding="max_length", truncation=True, return_tensors="pt" |
| ).input_ids |
| prompt_embeds = self.text_encoder( |
| text_input_ids.to(self.text_encoder.device), |
| )[0] |
| return prompt_embeds |
|
|
| def forward(self, c_t, positive_prompt=None, negative_prompt=None, args=None): |
| caption_enc = self.encode_prompt(positive_prompt) |
| neg_caption_enc = self.encode_prompt(negative_prompt) |
| encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| noise = torch.randn_like(encoded_control) |
| encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise) |
|
|
| model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
| output_image = output_image.clamp(-1, 1) |
|
|
| return output_image, x_denoised, caption_enc, neg_caption_enc, noise |
|
|
| def save_model(self, outf): |
| sd = {} |
| sd["vae_lora_encoder_modules"], sd["vae_lora_decoder_modules"], sd["vae_lora_others_modules"] =\ |
| self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others |
| sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\ |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others |
| sd["rank_unet"] = self.lora_rank_unet |
| sd["rank_vae"] = self.lora_rank_vae |
| sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} |
| sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k} |
| torch.save(sd, outf) |
| |
| def load_ckpt_from_state_dict(self, sd): |
| |
| lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_encoder_modules"]) |
| lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_decoder_modules"]) |
| lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_others_modules"]) |
| self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| self.unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| for n, p in self.unet.named_parameters(): |
| if "lora" in n or "conv_in" in n: |
| p.data.copy_(sd["state_dict_unet"][n]) |
|
|
| |
| vae_lora_conf_encoder = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_encoder_modules"]) |
| vae_lora_conf_decoder = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_decoder_modules"]) |
| self.vae.add_adapter(vae_lora_conf_encoder, adapter_name="default_encoder") |
| self.vae.add_adapter(vae_lora_conf_decoder, adapter_name="default_decoder") |
| for n, p in self.vae.named_parameters(): |
| if "lora" in n: |
| p.data.copy_(sd["state_dict_vae"][n]) |
|
|
| class GDPOSR(torch.nn.Module): |
| def __init__(self, args): |
| super().__init__() |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda() |
| self.sched = make_1step_sched(args.pretrained_model_name_or_path) |
| self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| self.args = args |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| vae = AutoencoderKL.from_pretrained(args.basemodel_path, subfolder="vae") |
| unet = UNet2DConditionModel.from_pretrained(args.basemodel_path, subfolder="unet") |
| ref_unet = UNet2DConditionModel.from_pretrained(args.basemodel_path, subfolder="unet") |
| |
| if args.pretrained_path is None: |
| print('==================================> randomly initiate the weight') |
| unet, lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others =\ |
| initialize_unet_sr(rank=args.lora_rank_unet, pretrained_model_name_or_path=args.basemodel_path, return_lora_module_names=True, args=args) |
| self.lora_rank_unet = args.lora_rank_unet |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others |
| |
| self.unet, self.vae = unet, vae |
| |
| if args.pretrained_path is not None: |
| print('==================================> loading pre-trained weight') |
| sd = torch.load(args.pretrained_path) |
| self.load_ckpt_from_state_dict(sd) |
| self.lora_rank_unet = sd['rank_unet'] |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| sd['unet_lora_encoder_modules'], sd['unet_lora_decoder_modules'], sd['unet_lora_others_modules'] |
|
|
| self.unet, self.vae = self.unet.cuda(), self.vae.cuda() |
| self.ref_unet = ref_unet.cuda() |
| self.timesteps = torch.tensor([args.time_step], device="cuda").long() |
| self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long() |
| self.text_encoder.requires_grad_(False) |
|
|
| def set_eval(self): |
| self.unet.eval() |
| self.vae.eval() |
| self.ref_unet.eval() |
| self.unet.requires_grad_(False) |
| self.vae.requires_grad_(False) |
| self.ref_unet.requires_grad_(False) |
|
|
| def set_train(self): |
| self.unet.train() |
| self.vae.train() |
| for n, _p in self.unet.named_parameters(): |
| if "lora" in n: |
| _p.requires_grad = True |
| for n, _p in self.ref_unet.named_parameters(): |
| _p.requires_grad = False |
|
|
| def encode_prompt(self, prompt): |
| with torch.no_grad(): |
| text_input_ids = self.tokenizer( |
| prompt, max_length=self.tokenizer.model_max_length, |
| padding="max_length", truncation=True, return_tensors="pt" |
| ).input_ids |
| prompt_embeds = self.text_encoder( |
| text_input_ids.to(self.text_encoder.device), |
| )[0] |
| return prompt_embeds |
|
|
| def forward(self, c_t, positive_prompt=[''], negative_prompt=[''], args=None): |
| caption_enc = self.encode_prompt(positive_prompt) |
| neg_caption_enc = self.encode_prompt(negative_prompt) |
| with torch.no_grad(): |
| encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| encoded_control_ref = encoded_control |
| noise = torch.randn_like(encoded_control) |
| encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise) |
|
|
| model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
| output_image = output_image.clamp(-1, 1) |
|
|
| with torch.no_grad(): |
| encoded_control_ref = self.sched2.add_noise(encoded_control_ref, noise, self.timestepsnoise) |
| ref_model_pred = self.ref_unet(encoded_control_ref, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| ref_x_denoised = self.sched.step(ref_model_pred, self.timesteps, encoded_control_ref, return_dict=True).prev_sample |
| ref_output_image = self.vae.decode(ref_x_denoised / self.vae.config.scaling_factor).sample |
| ref_output_image = ref_output_image.clamp(-1, 1) |
|
|
| return output_image, x_denoised, model_pred, caption_enc, neg_caption_enc, noise, ref_output_image, ref_x_denoised, ref_model_pred |
| |
| def GDPOReference(self, c_t, positive_prompt=[''], negative_prompt=[''], args=None, groupsize=6): |
| |
| with torch.no_grad(): |
| |
| caption_enc = self.encode_prompt(positive_prompt).unsqueeze(1) |
| encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| b,c,h,w=encoded_control.shape |
| encoded_control = encoded_control.unsqueeze(1) |
| caption_enc = caption_enc.repeat(1,groupsize,1,1) |
| encoded_control = encoded_control.repeat(1, groupsize, 1, 1, 1) |
| noise = torch.randn_like(encoded_control) |
| output_image = torch.zeros_like(c_t).unsqueeze(1).repeat(1,groupsize,1,1,1) |
| x_denoised = torch.zeros_like(noise) |
| model_pred = torch.zeros_like(noise) |
| for i in range(b): |
| encoded_control_i = self.sched2.add_noise(encoded_control[i], noise[i], self.timestepsnoise) |
| |
| model_pred_i = self.ref_unet(encoded_control_i, self.timesteps, encoder_hidden_states=caption_enc[i],).sample |
| x_denoised_i = self.sched.step(model_pred_i, self.timesteps, encoded_control_i, return_dict=True).prev_sample |
| output_image_i = self.vae.decode(x_denoised_i / self.vae.config.scaling_factor).sample |
| output_image_i = output_image_i.clamp(-1, 1) |
| output_image[i] = output_image_i |
| x_denoised[i] = x_denoised_i |
| model_pred[i] = model_pred_i |
|
|
| return output_image, x_denoised, model_pred |
| |
| def save_model(self, outf): |
| sd = {} |
| sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\ |
| self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others |
| sd["rank_unet"] = self.lora_rank_unet |
| sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} |
| torch.save(sd, outf) |
| |
| def load_ckpt_from_state_dict(self, sd): |
| |
| lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_encoder_modules"]) |
| lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_decoder_modules"]) |
| lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_others_modules"]) |
| self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| self.unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| for n, p in self.unet.named_parameters(): |
| if "lora" in n or "conv_in" in n: |
| p.data.copy_(sd["state_dict_unet"][n]) |
|
|
| class GDPOSRTest(torch.nn.Module): |
| def __init__(self, args): |
| super().__init__() |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda() |
| self.sched = make_1step_sched(args.pretrained_model_name_or_path) |
| self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| self.args = args |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| vae = AutoencoderKL.from_pretrained(args.pretrained_path, subfolder="vae") |
| unet = UNet2DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet") |
| |
| self.unet, self.vae = unet, vae |
| self.unet, self.vae = self.unet.cuda(), self.vae.cuda() |
| self.timesteps = torch.tensor([args.time_step], device="cuda").long() |
| self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long() |
| self.text_encoder.requires_grad_(False) |
|
|
| def set_eval(self): |
| self.unet.eval() |
| self.vae.eval() |
| self.unet.requires_grad_(False) |
| self.vae.requires_grad_(False) |
|
|
| def encode_prompt(self, prompt): |
| with torch.no_grad(): |
| text_input_ids = self.tokenizer( |
| prompt, max_length=self.tokenizer.model_max_length, |
| padding="max_length", truncation=True, return_tensors="pt" |
| ).input_ids |
| prompt_embeds = self.text_encoder( |
| text_input_ids.to(self.text_encoder.device), |
| )[0] |
| return prompt_embeds |
|
|
| def forward(self, c_t, positive_prompt=['']): |
| |
| caption_enc = self.encode_prompt(positive_prompt) |
| encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| noise = torch.randn_like(encoded_control) |
| encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise) |
|
|
| model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
| output_image = output_image.clamp(-1, 1) |
| |
|
|
| return output_image |
|
|