Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,100 Bytes
37aeb5b 8cb0437 37aeb5b ecc65f1 37aeb5b 8981664 37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection
import torch
from copy import deepcopy
ENABLE_CPU_CACHE = False
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
cached_models = {} # cache for models to avoid repeated loading, key is model name
def cache_model(func):
def wrapper(*args, **kwargs):
if ENABLE_CPU_CACHE:
model_name = func.__name__ + str(args) + str(kwargs)
if model_name not in cached_models:
cached_models[model_name] = func(*args, **kwargs)
return cached_models[model_name]
else:
return func(*args, **kwargs)
return wrapper
def copied_cache_model(func):
def wrapper(*args, **kwargs):
if ENABLE_CPU_CACHE:
model_name = func.__name__ + str(args) + str(kwargs)
if model_name not in cached_models:
cached_models[model_name] = func(*args, **kwargs)
return deepcopy(cached_models[model_name])
else:
return func(*args, **kwargs)
return wrapper
def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
if ckpt_or_pretrained.endswith(".safetensors"):
pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
else:
pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
return pipe
@copied_cache_model
def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
model_kwargs = dict(
torch_dtype=torch_dtype,
requires_safety_checker=False,
safety_checker=None,
)
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
base_model,
StableDiffusionPipeline,
**model_kwargs
)
pipe.to("cpu")
return pipe.components
@cache_model
def load_controlnet(controlnet_path, torch_dtype=torch.float16):
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
return controlnet
@cache_model
def load_image_encoder():
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
)
return image_encoder
def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
model_kwargs = dict(
torch_dtype=torch_dtype,
# device_map=device,
requires_safety_checker=False,
safety_checker=None,
)
components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
model_kwargs.update(components)
model_kwargs.update(kwargs)
if controlnet is not None:
if isinstance(controlnet, list):
controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
else:
controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
model_kwargs.update(controlnet=controlnet)
if pipeline_class is None:
if controlnet is not None:
pipeline_class = StableDiffusionControlNetPipeline
else:
pipeline_class = StableDiffusionPipeline
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
base_model,
pipeline_class,
**model_kwargs
)
if ip_adapter:
image_encoder = load_image_encoder()
pipe.image_encoder = image_encoder
if plus_model:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
else:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
pipe.set_ip_adapter_scale(1.0)
else:
pipe.unload_ip_adapter()
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if model_cpu_offload_seq is None:
if isinstance(pipe, StableDiffusionControlNetPipeline):
pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
else:
pipe.model_cpu_offload_seq = model_cpu_offload_seq
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
else:
pass
pipe.enable_model_cpu_offload()
if vae_slicing:
pipe.enable_vae_slicing()
import gc
gc.collect()
return pipe
|