| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import typing |
| from typing import Optional, Union |
|
|
| import torch |
| from PIL import Image |
| from torchvision import transforms |
|
|
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.models.autoencoders.autoencoder_kl import ( |
| AutoencoderKL, |
| AutoencoderKLOutput, |
| ) |
| from diffusers.models.autoencoders.autoencoder_tiny import ( |
| AutoencoderTiny, |
| AutoencoderTinyOutput, |
| ) |
| from diffusers.models.autoencoders.vae import DecoderOutput |
|
|
|
|
| SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny] |
|
|
|
|
| def load_vae_model( |
| *, |
| device: torch.device, |
| model_name_or_path: str, |
| revision: Optional[str], |
| variant: Optional[str], |
| |
| subfolder: Optional[str], |
| use_tiny_nn: bool, |
| ) -> SupportedAutoencoder: |
| if use_tiny_nn: |
| |
| down_scale = 2 |
| up_scale = 2 |
| vae = AutoencoderTiny.from_pretrained( |
| model_name_or_path, |
| subfolder=subfolder, |
| revision=revision, |
| variant=variant, |
| downscaling_scaling_factor=down_scale, |
| upsampling_scaling_factor=up_scale, |
| ) |
| assert isinstance(vae, AutoencoderTiny) |
| else: |
| vae = AutoencoderKL.from_pretrained( |
| model_name_or_path, |
| subfolder=subfolder, |
| revision=revision, |
| variant=variant, |
| ) |
| assert isinstance(vae, AutoencoderKL) |
| vae = vae.to(device) |
| vae.eval() |
| return vae |
|
|
|
|
| def pil_to_nhwc( |
| *, |
| device: torch.device, |
| image: Image.Image, |
| ) -> torch.Tensor: |
| assert image.mode == "RGB" |
| transform = transforms.ToTensor() |
| nhwc = transform(image).unsqueeze(0).to(device) |
| assert isinstance(nhwc, torch.Tensor) |
| return nhwc |
|
|
|
|
| def nhwc_to_pil( |
| *, |
| nhwc: torch.Tensor, |
| ) -> Image.Image: |
| assert nhwc.shape[0] == 1 |
| hwc = nhwc.squeeze(0).cpu() |
| return transforms.ToPILImage()(hwc) |
|
|
|
|
| def concatenate_images( |
| *, |
| left: Image.Image, |
| right: Image.Image, |
| vertical: bool = False, |
| ) -> Image.Image: |
| width1, height1 = left.size |
| width2, height2 = right.size |
| if vertical: |
| total_height = height1 + height2 |
| max_width = max(width1, width2) |
| new_image = Image.new("RGB", (max_width, total_height)) |
| new_image.paste(left, (0, 0)) |
| new_image.paste(right, (0, height1)) |
| else: |
| total_width = width1 + width2 |
| max_height = max(height1, height2) |
| new_image = Image.new("RGB", (total_width, max_height)) |
| new_image.paste(left, (0, 0)) |
| new_image.paste(right, (width1, 0)) |
| return new_image |
|
|
|
|
| def to_latent( |
| *, |
| rgb_nchw: torch.Tensor, |
| vae: SupportedAutoencoder, |
| ) -> torch.Tensor: |
| rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) |
| encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw)) |
| if isinstance(encoding_nchw, AutoencoderKLOutput): |
| latent = encoding_nchw.latent_dist.sample() |
| assert isinstance(latent, torch.Tensor) |
| elif isinstance(encoding_nchw, AutoencoderTinyOutput): |
| latent = encoding_nchw.latents |
| do_internal_vae_scaling = False |
| if do_internal_vae_scaling: |
| latent = vae.scale_latents(latent).mul(255).round().byte() |
| latent = vae.unscale_latents(latent / 255.0) |
| assert isinstance(latent, torch.Tensor) |
| else: |
| assert False, f"Unknown encoding type: {type(encoding_nchw)}" |
| return latent |
|
|
|
|
| def from_latent( |
| *, |
| latent_nchw: torch.Tensor, |
| vae: SupportedAutoencoder, |
| ) -> torch.Tensor: |
| decoding_nchw = vae.decode(latent_nchw) |
| assert isinstance(decoding_nchw, DecoderOutput) |
| rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) |
| assert isinstance(rgb_nchw, torch.Tensor) |
| return rgb_nchw |
|
|
|
|
| def main_kwargs( |
| *, |
| device: torch.device, |
| input_image_path: str, |
| pretrained_model_name_or_path: str, |
| revision: Optional[str], |
| variant: Optional[str], |
| subfolder: Optional[str], |
| use_tiny_nn: bool, |
| ) -> None: |
| vae = load_vae_model( |
| device=device, |
| model_name_or_path=pretrained_model_name_or_path, |
| revision=revision, |
| variant=variant, |
| subfolder=subfolder, |
| use_tiny_nn=use_tiny_nn, |
| ) |
| original_pil = Image.open(input_image_path).convert("RGB") |
| original_image = pil_to_nhwc( |
| device=device, |
| image=original_pil, |
| ) |
| print(f"Original image shape: {original_image.shape}") |
| reconstructed_image: Optional[torch.Tensor] = None |
|
|
| with torch.no_grad(): |
| latent_image = to_latent(rgb_nchw=original_image, vae=vae) |
| print(f"Latent shape: {latent_image.shape}") |
| reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae) |
| reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image) |
| combined_image = concatenate_images( |
| left=original_pil, |
| right=reconstructed_pil, |
| vertical=False, |
| ) |
| combined_image.show("Original | Reconstruction") |
| print(f"Reconstructed image shape: {reconstructed_image.shape}") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Inference with VAE") |
| parser.add_argument( |
| "--input_image", |
| type=str, |
| required=True, |
| help="Path to the input image for inference.", |
| ) |
| parser.add_argument( |
| "--pretrained_model_name_or_path", |
| type=str, |
| required=True, |
| help="Path to pretrained VAE model.", |
| ) |
| parser.add_argument( |
| "--revision", |
| type=str, |
| default=None, |
| help="Model version.", |
| ) |
| parser.add_argument( |
| "--variant", |
| type=str, |
| default=None, |
| help="Model file variant, e.g., 'fp16'.", |
| ) |
| parser.add_argument( |
| "--subfolder", |
| type=str, |
| default=None, |
| help="Subfolder in the model file.", |
| ) |
| parser.add_argument( |
| "--use_cuda", |
| action="store_true", |
| help="Use CUDA if available.", |
| ) |
| parser.add_argument( |
| "--use_tiny_nn", |
| action="store_true", |
| help="Use tiny neural network.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| def main_cli() -> None: |
| args = parse_args() |
|
|
| input_image_path = args.input_image |
| assert isinstance(input_image_path, str) |
|
|
| pretrained_model_name_or_path = args.pretrained_model_name_or_path |
| assert isinstance(pretrained_model_name_or_path, str) |
|
|
| revision = args.revision |
| assert isinstance(revision, (str, type(None))) |
|
|
| variant = args.variant |
| assert isinstance(variant, (str, type(None))) |
|
|
| subfolder = args.subfolder |
| assert isinstance(subfolder, (str, type(None))) |
|
|
| use_cuda = args.use_cuda |
| assert isinstance(use_cuda, bool) |
|
|
| use_tiny_nn = args.use_tiny_nn |
| assert isinstance(use_tiny_nn, bool) |
|
|
| device = torch.device("cuda" if use_cuda else "cpu") |
|
|
| main_kwargs( |
| device=device, |
| input_image_path=input_image_path, |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| revision=revision, |
| variant=variant, |
| subfolder=subfolder, |
| use_tiny_nn=use_tiny_nn, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main_cli() |
|
|