FloodDiffusion / ldf_models /vae_wan_1d.py
herrscher0's picture
Initial commit: FloodDiffusion text-to-motion generation model
ebc7f2e verified
import numpy as np
import torch
import torch.nn as nn
from .tools.wan_vae_1d import WanVAE_
class VAEWanModel(nn.Module):
def __init__(
self,
input_dim,
mean_path=None,
std_path=None,
z_dim=256,
dim=160,
dec_dim=512,
num_res_blocks=1,
dropout=0.0,
dim_mult=[1, 1, 1],
temperal_downsample=[True, True],
vel_window=[0, 0],
**kwargs,
):
super().__init__()
self.mean_path = mean_path
self.std_path = std_path
self.input_dim = input_dim
self.z_dim = z_dim
self.dim = dim
self.dec_dim = dec_dim
self.num_res_blocks = num_res_blocks
self.dropout = dropout
self.dim_mult = dim_mult
self.temperal_downsample = temperal_downsample
self.vel_window = vel_window
self.RECONS_LOSS = nn.SmoothL1Loss()
self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5)
self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)
if self.mean_path is not None:
self.register_buffer(
"mean", torch.from_numpy(np.load(self.mean_path)).float()
)
else:
self.register_buffer("mean", torch.zeros(input_dim))
if self.std_path is not None:
self.register_buffer(
"std", torch.from_numpy(np.load(self.std_path)).float()
)
else:
self.register_buffer("std", torch.ones(input_dim))
self.model = WanVAE_(
input_dim=self.input_dim,
dim=self.dim,
dec_dim=self.dec_dim,
z_dim=self.z_dim,
dim_mult=self.dim_mult,
num_res_blocks=self.num_res_blocks,
temperal_downsample=self.temperal_downsample,
dropout=self.dropout,
)
downsample_factor = 1
for flag in self.temperal_downsample:
if flag:
downsample_factor *= 2
self.downsample_factor = downsample_factor
def preprocess(self, x):
# (bs, T, C) -> (bs, C, T)
x = x.permute(0, 2, 1)
return x
def postprocess(self, x):
# (bs, C, T) -> (bs, T, C)
x = x.permute(0, 2, 1)
return x
def forward(self, x):
features = x["feature"]
feature_length = x["feature_length"]
features = (features - self.mean) / self.std
# create mask based on feature_length
batch_size, seq_len = features.shape[:2]
mask = torch.zeros(
batch_size, seq_len, dtype=torch.bool, device=features.device
)
for i in range(batch_size):
mask[i, : feature_length[i]] = True
x_in = self.preprocess(features) # (bs, input_dim, T)
mu, log_var = self.model.encode(
x_in, scale=[0, 1], return_dist=True
) # (bs, z_dim, T)
z = self.model.reparameterize(mu, log_var)
x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T)
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
if x_out.size(1) != features.size(1):
min_len = min(x_out.size(1), features.size(1))
x_out = x_out[:, :min_len, :]
features = features[:, :min_len, :]
mask = mask[:, :min_len]
mask_expanded = mask.unsqueeze(-1)
x_out_masked = x_out * mask_expanded
features_masked = features * mask_expanded
loss_recons = self.RECONS_LOSS(x_out_masked, features_masked)
vel_start = self.vel_window[0]
vel_end = self.vel_window[1]
loss_vel = self.RECONS_LOSS(
x_out_masked[..., vel_start:vel_end],
features_masked[..., vel_start:vel_end],
)
# Compute KL divergence loss
# KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# log_var = log(sigma^2), so we can use it directly
# Build mask for latent space
T_latent = mu.size(2)
mask_downsampled = torch.zeros(
batch_size, T_latent, dtype=torch.bool, device=features.device
)
for i in range(batch_size):
latent_length = (
feature_length[i] + self.downsample_factor - 1
) // self.downsample_factor
mask_downsampled[i, :latent_length] = True
mask_latent = mask_downsampled.unsqueeze(1) # (B, 1, T_latent)
# Compute KL loss per element
kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
# Apply mask: only compute KL loss for valid timesteps
kl_masked = kl_per_element * mask_latent
# Sum over all dimensions and normalize by the number of valid elements
kl_loss = torch.sum(kl_masked) / (
torch.sum(mask_downsampled) * mu.size(1)
) # normalize by valid timesteps * latent_dim
# Total loss
total_loss = (
self.LAMBDA_FEATURE * loss_recons
+ self.LAMBDA_VELOCITY * loss_vel
+ self.LAMBDA_KL * kl_loss
)
loss_dict = {}
loss_dict["total"] = total_loss
loss_dict["recons"] = loss_recons
loss_dict["velocity"] = loss_vel
loss_dict["kl"] = kl_loss
return loss_dict
def encode(self, x):
x = (x - self.mean) / self.std
x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T)
mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T)
mu = self.postprocess(mu) # (bs, T, z_dim)
return mu
def decode(self, mu):
mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T)
x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T)
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
x_out = x_out * self.std + self.mean
return x_out
@torch.no_grad()
def stream_encode(self, x, first_chunk=True):
x = (x - self.mean) / self.std
x_in = self.preprocess(x) # (bs, input_dim, T)
mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
mu = self.postprocess(mu) # (bs, T, z_dim)
return mu
@torch.no_grad()
def stream_decode(self, mu, first_chunk=True):
mu_in = self.preprocess(mu) # (bs, z_dim, T)
x_decoder = self.model.stream_decode(
mu_in, first_chunk=first_chunk, scale=[0, 1]
)
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
x_out = x_out * self.std + self.mean
return x_out
def clear_cache(self):
self.model.clear_cache()
def generate(self, x):
features = x["feature"]
feature_length = x["feature_length"]
y_hat = self.decode(self.encode(features))
y_hat_out = []
for i in range(y_hat.shape[0]):
# cut off the padding and align lengths
valid_len = (
feature_length[i] - 1
) // self.downsample_factor * self.downsample_factor + 1
# Make sure both have the same length (take minimum)
y_hat_out.append(y_hat[i, :valid_len, :])
out = {}
out["generated"] = y_hat_out
return out