File size: 5,098 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cb0437
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8981664
37aeb5b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection
import torch
from copy import deepcopy

ENABLE_CPU_CACHE = False
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"

cached_models = {}  # cache for models to avoid repeated loading, key is model name
def cache_model(func):
    def wrapper(*args, **kwargs):
        if ENABLE_CPU_CACHE:
            model_name = func.__name__ + str(args) + str(kwargs)
            if model_name not in cached_models:
                cached_models[model_name] = func(*args, **kwargs)
            return cached_models[model_name]
        else:
            return func(*args, **kwargs)
    return wrapper

def copied_cache_model(func):
    def wrapper(*args, **kwargs):
        if ENABLE_CPU_CACHE:
            model_name = func.__name__ + str(args) + str(kwargs)
            if model_name not in cached_models:
                cached_models[model_name] = func(*args, **kwargs)
            return deepcopy(cached_models[model_name])
        else:
            return func(*args, **kwargs)
    return wrapper

def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
    if ckpt_or_pretrained.endswith(".safetensors"):
        pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
    else:
        pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
    return pipe

@copied_cache_model
def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
    model_kwargs = dict(
        torch_dtype=torch_dtype,
        requires_safety_checker=False, 
        safety_checker=None,
    )
    pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
        base_model,
        StableDiffusionPipeline,
        **model_kwargs
    )
    pipe.to("cpu")
    return pipe.components

@cache_model
def load_controlnet(controlnet_path, torch_dtype=torch.float16):
    controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
    return controlnet

@cache_model
def load_image_encoder():
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        "h94/IP-Adapter",
        subfolder="models/image_encoder",
        torch_dtype=torch.float16,
    )
    return image_encoder

def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
    model_kwargs = dict(
        torch_dtype=torch_dtype, 
        device_map=device,
        requires_safety_checker=False, 
        safety_checker=None,
    )
    components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
    model_kwargs.update(components)
    model_kwargs.update(kwargs)
    
    if controlnet is not None:
        if isinstance(controlnet, list):
            controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
        else:
            controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
        model_kwargs.update(controlnet=controlnet)
    
    if pipeline_class is None:
        if controlnet is not None:
            pipeline_class = StableDiffusionControlNetPipeline
        else:
            pipeline_class = StableDiffusionPipeline
    
    pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
        base_model,
        pipeline_class,
        **model_kwargs
    )

    if ip_adapter:
        image_encoder = load_image_encoder()
        pipe.image_encoder = image_encoder
        if plus_model:
            pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
        else:
            pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
        pipe.set_ip_adapter_scale(1.0)
    else:
        pipe.unload_ip_adapter()
    
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

    if model_cpu_offload_seq is None:
        if isinstance(pipe, StableDiffusionControlNetPipeline):
            pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
        elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
            pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
    else:
        pipe.model_cpu_offload_seq = model_cpu_offload_seq
    
    if enable_sequential_cpu_offload:
        pipe.enable_sequential_cpu_offload()
    else:
        pass
        pipe.enable_model_cpu_offload()
    if vae_slicing:
        pipe.enable_vae_slicing()
        
    import gc
    gc.collect()
    return pipe