Pyramid-Flow / video_vae /causal_video_vae_wrapper.py
sironagasuyagi's picture
Upload folder using huggingface_hub
910e2ad verified
import torch
import os
import torch.nn as nn
from collections import OrderedDict
from .modeling_causal_vae import CausalVideoVAE
from .modeling_loss import LPIPSWithDiscriminator
from einops import rearrange
from PIL import Image
from IPython import embed
from utils import (
is_context_parallel_initialized,
get_context_parallel_group,
get_context_parallel_world_size,
get_context_parallel_rank,
get_context_parallel_group_rank,
)
from .context_parallel_ops import (
conv_scatter_to_context_parallel_region,
conv_gather_from_context_parallel_region,
)
class CausalVideoVAELossWrapper(nn.Module):
"""
The causal video vae training and inference running wrapper
"""
def __init__(self, model_path, model_dtype='fp32', disc_start=0, logvar_init=0.0, kl_weight=1.0,
pixelloss_weight=1.0, perceptual_weight=1.0, disc_weight=0.5, interpolate=True,
add_discriminator=True, freeze_encoder=False, load_loss_module=False, lpips_ckpt=None, **kwargs,
):
super().__init__()
if model_dtype == 'bf16':
torch_dtype = torch.bfloat16
elif model_dtype == 'fp16':
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
self.vae = CausalVideoVAE.from_pretrained(model_path, torch_dtype=torch_dtype, interpolate=False)
self.vae_scale_factor = self.vae.config.scaling_factor
if freeze_encoder:
print("Freeze the parameters of vae encoder")
for parameter in self.vae.encoder.parameters():
parameter.requires_grad = False
for parameter in self.vae.quant_conv.parameters():
parameter.requires_grad = False
self.add_discriminator = add_discriminator
self.freeze_encoder = freeze_encoder
# Used for training
if load_loss_module:
self.loss = LPIPSWithDiscriminator(disc_start, logvar_init=logvar_init, kl_weight=kl_weight,
pixelloss_weight=pixelloss_weight, perceptual_weight=perceptual_weight, disc_weight=disc_weight,
add_discriminator=add_discriminator, using_3d_discriminator=False, disc_num_layers=4, lpips_ckpt=lpips_ckpt)
else:
self.loss = None
self.disc_start = disc_start
def load_checkpoint(self, checkpoint_path, **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
vae_checkpoint = OrderedDict()
disc_checkpoint = OrderedDict()
for key in checkpoint.keys():
if key.startswith('vae.'):
new_key = key.split('.')
new_key = '.'.join(new_key[1:])
vae_checkpoint[new_key] = checkpoint[key]
if key.startswith('loss.discriminator'):
new_key = key.split('.')
new_key = '.'.join(new_key[2:])
disc_checkpoint[new_key] = checkpoint[key]
vae_ckpt_load_result = self.vae.load_state_dict(vae_checkpoint, strict=False)
print(f"Load vae checkpoint from {checkpoint_path}, load result: {vae_ckpt_load_result}")
disc_ckpt_load_result = self.loss.discriminator.load_state_dict(disc_checkpoint, strict=False)
print(f"Load disc checkpoint from {checkpoint_path}, load result: {disc_ckpt_load_result}")
def forward(self, x, step, identifier=['video']):
xdim = x.ndim
if xdim == 4:
x = x.unsqueeze(2) # (B, C, H, W) -> (B, C, 1, H , W)
if 'video' in identifier:
# The input is video
assert 'image' not in identifier
else:
# The input is image
assert 'video' not in identifier
# We arrange multiple images to a 5D Tensor for compatibility with video input
# So we needs to reformulate images into 1-frame video tensor
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = x.unsqueeze(2) # [(b t) c 1 h w]
if is_context_parallel_initialized():
assert self.training, "Only supports during training now"
cp_world_size = get_context_parallel_world_size()
global_src_rank = get_context_parallel_group_rank() * cp_world_size
# sync the input and split
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
batch_x = conv_scatter_to_context_parallel_region(x, dim=2, kernel_size=1)
else:
batch_x = x
posterior, reconstruct = self.vae(batch_x, freeze_encoder=self.freeze_encoder,
is_init_image=True, temporal_chunk=False,)
# The reconstruct loss
reconstruct_loss, rec_log = self.loss(
batch_x, reconstruct, posterior,
optimizer_idx=0, global_step=step, last_layer=self.vae.get_last_layer(),
)
if step < self.disc_start:
return reconstruct_loss, None, rec_log
# The loss to train the discriminator
gan_loss, gan_log = self.loss(batch_x, reconstruct, posterior, optimizer_idx=1,
global_step=step, last_layer=self.vae.get_last_layer(),
)
loss_log = {**rec_log, **gan_log}
return reconstruct_loss, gan_loss, loss_log
def encode(self, x, sample=False, is_init_image=True,
temporal_chunk=False, window_size=16, tile_sample_min_size=256,):
# x: (B, C, T, H, W) or (B, C, H, W)
B = x.shape[0]
xdim = x.ndim
if xdim == 4:
# The input is an image
x = x.unsqueeze(2)
if sample:
x = self.vae.encode(
x, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
window_size=window_size, tile_sample_min_size=tile_sample_min_size,
).latent_dist.sample()
else:
x = self.vae.encode(
x, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
window_size=window_size, tile_sample_min_size=tile_sample_min_size,
).latent_dist.mode()
return x
def decode(self, x, is_init_image=True, temporal_chunk=False,
window_size=2, tile_sample_min_size=256,):
# x: (B, C, T, H, W) or (B, C, H, W)
B = x.shape[0]
xdim = x.ndim
if xdim == 4:
# The input is an image
x = x.unsqueeze(2)
x = self.vae.decode(
x, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
window_size=window_size, tile_sample_min_size=tile_sample_min_size,
).sample
return x
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def reconstruct(
self, x, sample=False, return_latent=False, is_init_image=True,
temporal_chunk=False, window_size=16, tile_sample_min_size=256, **kwargs
):
assert x.shape[0] == 1
xdim = x.ndim
encode_window_size = window_size
decode_window_size = window_size // self.vae.downsample_scale
# Encode
x = self.encode(
x, sample, is_init_image, temporal_chunk, encode_window_size, tile_sample_min_size,
)
encode_latent = x
# Decode
x = self.decode(
x, is_init_image, temporal_chunk, decode_window_size, tile_sample_min_size
)
output_image = x.float()
output_image = (output_image / 2 + 0.5).clamp(0, 1)
# Convert to PIL images
output_image = rearrange(output_image, "B C T H W -> (B T) C H W")
output_image = output_image.cpu().permute(0, 2, 3, 1).numpy()
output_images = self.numpy_to_pil(output_image)
if return_latent:
return output_images, encode_latent
return output_images
# encode vae latent
def encode_latent(self, x, sample=False, is_init_image=True,
temporal_chunk=False, window_size=16, tile_sample_min_size=256,):
# Encode
latent = self.encode(
x, sample, is_init_image, temporal_chunk, window_size, tile_sample_min_size,
)
return latent
# decode vae latent
def decode_latent(self, latent, is_init_image=True,
temporal_chunk=False, window_size=2, tile_sample_min_size=256,):
x = self.decode(
latent, is_init_image, temporal_chunk, window_size, tile_sample_min_size
)
output_image = x.float()
output_image = (output_image / 2 + 0.5).clamp(0, 1)
# Convert to PIL images
output_image = rearrange(output_image, "B C T H W -> (B T) C H W")
output_image = output_image.cpu().permute(0, 2, 3, 1).numpy()
output_images = self.numpy_to_pil(output_image)
return output_images
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype