Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import copy | |
import os | |
from diffusers import DDPMScheduler | |
from transformers import AutoTokenizer, CLIPTextModel | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from peft import LoraConfig | |
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home | |
class RelationShipConvolution(torch.nn.Module): | |
def __init__(self, conv_in_pretrained, conv_in_curr, r): | |
super(RelationShipConvolution, self).__init__() | |
self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained) | |
self.conv_in_curr = copy.deepcopy(conv_in_curr) | |
self.r = r | |
def forward(self, x): | |
x1 = self.conv_in_pretrained(x).detach() | |
x2 = self.conv_in_curr(x) | |
return x1 * (1 - self.r) + x2 * self.r | |
class PrimaryModel: | |
def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'): | |
self.backbone_diffusion_path = backbone_diffusion_path | |
self.global_unet = None | |
self.global_vae = None | |
self.global_tokenizer = None | |
self.global_text_encoder = None | |
self.global_scheduler = None | |
def _load_model(path, model_class, unet_mode=False): | |
model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda') | |
return model | |
def one_step_scheduler(self): | |
noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler") | |
noise_scheduler_1step.set_timesteps(1, device="cuda") | |
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() | |
return noise_scheduler_1step | |
def skip_connections(self, vae): | |
vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) | |
vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) | |
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.ignore_skip = False | |
return vae | |
def weights_adapter(self, p_ckpt, model_name): | |
if model_name == '350k-adapter': | |
home = get_s2i_home() | |
sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu") | |
sd = torch.load(p_ckpt, map_location="cpu") | |
sd.update(sd_sketch) | |
return sd | |
else: | |
sd = torch.load(p_ckpt, map_location="cpu") | |
return sd | |
def from_pretrained(self, model_name, r): | |
if self.global_tokenizer is None: | |
# self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path, | |
# subfolder="tokenizer") | |
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2") | |
if self.global_text_encoder is None: | |
self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path, | |
subfolder="text_encoder").to(device='cuda') | |
if self.global_scheduler is None: | |
self.global_scheduler = self.one_step_scheduler() | |
if self.global_vae is None: | |
self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL) | |
self.global_vae = self.skip_connections(self.global_vae) | |
if self.global_unet is None: | |
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True) | |
p_ckpt_path = download_models() | |
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path) | |
sd = self.weights_adapter(p_ckpt, model_name) | |
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in) | |
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r) | |
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", | |
target_modules=sd["unet_lora_target_modules"]) | |
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", | |
target_modules=sd["vae_lora_target_modules"]) | |
self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip") | |
_sd_vae = self.global_vae.state_dict() | |
for k in sd["state_dict_vae"]: | |
_sd_vae[k] = sd["state_dict_vae"][k] | |
self.global_vae.load_state_dict(_sd_vae) | |
self.global_unet.add_adapter(unet_lora_config) | |
_sd_unet = self.global_unet.state_dict() | |
for k in sd["state_dict_unet"]: | |
_sd_unet[k] = sd["state_dict_unet"][k] | |
self.global_unet.load_state_dict(_sd_unet, strict=False) | |