radames's picture
cleanup
d8457bc
raw
history blame
No virus
8.91 kB
# https://github.com/GaParmar/img2img-turbo/blob/main/src/pix2pix_turbo.py
import os
import requests
import sys
import pdb
import copy
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from peft import LoraConfig
from pipelines.pix2pix.model import (
make_1step_sched,
my_vae_encoder_fwd,
my_vae_decoder_fwd,
)
class TwinConv(torch.nn.Module):
def __init__(self, convin_pretrained, convin_curr):
super(TwinConv, self).__init__()
self.conv_in_pretrained = copy.deepcopy(convin_pretrained)
self.conv_in_curr = copy.deepcopy(convin_curr)
self.r = None
def forward(self, x):
x1 = self.conv_in_pretrained(x).detach()
x2 = self.conv_in_curr(x)
return x1 * (1 - self.r) + x2 * (self.r)
class Pix2Pix_Turbo(torch.nn.Module):
def __init__(self, name, ckpt_folder="checkpoints"):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
"stabilityai/sd-turbo", subfolder="tokenizer"
)
self.text_encoder = CLIPTextModel.from_pretrained(
"stabilityai/sd-turbo", subfolder="text_encoder"
).cuda()
self.sched = make_1step_sched()
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/sd-turbo", subfolder="unet"
)
if name == "edge_to_image":
url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(
total=total_size_in_bytes, unit="iB", unit_scale=True
)
with open(outf, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
p_ckpt = outf
sd = torch.load(p_ckpt, map_location="cpu")
unet_lora_config = LoraConfig(
r=sd["rank_unet"],
init_lora_weights="gaussian",
target_modules=sd["unet_lora_target_modules"],
)
if name == "sketch_to_image_stochastic":
# download from url
url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
os.makedirs(ckpt_folder, exist_ok=True)
outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(
total=total_size_in_bytes, unit="iB", unit_scale=True
)
with open(outf, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
p_ckpt = outf
sd = torch.load(p_ckpt, map_location="cpu")
unet_lora_config = LoraConfig(
r=sd["rank_unet"],
init_lora_weights="gaussian",
target_modules=sd["unet_lora_target_modules"],
)
convin_pretrained = copy.deepcopy(unet.conv_in)
unet.conv_in = TwinConv(convin_pretrained, unet.conv_in)
vae.encoder.forward = my_vae_encoder_fwd.__get__(
vae.encoder, vae.encoder.__class__
)
vae.decoder.forward = my_vae_decoder_fwd.__get__(
vae.decoder, vae.decoder.__class__
)
# add the skip connection convs
vae.decoder.skip_conv_1 = torch.nn.Conv2d(
512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
).cuda()
vae.decoder.skip_conv_2 = torch.nn.Conv2d(
256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
).cuda()
vae.decoder.skip_conv_3 = torch.nn.Conv2d(
128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
).cuda()
vae.decoder.skip_conv_4 = torch.nn.Conv2d(
128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
).cuda()
vae_lora_config = LoraConfig(
r=sd["rank_vae"],
init_lora_weights="gaussian",
target_modules=sd["vae_lora_target_modules"],
)
vae.decoder.ignore_skip = False
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
unet.add_adapter(unet_lora_config)
_sd_unet = unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
unet.load_state_dict(_sd_unet)
unet.enable_xformers_memory_efficient_attention()
_sd_vae = vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
vae.load_state_dict(_sd_vae)
unet.to("cuda")
vae.to("cuda")
unet.eval()
vae.eval()
self.unet, self.vae = unet, vae
self.vae.decoder.gamma = 1
self.timesteps = torch.tensor([999], device="cuda").long()
self.last_prompt = ""
self.caption_enc = None
self.device = "cuda"
@torch.no_grad()
def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0):
# encode the text prompt
if prompt != self.last_prompt:
caption_tokens = self.tokenizer(
prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]
self.caption_enc = caption_enc
self.last_prompt = prompt
if deterministic:
encoded_control = (
self.vae.encode(c_t).latent_dist.sample()
* self.vae.config.scaling_factor
)
model_pred = self.unet(
encoded_control,
self.timesteps,
encoder_hidden_states=self.caption_enc,
).sample
x_denoised = self.sched.step(
model_pred, self.timesteps, encoded_control, return_dict=True
).prev_sample
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
output_image = (
self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
).clamp(-1, 1)
else:
# scale the lora weights based on the r value
self.unet.set_adapters(["default"], weights=[r])
set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r])
encoded_control = (
self.vae.encode(c_t).latent_dist.sample()
* self.vae.config.scaling_factor
)
# combine the input and noise
unet_input = encoded_control * r + noise_map * (1 - r)
self.unet.conv_in.r = r
unet_output = self.unet(
unet_input,
self.timesteps,
encoder_hidden_states=self.caption_enc,
).sample
self.unet.conv_in.r = None
x_denoised = self.sched.step(
unet_output, self.timesteps, unet_input, return_dict=True
).prev_sample
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
self.vae.decoder.gamma = r
output_image = (
self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
).clamp(-1, 1)
return output_image