|
import torch |
|
from diffusers import AutoencoderKL |
|
from einops import rearrange |
|
from torch import Tensor |
|
|
|
|
|
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder |
|
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder |
|
import xora.utils.dist_util |
|
|
|
|
|
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor: |
|
""" |
|
Encodes media items (images or videos) into latent representations using a specified VAE model. |
|
The function supports processing batches of images or video frames and can handle the processing |
|
in smaller sub-batches if needed. |
|
|
|
Args: |
|
media_items (Tensor): A torch Tensor containing the media items to encode. The expected |
|
shape is (batch_size, channels, height, width) for images or (batch_size, channels, |
|
frames, height, width) for videos. |
|
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, |
|
pre-configured and loaded with the appropriate model weights. |
|
split_size (int, optional): The number of sub-batches to split the input batch into for encoding. |
|
If set to more than 1, the input media items are processed in smaller batches according to |
|
this value. Defaults to 1, which processes all items in a single batch. |
|
|
|
Returns: |
|
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted |
|
to match the input shape, scaled by the model's configuration. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> from diffusers import AutoencoderKL |
|
>>> vae = AutoencoderKL.from_pretrained('your-model-name') |
|
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. |
|
>>> latents = vae_encode(images, vae) |
|
>>> print(latents.shape) # Output shape will depend on the model's latent configuration. |
|
|
|
Note: |
|
In case of a video, the function encodes the media item frame-by frame. |
|
""" |
|
is_video_shaped = media_items.dim() == 5 |
|
batch_size, channels = media_items.shape[0:2] |
|
|
|
if channels != 3: |
|
raise ValueError(f"Expects tensors with 3 channels, got {channels}.") |
|
|
|
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): |
|
media_items = rearrange(media_items, "b c n h w -> (b n) c h w") |
|
if split_size > 1: |
|
if len(media_items) % split_size != 0: |
|
raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") |
|
encode_bs = len(media_items) // split_size |
|
|
|
latents = [] |
|
dist_util.execute_graph() |
|
for image_batch in media_items.split(encode_bs): |
|
latents.append(vae.encode(image_batch).latent_dist.sample()) |
|
dist_util.execute_graph() |
|
latents = torch.cat(latents, dim=0) |
|
else: |
|
latents = vae.encode(media_items).latent_dist.sample() |
|
|
|
latents = normalize_latents(latents, vae, vae_per_channel_normalize) |
|
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): |
|
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) |
|
return latents |
|
|
|
|
|
def vae_decode( |
|
latents: Tensor, vae: AutoencoderKL, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize=False |
|
) -> Tensor: |
|
is_video_shaped = latents.dim() == 5 |
|
batch_size = latents.shape[0] |
|
|
|
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): |
|
latents = rearrange(latents, "b c n h w -> (b n) c h w") |
|
if split_size > 1: |
|
if len(latents) % split_size != 0: |
|
raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") |
|
encode_bs = len(latents) // split_size |
|
image_batch = [ |
|
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize) |
|
for latent_batch in latents.split(encode_bs) |
|
] |
|
images = torch.cat(image_batch, dim=0) |
|
else: |
|
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize) |
|
|
|
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): |
|
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) |
|
return images |
|
|
|
|
|
def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor: |
|
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): |
|
*_, fl, hl, wl = latents.shape |
|
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) |
|
latents = latents.to(vae.dtype) |
|
image = vae.decode( |
|
un_normalize_latents(latents, vae, vae_per_channel_normalize), |
|
return_dict=False, |
|
target_shape=(1, 3, fl * temporal_scale if is_video else 1, hl * spatial_scale, wl * spatial_scale), |
|
)[0] |
|
else: |
|
image = vae.decode( |
|
un_normalize_latents(latents, vae, vae_per_channel_normalize), |
|
return_dict=False, |
|
)[0] |
|
return image |
|
|
|
|
|
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: |
|
if isinstance(vae, CausalVideoAutoencoder): |
|
spatial = vae.spatial_downscale_factor |
|
temporal = vae.temporal_downscale_factor |
|
else: |
|
down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)]) |
|
spatial = vae.config.patch_size * 2**down_blocks |
|
temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1 |
|
|
|
return (temporal, spatial, spatial) |
|
|
|
|
|
def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: |
|
return ( |
|
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) |
|
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) |
|
if vae_per_channel_normalize |
|
else latents * vae.config.scaling_factor |
|
) |
|
|
|
|
|
def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: |
|
return ( |
|
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) |
|
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) |
|
if vae_per_channel_normalize |
|
else latents / vae.config.scaling_factor |
|
) |