import os import time from typing import Union from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import from modules.shared import log from modules import errors, sd_models from modules.control.units import detect what = 'T2I-Adapter' debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: CONTROL') predefined_sd15 = { 'Segment': 'TencentARC/t2iadapter_seg_sd14v1', 'Zoe Depth': 'TencentARC/t2iadapter_zoedepth_sd15v1', 'OpenPose': 'TencentARC/t2iadapter_openpose_sd14v1', 'KeyPose': 'TencentARC/t2iadapter_keypose_sd14v1', 'Color': 'TencentARC/t2iadapter_color_sd14v1', 'Depth v1': 'TencentARC/t2iadapter_depth_sd14v1', 'Depth v2': 'TencentARC/t2iadapter_depth_sd15v2', 'Canny v1': 'TencentARC/t2iadapter_canny_sd14v1', 'Canny v2': 'TencentARC/t2iadapter_canny_sd15v2', 'Sketch v1': 'TencentARC/t2iadapter_sketch_sd14v1', 'Sketch v2': 'TencentARC/t2iadapter_sketch_sd15v2', } predefined_sdxl = { 'Canny XL': 'TencentARC/t2i-adapter-canny-sdxl-1.0', 'LineArt XL': 'TencentARC/t2i-adapter-lineart-sdxl-1.0', 'Sketch XL': 'TencentARC/t2i-adapter-sketch-sdxl-1.0', 'Zoe Depth XL': 'TencentARC/t2i-adapter-depth-zoe-sdxl-1.0', 'OpenPose XL': 'TencentARC/t2i-adapter-openpose-sdxl-1.0', 'Midas Depth XL': 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0', } models = {} all_models = {} all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/adapter' def list_models(refresh=False): import modules.shared global models # pylint: disable=global-statement if not refresh and len(models) > 0: return models models = {} if modules.shared.sd_model_type == 'none': models = ['None'] elif modules.shared.sd_model_type == 'sdxl': models = ['None'] + sorted(predefined_sdxl) elif modules.shared.sd_model_type == 'sd': models = ['None'] + sorted(predefined_sd15) else: log.warning(f'Control {what} model list failed: unknown model type') models = ['None'] + sorted(list(predefined_sd15) + list(predefined_sdxl)) debug(f'Control list {what}: path={cache_dir} models={models}') return models class AdapterModel(T2IAdapter): pass class Adapter(): def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): self.model: AdapterModel = None self.model_id: str = model_id self.device = device self.dtype = dtype self.load_config = { 'cache_dir': cache_dir } if load_config is not None: self.load_config.update(load_config) if model_id is not None: self.load() def reset(self): if self.model is not None: debug(f'Control {what} model unloaded') self.model = None self.model_id = None def load(self, model_id: str = None) -> str: try: t0 = time.time() model_id = model_id or self.model_id if model_id is None or model_id == 'None': self.reset() return model_path = all_models[model_id] if model_path is None: log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') return log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') self.model = T2IAdapter.from_pretrained(model_path, **self.load_config) if self.device is not None: self.model.to(self.device) if self.dtype is not None: self.model.to(self.dtype) t1 = time.time() self.model_id = model_id log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') return f'{what} loaded model: {model_id}' except Exception as e: log.error(f'Control {what} model load failed: id="{model_id}" error={e}') errors.display(e, f'Control {what} load') return f'{what} failed to load model: {model_id}' class AdapterPipeline(): def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): t0 = time.time() self.orig_pipeline = pipeline self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None if pipeline is None: log.error(f'Control {what} pipeline: model not loaded') return if isinstance(adapter, list) and len(adapter) > 1: adapter = MultiAdapter(adapter) adapter.to(device=pipeline.device, dtype=pipeline.dtype) if detect.is_sdxl(pipeline): self.pipeline = StableDiffusionXLAdapterPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), adapter=adapter, ) sd_models.move_model(self.pipeline, pipeline.device) elif detect.is_sd15(pipeline): self.pipeline = StableDiffusionAdapterPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), requires_safety_checker=False, safety_checker=None, adapter=adapter, ) sd_models.move_model(self.pipeline, pipeline.device) else: log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') return if dtype is not None and self.pipeline is not None: self.pipeline.dtype = dtype t1 = time.time() if self.pipeline is not None: log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') else: log.error(f'Control {what} pipeline: not initialized') def restore(self): self.pipeline = None return self.orig_pipeline