Spaces:
Sleeping
Sleeping
| """ | |
| VQGAN Autoencoder module for encoding/decoding images to/from latent space. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| import sys | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # Handle import of ldm from latent-diffusion repository | |
| # Check if ldm directory exists locally (from latent-diffusion repo) | |
| _ldm_path = Path(__file__).parent.parent / "ldm" | |
| if _ldm_path.exists() and str(_ldm_path) not in sys.path: | |
| sys.path.insert(0, str(_ldm_path.parent)) | |
| try: | |
| from ldm.models.autoencoder import VQModelTorch | |
| except ImportError: | |
| # Fallback: try importing from site-packages if latent-diffusion is installed | |
| try: | |
| import importlib.util | |
| spec = importlib.util.find_spec("ldm.models.autoencoder") | |
| if spec is None: | |
| raise ImportError("Could not find ldm.models.autoencoder") | |
| from ldm.models.autoencoder import VQModelTorch | |
| except ImportError as e: | |
| raise ImportError( | |
| "Could not import VQModelTorch from ldm.models.autoencoder. " | |
| "Please ensure the latent-diffusion repository is cloned and the ldm directory exists, " | |
| "or install latent-diffusion package. Error: " + str(e) | |
| ) | |
| from config import ( | |
| autoencoder_ckpt_path, | |
| autoencoder_use_fp16, | |
| autoencoder_embed_dim, | |
| autoencoder_n_embed, | |
| autoencoder_double_z, | |
| autoencoder_z_channels, | |
| autoencoder_resolution, | |
| autoencoder_in_channels, | |
| autoencoder_out_ch, | |
| autoencoder_ch, | |
| autoencoder_ch_mult, | |
| autoencoder_num_res_blocks, | |
| autoencoder_attn_resolutions, | |
| autoencoder_dropout, | |
| autoencoder_padding_mode, | |
| _project_root, | |
| device | |
| ) | |
| # Hugging Face repo ID for weights | |
| HF_WEIGHTS_REPO_ID = "shekkari21/DiffusionSR-weights" | |
| def load_vqgan(ckpt_path=None, device=device): | |
| """ | |
| Load VQGAN autoencoder from checkpoint. | |
| Args: | |
| ckpt_path: Path to checkpoint file. If None, uses config path. | |
| device: Device to load model on. | |
| Returns: | |
| VQGAN model in eval mode. | |
| """ | |
| if ckpt_path is None: | |
| ckpt_path = autoencoder_ckpt_path | |
| # Resolve path relative to project root | |
| if not Path(ckpt_path).is_absolute(): | |
| ckpt_path = _project_root / ckpt_path | |
| # Download from Hugging Face if not found locally | |
| if not Path(ckpt_path).exists(): | |
| print(f"VQGAN checkpoint not found locally. Downloading from Hugging Face...") | |
| try: | |
| # Files are in root of weights repo, download to local directory structure | |
| local_weights_dir = _project_root / "pretrained_weights" | |
| local_weights_dir.mkdir(parents=True, exist_ok=True) | |
| # Download from root of weights repo | |
| downloaded_path = hf_hub_download( | |
| repo_id=HF_WEIGHTS_REPO_ID, | |
| filename="autoencoder_vq_f4.pth", | |
| local_dir=str(local_weights_dir), | |
| local_dir_use_symlinks=False | |
| ) | |
| ckpt_path = local_weights_dir / "autoencoder_vq_f4.pth" | |
| print(f"✓ Downloaded VQGAN checkpoint: {ckpt_path}") | |
| except Exception as e: | |
| raise FileNotFoundError( | |
| f"VQGAN checkpoint not found at: {ckpt_path}\n" | |
| f"Could not download from Hugging Face: {e}\n" | |
| f"Please ensure the file exists in the repository." | |
| ) | |
| print(f"Loading VQGAN from: {ckpt_path}") | |
| # Load checkpoint | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| # Extract state_dict | |
| if isinstance(checkpoint, dict): | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| elif 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| else: | |
| state_dict = checkpoint | |
| else: | |
| raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}") | |
| # Create model architecture | |
| ddconfig = { | |
| 'double_z': autoencoder_double_z, | |
| 'z_channels': autoencoder_z_channels, | |
| 'resolution': autoencoder_resolution, | |
| 'in_channels': autoencoder_in_channels, | |
| 'out_ch': autoencoder_out_ch, | |
| 'ch': autoencoder_ch, | |
| 'ch_mult': autoencoder_ch_mult, | |
| 'num_res_blocks': autoencoder_num_res_blocks, | |
| 'attn_resolutions': autoencoder_attn_resolutions, | |
| 'dropout': autoencoder_dropout, | |
| 'padding_mode': autoencoder_padding_mode, | |
| } | |
| model = VQModelTorch( | |
| ddconfig=ddconfig, | |
| n_embed=autoencoder_n_embed, | |
| embed_dim=autoencoder_embed_dim, | |
| ) | |
| # Load state_dict | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| model.to(device) | |
| if autoencoder_use_fp16: | |
| model = model.half() | |
| print(f"VQGAN loaded successfully on {device}") | |
| return model | |
| class VQGANWrapper(nn.Module): | |
| """ | |
| Simple wrapper for VQGAN autoencoder. | |
| """ | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def encode(self, x): | |
| """ | |
| Encode image to latent space. | |
| Args: | |
| x: (B, 3, H, W) Image tensor in range [0, 1] | |
| Returns: | |
| z: (B, 3, H//4, W//4) Latent tensor | |
| """ | |
| # Ensure model is in eval mode | |
| self.model.eval() | |
| with torch.no_grad(): | |
| # Normalize to [-1, 1] if needed | |
| if x.max() <= 1.0: | |
| x = x * 2.0 - 1.0 # [0, 1] -> [-1, 1] | |
| # Match model dtype (handle fp16 models) | |
| model_dtype = next(self.model.parameters()).dtype | |
| if x.dtype != model_dtype: | |
| x = x.to(model_dtype) | |
| # Ensure input is on same device as model | |
| model_device = next(self.model.parameters()).device | |
| if x.device != model_device: | |
| x = x.to(model_device) | |
| # Encode | |
| z = self.model.encode(x) | |
| # Extract latent from tuple/dict if needed | |
| if isinstance(z, (tuple, list)): | |
| z = z[0] | |
| elif isinstance(z, dict): | |
| z = z.get('z', z.get('latent', z)) | |
| # Convert back to float32 for consistency | |
| if z.dtype != torch.float32: | |
| z = z.float() | |
| return z | |
| def decode(self, z): | |
| """ | |
| Decode latent to image space. | |
| Args: | |
| z: (B, 3, H, W) Latent tensor | |
| Returns: | |
| x: (B, 3, H*4, W*4) Image tensor in range [0, 1] | |
| """ | |
| with torch.no_grad(): | |
| # Match model dtype (handle fp16 models) | |
| model_dtype = next(self.model.parameters()).dtype | |
| if z.dtype != model_dtype: | |
| z = z.to(model_dtype) | |
| # Decode | |
| x = self.model.decode(z) | |
| # Convert back to float32 | |
| if x.dtype != torch.float32: | |
| x = x.float() | |
| # Normalize back to [0, 1] | |
| if x.min() < 0: | |
| x = (x + 1.0) / 2.0 # [-1, 1] -> [0, 1] | |
| x = torch.clamp(x, 0, 1) | |
| return x | |
| # Convenience function | |
| def get_vqgan(ckpt_path=None, device=device): | |
| """Get VQGAN model instance.""" | |
| model = load_vqgan(ckpt_path=ckpt_path, device=device) | |
| return VQGANWrapper(model) | |