File size: 5,570 Bytes
55a3c9a e2053b7 55a3c9a 8811405 55a3c9a e4c85fa 55a3c9a 8811405 55a3c9a e4c85fa 55a3c9a a2fc6fa cc3415b 8811405 a2fc6fa cc3415b 8811405 55a3c9a e4c85fa 55a3c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
import copy
import os
from diffusers import DDPMScheduler
from transformers import AutoTokenizer, CLIPTextModel, pipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
class RelationShipConvolution(torch.nn.Module):
def __init__(self, conv_in_pretrained, conv_in_curr, r):
super(RelationShipConvolution, self).__init__()
self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained)
self.conv_in_curr = copy.deepcopy(conv_in_curr)
self.r = r
def forward(self, x):
x1 = self.conv_in_pretrained(x).detach()
x2 = self.conv_in_curr(x)
return x1 * (1 - self.r) + x2 * self.r
class PrimaryModel:
def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'):
self.backbone_diffusion_path = backbone_diffusion_path
self.global_unet = None
self.global_vae = None
self.global_tokenizer = None
self.global_text_encoder = None
self.global_scheduler = None
self.global_medium_prompt = None
self.global_long_prompt = None
@staticmethod
def _load_model(path, model_class, unet_mode=False):
model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda')
return model
def one_step_scheduler(self):
noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler")
noise_scheduler_1step.set_timesteps(1, device="cuda")
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
return noise_scheduler_1step
def skip_connections(self, vae):
vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
vae.decoder.ignore_skip = False
return vae
def weights_adapter(self, p_ckpt, model_name):
if model_name == '350k-adapter':
home = get_s2i_home()
sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu")
sd = torch.load(p_ckpt, map_location="cpu")
sd.update(sd_sketch)
return sd
else:
sd = torch.load(p_ckpt, map_location="cpu")
return sd
def from_pretrained(self, model_name, r):
if self.global_medium_prompt is None:
self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda' if torch.cuda.is_available() else 'cpu')
if self.global_long_prompt is None:
self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda' if torch.cuda.is_available() else 'cpu')
if self.global_tokenizer is None:
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
if self.global_text_encoder is None:
self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path,
subfolder="text_encoder").to(device='cuda')
if self.global_scheduler is None:
self.global_scheduler = self.one_step_scheduler()
if self.global_vae is None:
self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL)
self.global_vae = self.skip_connections(self.global_vae)
if self.global_unet is None:
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
p_ckpt_path = download_models()
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
sd = self.weights_adapter(p_ckpt, model_name)
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
target_modules=sd["unet_lora_target_modules"])
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian",
target_modules=sd["vae_lora_target_modules"])
self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
_sd_vae = self.global_vae.state_dict()
for k in sd["state_dict_vae"]:
_sd_vae[k] = sd["state_dict_vae"][k]
self.global_vae.load_state_dict(_sd_vae)
self.global_unet.add_adapter(unet_lora_config)
_sd_unet = self.global_unet.state_dict()
for k in sd["state_dict_unet"]:
_sd_unet[k] = sd["state_dict_unet"][k]
self.global_unet.load_state_dict(_sd_unet, strict=False)
|