Spaces:
Running
on
Zero
Running
on
Zero
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline | |
from transformers import CLIPVisionModelWithProjection | |
import torch | |
from copy import deepcopy | |
ENABLE_CPU_CACHE = False | |
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5" | |
cached_models = {} # cache for models to avoid repeated loading, key is model name | |
def cache_model(func): | |
def wrapper(*args, **kwargs): | |
if ENABLE_CPU_CACHE: | |
model_name = func.__name__ + str(args) + str(kwargs) | |
if model_name not in cached_models: | |
cached_models[model_name] = func(*args, **kwargs) | |
return cached_models[model_name] | |
else: | |
return func(*args, **kwargs) | |
return wrapper | |
def copied_cache_model(func): | |
def wrapper(*args, **kwargs): | |
if ENABLE_CPU_CACHE: | |
model_name = func.__name__ + str(args) + str(kwargs) | |
if model_name not in cached_models: | |
cached_models[model_name] = func(*args, **kwargs) | |
return deepcopy(cached_models[model_name]) | |
else: | |
return func(*args, **kwargs) | |
return wrapper | |
def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): | |
if ckpt_or_pretrained.endswith(".safetensors"): | |
pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) | |
else: | |
pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) | |
return pipe | |
def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): | |
model_kwargs = dict( | |
torch_dtype=torch_dtype, | |
requires_safety_checker=False, | |
safety_checker=None, | |
) | |
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( | |
base_model, | |
StableDiffusionPipeline, | |
**model_kwargs | |
) | |
pipe.to("cpu") | |
return pipe.components | |
def load_controlnet(controlnet_path, torch_dtype=torch.float16): | |
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) | |
return controlnet | |
def load_image_encoder(): | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
"h94/IP-Adapter", | |
subfolder="models/image_encoder", | |
torch_dtype=torch.float16, | |
) | |
return image_encoder | |
def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): | |
model_kwargs = dict( | |
torch_dtype=torch_dtype, | |
# device_map=device, | |
requires_safety_checker=False, | |
safety_checker=None, | |
) | |
components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) | |
model_kwargs.update(components) | |
model_kwargs.update(kwargs) | |
if controlnet is not None: | |
if isinstance(controlnet, list): | |
controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] | |
else: | |
controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) | |
model_kwargs.update(controlnet=controlnet) | |
if pipeline_class is None: | |
if controlnet is not None: | |
pipeline_class = StableDiffusionControlNetPipeline | |
else: | |
pipeline_class = StableDiffusionPipeline | |
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( | |
base_model, | |
pipeline_class, | |
**model_kwargs | |
) | |
if ip_adapter: | |
image_encoder = load_image_encoder() | |
pipe.image_encoder = image_encoder | |
if plus_model: | |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") | |
else: | |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") | |
pipe.set_ip_adapter_scale(1.0) | |
else: | |
pipe.unload_ip_adapter() | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
if model_cpu_offload_seq is None: | |
if isinstance(pipe, StableDiffusionControlNetPipeline): | |
pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" | |
elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): | |
pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" | |
else: | |
pipe.model_cpu_offload_seq = model_cpu_offload_seq | |
if enable_sequential_cpu_offload: | |
pipe.enable_sequential_cpu_offload() | |
else: | |
pass | |
pipe.enable_model_cpu_offload() | |
if vae_slicing: | |
pipe.enable_vae_slicing() | |
import gc | |
gc.collect() | |
return pipe | |