myn0908's picture
adding os
e2053b7
raw
history blame
5.25 kB
import torch
import copy
import os
from diffusers import DDPMScheduler
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
class RelationShipConvolution(torch.nn.Module):
def __init__(self, conv_in_pretrained, conv_in_curr, r):
super(RelationShipConvolution, self).__init__()
self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained)
self.conv_in_curr = copy.deepcopy(conv_in_curr)
self.r = r
def forward(self, x):
x1 = self.conv_in_pretrained(x).detach()
x2 = self.conv_in_curr(x)
return x1 * (1 - self.r) + x2 * self.r
class PrimaryModel:
def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'):
self.backbone_diffusion_path = backbone_diffusion_path
self.global_unet = None
self.global_vae = None
self.global_tokenizer = None
self.global_text_encoder = None
self.global_scheduler = None
@staticmethod
def _load_model(path, model_class, unet_mode=False):
model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda')
return model
def one_step_scheduler(self):
noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler")
noise_scheduler_1step.set_timesteps(1, device="cuda")
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
return noise_scheduler_1step
def skip_connections(self, vae):
vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.ignore_skip = False
return vae
def weights_adapter(self, p_ckpt, model_name):
if model_name == '350k-adapter':
home = get_s2i_home()
sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu")
sd = torch.load(p_ckpt, map_location="cpu")
sd.update(sd_sketch)
return sd
else:
sd = torch.load(p_ckpt, map_location="cpu")
return sd
def from_pretrained(self, model_name, r):
if self.global_tokenizer is None:
# self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
# subfolder="tokenizer")
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
if self.global_text_encoder is None:
self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path,
subfolder="text_encoder").to(device='cuda')
if self.global_scheduler is None:
self.global_scheduler = self.one_step_scheduler()
if self.global_vae is None:
self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL)
self.global_vae = self.skip_connections(self.global_vae)
if self.global_unet is None:
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
p_ckpt_path = download_models()
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
sd = self.weights_adapter(p_ckpt, model_name)
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
target_modules=sd["unet_lora_target_modules"])
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian",
target_modules=sd["vae_lora_target_modules"])
self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
_sd_vae = self.global_vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
self.global_vae.load_state_dict(_sd_vae)
self.global_unet.add_adapter(unet_lora_config)
_sd_unet = self.global_unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
self.global_unet.load_state_dict(_sd_unet, strict=False)