| | """IRDiffAE: standalone HuggingFace-compatible iRDiffAE model.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from pathlib import Path |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| | from .config import IRDiffAEConfig, IRDiffAEInferenceConfig |
| | from .decoder import Decoder |
| | from .encoder import Encoder |
| | from .samplers import run_ddim, run_dpmpp_2m |
| | from .vp_diffusion import get_schedule, make_initial_state, sample_noise |
| |
|
| |
|
| | def _resolve_model_dir( |
| | path_or_repo_id: str | Path, |
| | *, |
| | revision: str | None, |
| | cache_dir: str | Path | None, |
| | ) -> Path: |
| | """Resolve a local path or HuggingFace Hub repo ID to a local directory.""" |
| |
|
| | local = Path(path_or_repo_id) |
| | if local.is_dir(): |
| | return local |
| | |
| | repo_id = str(path_or_repo_id) |
| | try: |
| | from huggingface_hub import snapshot_download |
| | except ImportError: |
| | raise ImportError( |
| | f"'{repo_id}' is not an existing local directory. " |
| | "To download from HuggingFace Hub, install huggingface_hub: " |
| | "pip install huggingface_hub" |
| | ) |
| | cache_dir_str = str(cache_dir) if cache_dir is not None else None |
| | local_dir = snapshot_download( |
| | repo_id, |
| | revision=revision, |
| | cache_dir=cache_dir_str, |
| | ) |
| | return Path(local_dir) |
| |
|
| |
|
| | class IRDiffAE(nn.Module): |
| | """Standalone iRDiffAE model for HuggingFace distribution. |
| | |
| | A diffusion autoencoder that encodes images to compact latents and |
| | decodes them back via iterative VP diffusion. |
| | |
| | Usage:: |
| | |
| | model = IRDiffAE.from_pretrained("path/to/weights") |
| | model = model.to("cuda", dtype=torch.bfloat16) |
| | |
| | # Encode |
| | latents = model.encode(images) # images: [B,3,H,W] in [-1,1] |
| | |
| | # Decode (1 step by default — PSNR-optimal) |
| | recon = model.decode(latents, height=H, width=W) |
| | |
| | # Reconstruct (encode + 1-step decode) |
| | recon = model.reconstruct(images) |
| | """ |
| |
|
| | def __init__(self, config: IRDiffAEConfig) -> None: |
| | super().__init__() |
| | self.config = config |
| |
|
| | self.encoder = Encoder( |
| | in_channels=config.in_channels, |
| | patch_size=config.patch_size, |
| | model_dim=config.model_dim, |
| | depth=config.encoder_depth, |
| | bottleneck_dim=config.bottleneck_dim, |
| | mlp_ratio=config.mlp_ratio, |
| | depthwise_kernel_size=config.depthwise_kernel_size, |
| | ) |
| |
|
| | self.decoder = Decoder( |
| | in_channels=config.in_channels, |
| | patch_size=config.patch_size, |
| | model_dim=config.model_dim, |
| | depth=config.decoder_depth, |
| | bottleneck_dim=config.bottleneck_dim, |
| | mlp_ratio=config.mlp_ratio, |
| | depthwise_kernel_size=config.depthwise_kernel_size, |
| | adaln_low_rank_rank=config.adaln_low_rank_rank, |
| | ) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | path_or_repo_id: str | Path, |
| | *, |
| | dtype: torch.dtype = torch.bfloat16, |
| | device: str | torch.device = "cpu", |
| | revision: str | None = None, |
| | cache_dir: str | Path | None = None, |
| | ) -> IRDiffAE: |
| | """Load a pretrained model from a local directory or HuggingFace Hub. |
| | |
| | The directory (or repo) should contain: |
| | - config.json: Model architecture config. |
| | - model.safetensors (preferred) or model.pt: Model weights. |
| | |
| | Args: |
| | path_or_repo_id: Local directory path or HuggingFace Hub repo ID |
| | (e.g. ``"data-archetype/irdiffae-v1"``). |
| | dtype: Load weights in this dtype (float32 or bfloat16). |
| | device: Target device. |
| | revision: Git revision (branch, tag, or commit) for Hub downloads. |
| | cache_dir: Where to cache Hub downloads. Uses HF default if None. |
| | |
| | Returns: |
| | Loaded model in eval mode. |
| | """ |
| | model_dir = _resolve_model_dir( |
| | path_or_repo_id, revision=revision, cache_dir=cache_dir |
| | ) |
| | config = IRDiffAEConfig.load(model_dir / "config.json") |
| | model = cls(config) |
| |
|
| | |
| | safetensors_path = model_dir / "model.safetensors" |
| | pt_path = model_dir / "model.pt" |
| |
|
| | if safetensors_path.exists(): |
| | try: |
| | from safetensors.torch import load_file |
| |
|
| | state_dict = load_file(str(safetensors_path), device=str(device)) |
| | except ImportError: |
| | raise ImportError( |
| | "safetensors package required to load .safetensors files. " |
| | "Install with: pip install safetensors" |
| | ) |
| | elif pt_path.exists(): |
| | state_dict = torch.load( |
| | str(pt_path), map_location=device, weights_only=True |
| | ) |
| | else: |
| | raise FileNotFoundError( |
| | f"No model weights found in {model_dir}. " |
| | "Expected model.safetensors or model.pt." |
| | ) |
| |
|
| | model.load_state_dict(state_dict) |
| | model = model.to(dtype=dtype, device=torch.device(device)) |
| | model.eval() |
| | return model |
| |
|
| | def encode(self, images: Tensor) -> Tensor: |
| | """Encode images to latents. |
| | |
| | Args: |
| | images: [B, 3, H, W] in [-1, 1], H and W must be divisible by patch_size. |
| | |
| | Returns: |
| | Latents [B, bottleneck_dim, H/patch, W/patch]. |
| | """ |
| | try: |
| | model_dtype = next(self.parameters()).dtype |
| | except StopIteration: |
| | model_dtype = torch.float32 |
| | return self.encoder(images.to(dtype=model_dtype)) |
| |
|
| | @torch.no_grad() |
| | def decode( |
| | self, |
| | latents: Tensor, |
| | height: int, |
| | width: int, |
| | *, |
| | inference_config: IRDiffAEInferenceConfig | None = None, |
| | ) -> Tensor: |
| | """Decode latents to images via VP diffusion. |
| | |
| | Args: |
| | latents: [B, bottleneck_dim, h, w] encoder latents. |
| | height: Output image height (must be divisible by patch_size). |
| | width: Output image width (must be divisible by patch_size). |
| | inference_config: Optional inference parameters. Uses defaults if None. |
| | |
| | Returns: |
| | Reconstructed images [B, 3, H, W] in float32. |
| | """ |
| | cfg = inference_config or IRDiffAEInferenceConfig() |
| | config = self.config |
| | batch = int(latents.shape[0]) |
| | device = latents.device |
| |
|
| | |
| | try: |
| | model_dtype = next(self.parameters()).dtype |
| | except StopIteration: |
| | model_dtype = torch.float32 |
| |
|
| | |
| | if height % config.patch_size != 0 or width % config.patch_size != 0: |
| | raise ValueError( |
| | f"height={height} and width={width} must be divisible by patch_size={config.patch_size}" |
| | ) |
| |
|
| | |
| | shape = (batch, config.in_channels, height, width) |
| | noise = sample_noise( |
| | shape, |
| | noise_std=config.pixel_noise_std, |
| | seed=cfg.seed, |
| | device=torch.device("cpu"), |
| | dtype=torch.float32, |
| | ) |
| |
|
| | |
| | schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device) |
| |
|
| | |
| | initial_state = make_initial_state( |
| | noise=noise.to(device=device), |
| | t_start=schedule[0:1], |
| | logsnr_min=config.logsnr_min, |
| | logsnr_max=config.logsnr_max, |
| | ) |
| |
|
| | |
| | device_type = "cuda" if device.type == "cuda" else "cpu" |
| | with torch.autocast(device_type=device_type, enabled=False): |
| | latents_in = latents.to(device=device) |
| |
|
| | def _forward_fn( |
| | x_t: Tensor, |
| | t: Tensor, |
| | latents: Tensor, |
| | *, |
| | drop_middle_blocks: bool = False, |
| | ) -> Tensor: |
| | return self.decoder( |
| | x_t.to(dtype=model_dtype), |
| | t, |
| | latents.to(dtype=model_dtype), |
| | drop_middle_blocks=drop_middle_blocks, |
| | ) |
| |
|
| | |
| | if cfg.sampler == "ddim": |
| | sampler_fn = run_ddim |
| | elif cfg.sampler == "dpmpp_2m": |
| | sampler_fn = run_dpmpp_2m |
| | else: |
| | raise ValueError( |
| | f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'." |
| | ) |
| |
|
| | result = sampler_fn( |
| | forward_fn=_forward_fn, |
| | initial_state=initial_state, |
| | schedule=schedule, |
| | latents=latents_in, |
| | logsnr_min=config.logsnr_min, |
| | logsnr_max=config.logsnr_max, |
| | pdg_enabled=cfg.pdg_enabled, |
| | pdg_strength=cfg.pdg_strength, |
| | device=device, |
| | ) |
| |
|
| | return result |
| |
|
| | @torch.no_grad() |
| | def reconstruct( |
| | self, |
| | images: Tensor, |
| | *, |
| | inference_config: IRDiffAEInferenceConfig | None = None, |
| | ) -> Tensor: |
| | """Encode then decode. Convenience wrapper. |
| | |
| | Args: |
| | images: [B, 3, H, W] in [-1, 1]. |
| | inference_config: Optional inference parameters. |
| | |
| | Returns: |
| | Reconstructed images [B, 3, H, W] in float32. |
| | """ |
| | latents = self.encode(images) |
| | _, _, h, w = images.shape |
| | return self.decode( |
| | latents, height=h, width=w, inference_config=inference_config |
| | ) |
| |
|