| """ |
| Helper functions to encode image or decode latents using Qwen-VAE |
| """ |
|
|
| from typing import List |
|
|
| from diffusers import AutoencoderKLQwenImage |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
|
|
| def encode_image( |
| images: List[Image.Image], model: AutoencoderKLQwenImage |
| ) -> torch.Tensor: |
| """ |
| Encode a batch of PIL Images to latents. |
| |
| Args: |
| images: List of PIL Images |
| model: Qwen VAE model |
| |
| Returns: |
| Latents tensor of shape [batch_size, channel, height, width] |
| """ |
| device = next(model.parameters()).device |
| dtype = next(model.parameters()).dtype |
|
|
| |
| tensors = [] |
| for img in images: |
| tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).to(dtype) |
| tensor = tensor / 127.5 - 1.0 |
| tensors.append(tensor) |
|
|
| batch = torch.stack(tensors).to(device) |
|
|
| |
| batch = batch.unsqueeze(2) |
|
|
| with torch.no_grad(): |
| latents = model.encode(batch).latent_dist.sample() |
|
|
| |
| latents = latents.squeeze(2) |
|
|
| return latents |
|
|
|
|
| def decode_latents( |
| latents: torch.Tensor, model: AutoencoderKLQwenImage |
| ) -> List[Image.Image]: |
| """ |
| Decode latents to a batch of PIL Images. |
| |
| Args: |
| latents: Latents tensor of shape [B, C, H, W] |
| model: Qwen VAE model |
| |
| Returns: |
| List of PIL Images |
| """ |
| device = next(model.parameters()).device |
| dtype = next(model.parameters()).device |
|
|
| |
| latents = latents.unsqueeze(2).to(device).to(dtype) |
|
|
| with torch.no_grad(): |
| reconstructed = model.decode(latents).sample |
|
|
| |
| reconstructed = reconstructed.squeeze(2) |
|
|
| |
| images = [] |
| for tensor in reconstructed: |
| tensor = torch.clamp(tensor, -1.0, 1.0) |
| tensor = (tensor + 1.0) * 127.5 |
|
|
| img_array = tensor.permute(1, 2, 0).cpu().to(torch.uint8).numpy() |
| images.append(Image.fromarray(img_array)) |
|
|
| return images |
|
|
|
|
| if __name__ == "__main__": |
| vae = AutoencoderKLQwenImage.from_pretrained( |
| "REPA-E/e2e-qwenimage-vae", torch_dtype=torch.bfloat16, device_map="cuda:0" |
| ) |
|
|
| test_image = Image.open("test_image.png") |
|
|
| latents = encode_image([test_image], vae) |
| print(f"Latents shape: {latents.shape}") |
| print(f"Latents dtype: {latents.dtype}") |
| print(f"Latents device: {latents.device}") |
|
|
| |
| reconstructed_images = decode_latents(latents, vae) |
| print(f"Reconstructed {len(reconstructed_images)} images") |
|
|
| |
| reconstructed_images[0].save("test_reconstruction.png") |
| print("Saved reconstructed image to test_reconstruction.png") |
|
|
|
|
| def get_vae_stats( |
| vae_path: str, device: torch.device = None |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Get latent mean and std from VAE config. |
| |
| Args: |
| vae_path: Path to VAE model |
| device: Device to put tensors on |
| |
| Returns: |
| (mean, std) tuple of tensors with shape [1, C, 1, 1] |
| """ |
| |
| from transformers import PretrainedConfig |
|
|
| try: |
| config = PretrainedConfig.from_pretrained(vae_path) |
| except Exception: |
| |
| vae = AutoencoderKLQwenImage.from_pretrained( |
| vae_path, torch_dtype=torch.bfloat16, local_files_only=True |
| ) |
| config = vae.config |
| del vae |
|
|
| if hasattr(config, "latents_mean") and config.latents_mean is not None: |
| mean = torch.tensor(config.latents_mean).view(1, -1, 1, 1) |
| else: |
| mean = torch.zeros(1, 16, 1, 1) |
|
|
| if hasattr(config, "latents_std") and config.latents_std is not None: |
| std = torch.tensor(config.latents_std).view(1, -1, 1, 1) |
| else: |
| std = torch.ones(1, 16, 1, 1) |
|
|
| if device: |
| mean = mean.to(device) |
| std = std.to(device) |
|
|
| |
| |
| |
| |
| |
| |
| return mean, std |
|
|