| """CRS-Diff modular loading utilities for custom diffusers pipeline.""" |
|
|
| import importlib |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Dict, Optional, Union |
|
|
| import torch |
| from diffusers import DDIMScheduler |
|
|
| _PIPELINE_DIR = Path(__file__).resolve().parent |
| if str(_PIPELINE_DIR) not in sys.path: |
| sys.path.insert(0, str(_PIPELINE_DIR)) |
|
|
| _COMPONENT_NAMES = ( |
| "unet", |
| "vae", |
| "text_encoder", |
| "local_adapter", |
| "global_content_adapter", |
| "global_text_adapter", |
| "metadata_encoder", |
| ) |
|
|
| _TARGET_MAP = { |
| "crs_core.local_adapter.LocalControlUNetModel": "crs_core.local_adapter.LocalControlUNetModel", |
| "crs_core.autoencoder.AutoencoderKL": "crs_core.autoencoder.AutoencoderKL", |
| "crs_core.text_encoder.FrozenCLIPEmbedder": "crs_core.text_encoder.FrozenCLIPEmbedder", |
| "crs_core.local_adapter.LocalAdapter": "crs_core.local_adapter.LocalAdapter", |
| "crs_core.global_adapter.GlobalContentAdapter": "crs_core.global_adapter.GlobalContentAdapter", |
| "crs_core.global_adapter.GlobalTextAdapter": "crs_core.global_adapter.GlobalTextAdapter", |
| "crs_core.metadata_embedding.metadata_embeddings": "crs_core.metadata_embedding.metadata_embeddings", |
| } |
|
|
|
|
| def ensure_model_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: |
| """Resolve local path or download HF repo snapshot.""" |
| path = Path(pretrained_model_name_or_path) |
| if not path.exists(): |
| from huggingface_hub import snapshot_download |
|
|
| path = Path(snapshot_download(str(pretrained_model_name_or_path))) |
| path = path.resolve() |
| if str(path) not in sys.path: |
| sys.path.insert(0, str(path)) |
| return path |
|
|
|
|
| def resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]: |
| """Resolve to folder containing model_index.json.""" |
| if not candidate: |
| return None |
| path = ensure_model_path(candidate) |
| if (path / "model_index.json").exists(): |
| return path |
| cur = path |
| for _ in range(5): |
| parent = cur.parent |
| if parent == cur: |
| break |
| if (parent / "model_index.json").exists(): |
| return parent |
| cur = parent |
| return None |
|
|
|
|
| def _get_class(target: str): |
| module_path, cls_name = target.rsplit(".", 1) |
| mod = importlib.import_module(module_path) |
| return getattr(mod, cls_name) |
|
|
|
|
| def load_component(model_root: Path, name: str): |
| """Load single split component from <repo>/<name>/.""" |
| root = Path(model_root) |
| comp_path = root / name |
| with (comp_path / "config.json").open("r", encoding="utf-8") as f: |
| cfg = json.load(f) |
| target = cfg.pop("_target", None) |
| if not target: |
| raise ValueError(f"No _target in {comp_path / 'config.json'}") |
| target = _TARGET_MAP.get(target, target) |
| cls_ref = _get_class(target) |
| params = {k: v for k, v in cfg.items() if not k.startswith("_")} |
| module = cls_ref(**params) |
|
|
| weight_file = comp_path / "diffusion_pytorch_model.safetensors" |
| if weight_file.exists(): |
| from safetensors.torch import load_file |
|
|
| state = load_file(str(weight_file)) |
| module.load_state_dict(state, strict=True) |
| module.eval() |
| return module |
|
|
|
|
| class CRSModelWrapper(torch.nn.Module): |
| """Wrap split components to mimic CRSControlNet APIs used by pipeline.""" |
|
|
| def __init__( |
| self, |
| unet, |
| vae, |
| text_encoder, |
| local_adapter, |
| global_content_adapter, |
| global_text_adapter, |
| metadata_encoder, |
| channels: int = 4, |
| ): |
| super().__init__() |
| self.model = torch.nn.Module() |
| self.model.add_module("diffusion_model", unet) |
| self.first_stage_model = vae |
| self.cond_stage_model = text_encoder |
| self.local_adapter = local_adapter |
| self.global_content_adapter = global_content_adapter |
| self.global_text_adapter = global_text_adapter |
| self.metadata_emb = metadata_encoder |
| self.local_control_scales = [1.0] * 13 |
| self.channels = channels |
|
|
| @torch.no_grad() |
| def get_learned_conditioning(self, prompts): |
| if hasattr(self.cond_stage_model, "device"): |
| self.cond_stage_model.device = str(next(self.parameters()).device) |
| return self.cond_stage_model.encode(prompts) |
|
|
| def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, **kwargs): |
| del kwargs |
| if metadata is None: |
| metadata = cond["metadata"] |
| cond_txt = torch.cat(cond["c_crossattn"], 1) |
|
|
| if cond.get("global_control") is not None and cond["global_control"][0] is not None: |
| metadata = self.metadata_emb(metadata) |
| content_t, _ = cond["global_control"][0].chunk(2, dim=1) |
| global_control = self.global_content_adapter(content_t) |
| cond_txt = self.global_text_adapter(cond_txt) |
| cond_txt = torch.cat([cond_txt, global_strength * global_control], dim=1) |
|
|
| local_control = None |
| if cond.get("local_control") is not None and cond["local_control"][0] is not None: |
| local_control = torch.cat(cond["local_control"], 1) |
| local_control = self.local_adapter( |
| x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control |
| ) |
| local_control = [c * s for c, s in zip(local_control, self.local_control_scales)] |
|
|
| return self.model.diffusion_model( |
| x=x_noisy, |
| timesteps=t, |
| metadata=metadata, |
| context=cond_txt, |
| local_control=local_control, |
| meta=True, |
| ) |
|
|
| def decode_first_stage(self, z): |
| return self.first_stage_model.decode(z) |
|
|
|
|
| def load_components(model_root: Union[str, Path]) -> Dict[str, object]: |
| """Load pipeline components from split directories.""" |
| root = ensure_model_path(model_root) |
| scheduler = DDIMScheduler.from_pretrained(root, subfolder="scheduler") |
|
|
| scale_factor = 0.18215 |
| channels = 4 |
| if (root / "model_index.json").exists(): |
| with (root / "model_index.json").open("r", encoding="utf-8") as f: |
| idx = json.load(f) |
| scale_factor = float(idx.get("scale_factor", scale_factor)) |
| channels = int(idx.get("channels", channels)) |
|
|
| has_split_components = all((root / name / "config.json").exists() for name in _COMPONENT_NAMES) |
| if not has_split_components: |
| missing = [name for name in _COMPONENT_NAMES if not (root / name / "config.json").exists()] |
| raise FileNotFoundError( |
| f"CRS-Diff split component export incomplete. Missing: {missing}. " |
| "Expected split folders with config.json and weights." |
| ) |
|
|
| loaded = {name: load_component(root, name) for name in _COMPONENT_NAMES} |
| crs_model = CRSModelWrapper( |
| unet=loaded["unet"], |
| vae=loaded["vae"], |
| text_encoder=loaded["text_encoder"], |
| local_adapter=loaded["local_adapter"], |
| global_content_adapter=loaded["global_content_adapter"], |
| global_text_adapter=loaded["global_text_adapter"], |
| metadata_encoder=loaded["metadata_encoder"], |
| channels=channels, |
| ) |
|
|
| return {"crs_model": crs_model, "scheduler": scheduler, "scale_factor": scale_factor} |
|
|