| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from diffusers import DDIMScheduler, DiffusionPipeline |
| from diffusers.utils import BaseOutput |
| from PIL import Image |
|
|
| _ROOT = Path(__file__).resolve().parent |
| if str(_ROOT) not in sys.path: |
| sys.path.insert(0, str(_ROOT)) |
|
|
| |
| sys.modules["pipeline"] = sys.modules[__name__] |
|
|
| from modular_pipeline import load_components, resolve_model_root |
|
|
|
|
| @dataclass |
| class CRSDiffPipelineOutput(BaseOutput): |
| images: List[Image.Image] |
|
|
|
|
| class CRSDiffPipeline(DiffusionPipeline): |
| def register_modules(self, **kwargs): |
| for name, module in kwargs.items(): |
| if module is None or ( |
| isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None |
| ): |
| self.register_to_config(**{name: (None, None)}) |
| setattr(self, name, module) |
| elif _is_component_list(module): |
| self.register_to_config(**{name: (module[0], module[1])}) |
| setattr(self, name, module) |
| else: |
| from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple |
|
|
| library, class_name = _fetch_class_library_tuple(module) |
| self.register_to_config(**{name: (library, class_name)}) |
| setattr(self, name, module) |
|
|
| def __init__( |
| self, |
| crs_model=None, |
| scheduler=None, |
| scale_factor: float = 0.18215, |
| model_path: Optional[Union[str, Path]] = None, |
| _name_or_path: Optional[Union[str, Path]] = None, |
| ): |
| super().__init__() |
| if _is_component_list(crs_model) or _is_component_list(scheduler): |
| model_root = ( |
| resolve_model_root(model_path) |
| or resolve_model_root(_name_or_path) |
| or resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None)) |
| ) |
| if model_root is None: |
| raise ValueError( |
| "CRS-Diff received config placeholders but could not resolve model path. " |
| "Pass `model_path` or load via DiffusionPipeline.from_pretrained(<path>, custom_pipeline=...)." |
| ) |
| loaded = load_components(model_root) |
| crs_model = loaded["crs_model"] |
| scheduler = loaded["scheduler"] |
| scale_factor = loaded["scale_factor"] |
|
|
| self.register_modules(crs_model=crs_model, scheduler=scheduler) |
| self.vae_scale_factor = scale_factor |
|
|
| @property |
| def device(self) -> torch.device: |
| params = list(self.crs_model.parameters()) |
| if params: |
| return params[0].device |
| return torch.device("cpu") |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Union[str, Path], |
| device: Optional[Union[str, torch.device]] = None, |
| subfolder: Optional[str] = None, |
| **kwargs, |
| ) -> "CRSDiffPipeline": |
| path = resolve_model_root(pretrained_model_name_or_path) |
| if path is None: |
| raise ValueError(f"Could not resolve CRS-Diff model root from: {pretrained_model_name_or_path}") |
|
|
| subfolder = kwargs.pop("subfolder", subfolder) |
| if subfolder == "scheduler": |
| return DDIMScheduler.from_pretrained(path, subfolder="scheduler") |
|
|
| loaded = load_components(path) |
| pipe = cls(crs_model=loaded["crs_model"], scheduler=loaded["scheduler"], scale_factor=loaded["scale_factor"]) |
| if device is not None: |
| pipe = pipe.to(device) |
| return pipe |
|
|
| def _to_tensor(self, x, device: torch.device, dtype=torch.float32) -> torch.Tensor: |
| if isinstance(x, np.ndarray): |
| x = torch.from_numpy(x) |
| if not isinstance(x, torch.Tensor): |
| raise TypeError("Expected torch.Tensor or np.ndarray for conditioning inputs.") |
| return x.to(device=device, dtype=dtype) |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: Union[str, List[str]], |
| local_control, |
| global_control, |
| metadata, |
| negative_prompt: Union[str, List[str]] = "", |
| num_inference_steps: int = 50, |
| guidance_scale: float = 7.5, |
| eta: float = 0.0, |
| strength: float = 1.0, |
| global_strength: float = 1.0, |
| generator: Optional[torch.Generator] = None, |
| output_type: str = "pil", |
| ) -> CRSDiffPipelineOutput: |
| device = self.device |
| local_control = self._to_tensor(local_control, device=device) |
| global_control = self._to_tensor(global_control, device=device) |
| metadata = self._to_tensor(metadata, device=device) |
|
|
| batch_size = local_control.shape[0] |
| if isinstance(prompt, str): |
| prompt = [prompt] * batch_size |
| if isinstance(negative_prompt, str): |
| negative_prompt = [negative_prompt] * batch_size |
|
|
| if metadata.dim() == 1: |
| metadata = metadata.unsqueeze(0).repeat(batch_size, 1) |
|
|
| cond = { |
| "local_control": [local_control], |
| "c_crossattn": [self.crs_model.get_learned_conditioning(prompt)], |
| "global_control": [global_control], |
| } |
| un_cond = { |
| "local_control": [local_control], |
| "c_crossattn": [self.crs_model.get_learned_conditioning(negative_prompt)], |
| "global_control": [torch.zeros_like(global_control)], |
| } |
|
|
| if hasattr(self.crs_model, "local_control_scales"): |
| self.crs_model.local_control_scales = [strength] * 13 |
|
|
| _, _, h, w = local_control.shape |
| latents = torch.randn( |
| (batch_size, self.crs_model.channels, h // 8, w // 8), |
| generator=generator, |
| device=device, |
| ) |
| latents = latents * self.scheduler.init_noise_sigma |
|
|
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| for t in self.scheduler.timesteps: |
| ts = torch.full((batch_size,), int(t), device=device, dtype=torch.long) |
| if guidance_scale > 1.0: |
| noise_text = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength) |
| noise_uncond = self.crs_model.apply_model(latents, ts, un_cond, metadata, global_strength) |
| noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) |
| else: |
| noise_pred = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength) |
|
|
| latents = self.scheduler.step( |
| model_output=noise_pred, |
| timestep=t, |
| sample=latents, |
| eta=eta, |
| generator=generator, |
| return_dict=True, |
| ).prev_sample |
|
|
| images = self.crs_model.decode_first_stage(latents) |
| images = images.clamp(-1, 1) |
| images = ((images + 1.0) / 2.0).permute(0, 2, 3, 1).cpu().numpy() |
| images = (images * 255.0).clip(0, 255).astype(np.uint8) |
|
|
| if output_type == "pil": |
| images = [Image.fromarray(img) for img in images] |
| elif output_type != "numpy": |
| raise ValueError("output_type must be 'pil' or 'numpy'") |
|
|
| return CRSDiffPipelineOutput(images=images) |
|
|
|
|
| def _is_component_list(v): |
| return ( |
| isinstance(v, (list, tuple)) |
| and len(v) == 2 |
| and isinstance(v[0], str) |
| and isinstance(v[1], str) |
| ) |
|
|