File size: 5,570 Bytes
55a3c9a
 
e2053b7
55a3c9a
8811405
55a3c9a
 
e4c85fa
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8811405
 
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4c85fa
 
 
 
 
 
 
 
 
 
55a3c9a
a2fc6fa
cc3415b
8811405
a2fc6fa
cc3415b
8811405
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4c85fa
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import copy
import os
from diffusers import DDPMScheduler
from transformers import AutoTokenizer, CLIPTextModel, pipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home


class RelationShipConvolution(torch.nn.Module):
    def __init__(self, conv_in_pretrained, conv_in_curr, r):
        super(RelationShipConvolution, self).__init__()
        self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained)
        self.conv_in_curr = copy.deepcopy(conv_in_curr)
        self.r = r

    def forward(self, x):
        x1 = self.conv_in_pretrained(x).detach()
        x2 = self.conv_in_curr(x)
        return x1 * (1 - self.r) + x2 * self.r


class PrimaryModel:
    def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'):
        self.backbone_diffusion_path = backbone_diffusion_path
        self.global_unet = None
        self.global_vae = None
        self.global_tokenizer = None
        self.global_text_encoder = None
        self.global_scheduler = None
        self.global_medium_prompt = None
        self.global_long_prompt = None

    @staticmethod
    def _load_model(path, model_class, unet_mode=False):
        model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda')
        return model


    def one_step_scheduler(self):
        noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler")
        noise_scheduler_1step.set_timesteps(1, device="cuda")
        noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
        return noise_scheduler_1step

    def skip_connections(self, vae):
        vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
        vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
        vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.ignore_skip = False
        return vae
    def weights_adapter(self, p_ckpt, model_name):
        if model_name == '350k-adapter':
            home = get_s2i_home()
            sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu")
            sd = torch.load(p_ckpt, map_location="cpu")
            sd.update(sd_sketch)
            return sd 
        else:
            sd = torch.load(p_ckpt, map_location="cpu")
            return sd
    def from_pretrained(self, model_name, r):
        if self.global_medium_prompt is None:
            self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda' if torch.cuda.is_available() else 'cpu')

        if self.global_long_prompt is None:
            self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda' if torch.cuda.is_available() else 'cpu')

        if self.global_tokenizer is None:
            self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")

        if self.global_text_encoder is None:
            self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path,
                                                                     subfolder="text_encoder").to(device='cuda')

        if self.global_scheduler is None:
            self.global_scheduler = self.one_step_scheduler()

        if self.global_vae is None:
            self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL)
            self.global_vae = self.skip_connections(self.global_vae)

        if self.global_unet is None:
            self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
            p_ckpt_path = download_models()
            p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
            sd = self.weights_adapter(p_ckpt, model_name)
            conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
            self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
            unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
                                          target_modules=sd["unet_lora_target_modules"])
            vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian",
                                         target_modules=sd["vae_lora_target_modules"])
            self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
            _sd_vae = self.global_vae.state_dict()
            for k in sd["state_dict_vae"]:
                _sd_vae[k] = sd["state_dict_vae"][k]
            self.global_vae.load_state_dict(_sd_vae)
            self.global_unet.add_adapter(unet_lora_config)
            _sd_unet = self.global_unet.state_dict()
            for k in sd["state_dict_unet"]:
                _sd_unet[k] = sd["state_dict_unet"][k]
            self.global_unet.load_state_dict(_sd_unet, strict=False)