Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from DeepCache import DeepCacheSDHelper | |
from diffusers import ControlNetModel | |
from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttnProcessor2_0 | |
from .config import Config | |
from .logger import Logger | |
from .upscaler import RealESRGAN | |
from .utils import timer | |
class Loader: | |
""" | |
A lazy-loading resource manager for Stable Diffusion pipelines. Lifecycles are managed by | |
comparing the current state with desired. Can be used as a singleton when created by the | |
`get_loader()` helper. | |
Usage: | |
loader = get_loader(singleton=True) | |
loader.load( | |
pipeline_id="controlnet_txt2img", | |
ip_adapter_model="full-face", | |
model="XpucT/Reliberate", | |
scheduler="UniPC", | |
controlnet_annotator="canny", | |
deepcache_interval=2, | |
scale=2, | |
use_karras=True | |
) | |
""" | |
def __init__(self): | |
self.model = "" | |
self.pipeline = None | |
self.upscaler = None | |
self.controlnet = None | |
self.annotator = "" # controlnet annotator (canny) | |
self.ip_adapter = "" # ip-adapter kind (full-face or plus) | |
self.log = Logger("Loader") | |
def should_unload_upscaler(self, scale=1): | |
return self.upscaler is not None and self.upscaler.scale != scale | |
def should_unload_deepcache(self, cache_interval=1): | |
has_deepcache = hasattr(self.pipeline, "deepcache") | |
if has_deepcache and cache_interval == 1: | |
return True | |
if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval: | |
# Unload if interval is different so it can be reloaded | |
return True | |
return False | |
def should_unload_ip_adapter(self, ip_adapter_model=""): | |
if not self.ip_adapter: | |
return False | |
if not ip_adapter_model: | |
return True | |
if self.ip_adapter != ip_adapter_model: | |
# Unload if model is different so it can be reloaded | |
return True | |
return False | |
def should_unload_controlnet(self, pipeline_id="", annotator=""): | |
if self.controlnet is None: | |
return False | |
if self.annotator != annotator: | |
return True | |
if not pipeline_id.startswith("controlnet_"): | |
return True | |
return False | |
def should_unload_pipeline(self, model=""): | |
if self.pipeline is None: | |
return False | |
if self.model != model: | |
return True | |
return False | |
# Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300 | |
def unload_ip_adapter(self): | |
# Remove the image encoder if text-to-image | |
if isinstance(self.pipeline, Config.PIPELINES["txt2img"]): | |
self.pipeline.image_encoder = None | |
self.pipeline.register_to_config(image_encoder=[None, None]) | |
# Remove hidden projection layer added by IP-Adapter | |
self.pipeline.unet.encoder_hid_proj = None | |
self.pipeline.unet.config.encoder_hid_dim_type = None | |
# Remove the feature extractor | |
self.pipeline.feature_extractor = None | |
self.pipeline.register_to_config(feature_extractor=[None, None]) | |
# Replace the custom attention processors with defaults | |
attn_procs = {} | |
for name, value in self.pipeline.unet.attn_processors.items(): | |
attn_processor_class = AttnProcessor2_0() # raises if not torch 2 | |
attn_procs[name] = ( | |
attn_processor_class | |
if isinstance(value, IPAdapterAttnProcessor2_0) | |
else value.__class__() | |
) | |
self.pipeline.unet.set_attn_processor(attn_procs) | |
self.ip_adapter = "" | |
def unload_all( | |
self, | |
pipeline_id="", | |
ip_adapter_model="", | |
model="", | |
controlnet_annotator="", | |
deepcache_interval=1, | |
scale=1, | |
): | |
if self.should_unload_deepcache(deepcache_interval): # remove deepcache first | |
self.log.info("Disabling DeepCache") | |
self.pipeline.deepcache.disable() | |
delattr(self.pipeline, "deepcache") | |
if self.should_unload_ip_adapter(ip_adapter_model): | |
self.log.info("Unloading IP-Adapter") | |
self.unload_ip_adapter() | |
if self.should_unload_controlnet(pipeline_id, controlnet_annotator): | |
self.log.info("Unloading ControlNet") | |
self.controlnet = None | |
self.annotator = "" | |
if self.should_unload_upscaler(scale): | |
self.log.info("Unloading upscaler") | |
self.upscaler = None | |
if self.should_unload_pipeline(model): | |
self.log.info("Unloading pipeline") | |
self.pipeline = None | |
self.model = "" | |
def should_load_upscaler(self, scale=1): | |
return self.upscaler is None and scale > 1 | |
def should_load_deepcache(self, cache_interval=1): | |
has_deepcache = hasattr(self.pipeline, "deepcache") | |
if not has_deepcache and cache_interval > 1: | |
return True | |
return False | |
def should_load_controlnet(self, pipeline_id=""): | |
return self.controlnet is None and pipeline_id.startswith("controlnet_") | |
def should_load_ip_adapter(self, ip_adapter_model=""): | |
has_ip_adapter = ( | |
hasattr(self.pipeline.unet, "encoder_hid_proj") | |
and self.pipeline.unet.config.encoder_hid_dim_type == "ip_image_proj" | |
) | |
return not has_ip_adapter and ip_adapter_model != "" | |
def should_load_scheduler(self, cls, use_karras=False): | |
has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas") | |
if not isinstance(self.pipeline.scheduler, cls): | |
return True | |
if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras: | |
return True | |
return False | |
def should_load_pipeline(self, pipeline_id=""): | |
if self.pipeline is None: | |
return True | |
if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]): | |
return True | |
return False | |
def load_upscaler(self, scale=1): | |
with timer(f"Loading {scale}x upscaler", logger=self.log.info): | |
self.upscaler = RealESRGAN(scale, device=self.pipeline.device) | |
self.upscaler.load_weights() | |
def load_deepcache(self, cache_interval=1): | |
self.log.info(f"Enabling DeepCache interval {cache_interval}") | |
self.pipeline.deepcache = DeepCacheSDHelper(self.pipeline) | |
self.pipeline.deepcache.set_params(cache_interval=cache_interval) | |
self.pipeline.deepcache.enable() | |
def load_controlnet(self, controlnet_annotator): | |
with timer("Loading ControlNet", logger=self.log.info): | |
self.controlnet = ControlNetModel.from_pretrained( | |
Config.ANNOTATORS[controlnet_annotator], | |
variant="fp16", | |
torch_dtype=torch.float16, | |
) | |
self.annotator = controlnet_annotator | |
def load_ip_adapter(self, ip_adapter_model=""): | |
with timer("Loading IP-Adapter", logger=self.log.info): | |
self.pipeline.load_ip_adapter( | |
"h94/IP-Adapter", | |
subfolder="models", | |
weight_name=f"ip-adapter-{ip_adapter_model}_sd15.safetensors", | |
) | |
self.pipeline.set_ip_adapter_scale(0.5) # 50% works the best | |
self.ip_adapter = ip_adapter_model | |
def load_scheduler(self, cls, use_karras=False, **kwargs): | |
self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}") | |
self.pipeline.scheduler = cls(**kwargs) | |
def load_pipeline( | |
self, | |
pipeline_id, | |
model, | |
**kwargs, | |
): | |
Pipeline = Config.PIPELINES[pipeline_id] | |
# Load from scratch | |
if self.pipeline is None: | |
with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info): | |
if self.controlnet is not None: | |
kwargs["controlnet"] = self.controlnet | |
if model in Config.SINGLE_FILE_MODELS: | |
checkpoint = Config.HF_REPOS[model][0] | |
self.pipeline = Pipeline.from_single_file( | |
f"https://huggingface.co/{model}/{checkpoint}", | |
**kwargs, | |
).to("cuda") | |
else: | |
self.pipeline = Pipeline.from_pretrained(model, **kwargs).to("cuda") | |
# Change to a different one | |
else: | |
with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info): | |
kwargs = {} | |
if self.controlnet is not None: | |
kwargs["controlnet"] = self.controlnet | |
self.pipeline = Pipeline.from_pipe( | |
self.pipeline, | |
**kwargs, | |
).to("cuda") | |
# Update model and disable terminal progress bars | |
self.model = model | |
self.pipeline.set_progress_bar_config(disable=True) | |
def load( | |
self, | |
pipeline_id, | |
ip_adapter_model, | |
model, | |
scheduler, | |
controlnet_annotator, | |
deepcache_interval, | |
scale, | |
use_karras, | |
): | |
Scheduler = Config.SCHEDULERS[scheduler] | |
scheduler_kwargs = { | |
"beta_start": 0.00085, | |
"beta_end": 0.012, | |
"beta_schedule": "scaled_linear", | |
"timestep_spacing": "leading", | |
"steps_offset": 1, | |
} | |
if scheduler not in ["Euler a"]: | |
scheduler_kwargs["use_karras_sigmas"] = use_karras | |
pipeline_kwargs = { | |
"torch_dtype": torch.float16, # defaults to fp32 | |
"safety_checker": None, | |
"requires_safety_checker": False, | |
"scheduler": Scheduler(**scheduler_kwargs), | |
} | |
# Single-file models don't need a variant | |
if model not in Config.SINGLE_FILE_MODELS: | |
pipeline_kwargs["variant"] = "fp16" | |
else: | |
pipeline_kwargs["variant"] = None | |
# Prepare state for loading checks | |
self.unload_all( | |
pipeline_id, | |
ip_adapter_model, | |
model, | |
controlnet_annotator, | |
deepcache_interval, | |
scale, | |
) | |
# Load controlnet model before pipeline | |
if self.should_load_controlnet(pipeline_id): | |
self.load_controlnet(controlnet_annotator) | |
if self.should_load_pipeline(pipeline_id): | |
self.load_pipeline(pipeline_id, model, **pipeline_kwargs) | |
if self.should_load_scheduler(Scheduler, use_karras): | |
self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs) | |
if self.should_load_deepcache(deepcache_interval): | |
self.load_deepcache(deepcache_interval) | |
if self.should_load_ip_adapter(ip_adapter_model): | |
self.load_ip_adapter(ip_adapter_model) | |
if self.should_load_upscaler(scale): | |
self.load_upscaler(scale) | |
# Get a singleton or a new instance of the Loader | |
def get_loader(singleton=False): | |
if not singleton: | |
return Loader() | |
else: | |
if not hasattr(get_loader, "_instance"): | |
get_loader._instance = Loader() | |
assert isinstance(get_loader._instance, Loader) | |
return get_loader._instance | |