import os import requests import sys import pdb import copy from tqdm import tqdm import torch from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler from diffusers.utils.peft_utils import set_weights_and_activate_adapters from peft import LoraConfig p = "src/" sys.path.append(p) from model import make_1step_sched """The forward method of the `Encoder` class.""" def my_vae_encoder_fwd(self, sample): sample = self.conv_in(sample) l_blocks = [] # down for down_block in self.down_blocks: l_blocks.append(sample) sample = down_block(sample) # middle sample = self.mid_block(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) self.current_down_blocks = l_blocks return sample """The forward method of the `Decoder` class.""" def my_vae_decoder_fwd(self,sample, latent_embeds = None): sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) if not self.ignore_skip: skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] # up for idx, up_block in enumerate(self.up_blocks): skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx]) # add skip sample = sample + skip_in sample = up_block(sample, latent_embeds) else: for idx, up_block in enumerate(self.up_blocks): sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class TwinConv(torch.nn.Module): def __init__(self, convin_pretrained, convin_curr): super(TwinConv, self).__init__() self.conv_in_pretrained = copy.deepcopy(convin_pretrained) self.conv_in_curr = copy.deepcopy(convin_curr) self.r = None def forward(self, x): x1 = self.conv_in_pretrained(x).detach() x2 = self.conv_in_curr(x) return x1*(1-self.r) + x2*(self.r) class Pix2Pix_Turbo(torch.nn.Module): def __init__(self, name, ckpt_folder="checkpoints"): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo",subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda() self.sched = make_1step_sched() vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet") if name=="edge_to_image": url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl" os.makedirs(ckpt_folder, exist_ok=True) outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl") if not os.path.exists(outf): print(f"Downloading checkpoint to {outf}") response = requests.get(url, stream=True) total_size_in_bytes= int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(outf, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") print(f"Downloaded successfully to {outf}") p_ckpt = outf sd = torch.load(p_ckpt, map_location="cpu") unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) if name=="sketch_to_image_stochastic": # download from url url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl" os.makedirs(ckpt_folder, exist_ok=True) outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl") if not os.path.exists(outf): print(f"Downloading checkpoint to {outf}") response = requests.get(url, stream=True) total_size_in_bytes= int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(outf, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") print(f"Downloaded successfully to {outf}") p_ckpt = outf sd = torch.load(p_ckpt, map_location="cpu") unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) convin_pretrained = copy.deepcopy(unet.conv_in) unet.conv_in = TwinConv(convin_pretrained, unet.conv_in) vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) # add the skip connection convs vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) vae.decoder.ignore_skip = False vae.add_adapter(vae_lora_config, adapter_name="vae_skip") unet.add_adapter(unet_lora_config) _sd_unet = unet.state_dict() for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k] unet.load_state_dict(_sd_unet) unet.enable_xformers_memory_efficient_attention() _sd_vae = vae.state_dict() for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k] vae.load_state_dict(_sd_vae) unet.to("cuda") vae.to("cuda") unet.eval() vae.eval() self.unet, self.vae = unet, vae self.timesteps = torch.tensor([999], device="cuda").long() def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=None): # encode the text prompt caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() caption_enc = self.text_encoder(caption_tokens)[0] if deterministic: encoded_control = self.vae.encode(c_t).latent_dist.sample()*self.vae.config.scaling_factor model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1) else: # scale the lora weights based on the r value self.unet.set_adapters(["default"], weights=[r]) set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r]) encoded_control = self.vae.encode(c_t).latent_dist.sample()*self.vae.config.scaling_factor # combine the input and noise unet_input = encoded_control*r + noise_map*(1-r) self.unet.conv_in.r = r unet_output = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample self.unet.conv_in.r = None x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1) return output_image