|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .tools.wan_vae_1d import WanVAE_ |
|
|
|
|
|
|
|
|
class VAEWanModel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_dim, |
|
|
mean_path=None, |
|
|
std_path=None, |
|
|
z_dim=256, |
|
|
dim=160, |
|
|
dec_dim=512, |
|
|
num_res_blocks=1, |
|
|
dropout=0.0, |
|
|
dim_mult=[1, 1, 1], |
|
|
temperal_downsample=[True, True], |
|
|
vel_window=[0, 0], |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.mean_path = mean_path |
|
|
self.std_path = std_path |
|
|
self.input_dim = input_dim |
|
|
self.z_dim = z_dim |
|
|
self.dim = dim |
|
|
self.dec_dim = dec_dim |
|
|
self.num_res_blocks = num_res_blocks |
|
|
self.dropout = dropout |
|
|
self.dim_mult = dim_mult |
|
|
self.temperal_downsample = temperal_downsample |
|
|
self.vel_window = vel_window |
|
|
self.RECONS_LOSS = nn.SmoothL1Loss() |
|
|
self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0) |
|
|
self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5) |
|
|
self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6) |
|
|
|
|
|
if self.mean_path is not None: |
|
|
self.register_buffer( |
|
|
"mean", torch.from_numpy(np.load(self.mean_path)).float() |
|
|
) |
|
|
else: |
|
|
self.register_buffer("mean", torch.zeros(input_dim)) |
|
|
|
|
|
if self.std_path is not None: |
|
|
self.register_buffer( |
|
|
"std", torch.from_numpy(np.load(self.std_path)).float() |
|
|
) |
|
|
else: |
|
|
self.register_buffer("std", torch.ones(input_dim)) |
|
|
|
|
|
self.model = WanVAE_( |
|
|
input_dim=self.input_dim, |
|
|
dim=self.dim, |
|
|
dec_dim=self.dec_dim, |
|
|
z_dim=self.z_dim, |
|
|
dim_mult=self.dim_mult, |
|
|
num_res_blocks=self.num_res_blocks, |
|
|
temperal_downsample=self.temperal_downsample, |
|
|
dropout=self.dropout, |
|
|
) |
|
|
|
|
|
downsample_factor = 1 |
|
|
for flag in self.temperal_downsample: |
|
|
if flag: |
|
|
downsample_factor *= 2 |
|
|
self.downsample_factor = downsample_factor |
|
|
|
|
|
def preprocess(self, x): |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
return x |
|
|
|
|
|
def postprocess(self, x): |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
features = x["feature"] |
|
|
feature_length = x["feature_length"] |
|
|
features = (features - self.mean) / self.std |
|
|
|
|
|
batch_size, seq_len = features.shape[:2] |
|
|
mask = torch.zeros( |
|
|
batch_size, seq_len, dtype=torch.bool, device=features.device |
|
|
) |
|
|
for i in range(batch_size): |
|
|
mask[i, : feature_length[i]] = True |
|
|
|
|
|
x_in = self.preprocess(features) |
|
|
mu, log_var = self.model.encode( |
|
|
x_in, scale=[0, 1], return_dist=True |
|
|
) |
|
|
z = self.model.reparameterize(mu, log_var) |
|
|
x_decoder = self.model.decode(z, scale=[0, 1]) |
|
|
x_out = self.postprocess(x_decoder) |
|
|
|
|
|
if x_out.size(1) != features.size(1): |
|
|
min_len = min(x_out.size(1), features.size(1)) |
|
|
x_out = x_out[:, :min_len, :] |
|
|
features = features[:, :min_len, :] |
|
|
mask = mask[:, :min_len] |
|
|
|
|
|
mask_expanded = mask.unsqueeze(-1) |
|
|
x_out_masked = x_out * mask_expanded |
|
|
features_masked = features * mask_expanded |
|
|
loss_recons = self.RECONS_LOSS(x_out_masked, features_masked) |
|
|
vel_start = self.vel_window[0] |
|
|
vel_end = self.vel_window[1] |
|
|
loss_vel = self.RECONS_LOSS( |
|
|
x_out_masked[..., vel_start:vel_end], |
|
|
features_masked[..., vel_start:vel_end], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T_latent = mu.size(2) |
|
|
mask_downsampled = torch.zeros( |
|
|
batch_size, T_latent, dtype=torch.bool, device=features.device |
|
|
) |
|
|
for i in range(batch_size): |
|
|
latent_length = ( |
|
|
feature_length[i] + self.downsample_factor - 1 |
|
|
) // self.downsample_factor |
|
|
mask_downsampled[i, :latent_length] = True |
|
|
mask_latent = mask_downsampled.unsqueeze(1) |
|
|
|
|
|
|
|
|
kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) |
|
|
|
|
|
kl_masked = kl_per_element * mask_latent |
|
|
|
|
|
kl_loss = torch.sum(kl_masked) / ( |
|
|
torch.sum(mask_downsampled) * mu.size(1) |
|
|
) |
|
|
|
|
|
|
|
|
total_loss = ( |
|
|
self.LAMBDA_FEATURE * loss_recons |
|
|
+ self.LAMBDA_VELOCITY * loss_vel |
|
|
+ self.LAMBDA_KL * kl_loss |
|
|
) |
|
|
|
|
|
loss_dict = {} |
|
|
loss_dict["total"] = total_loss |
|
|
loss_dict["recons"] = loss_recons |
|
|
loss_dict["velocity"] = loss_vel |
|
|
loss_dict["kl"] = kl_loss |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
def encode(self, x): |
|
|
x = (x - self.mean) / self.std |
|
|
x_in = self.preprocess(x) |
|
|
mu = self.model.encode(x_in, scale=[0, 1]) |
|
|
mu = self.postprocess(mu) |
|
|
return mu |
|
|
|
|
|
def decode(self, mu): |
|
|
mu_in = self.preprocess(mu) |
|
|
x_decoder = self.model.decode(mu_in, scale=[0, 1]) |
|
|
x_out = self.postprocess(x_decoder) |
|
|
x_out = x_out * self.std + self.mean |
|
|
return x_out |
|
|
|
|
|
@torch.no_grad() |
|
|
def stream_encode(self, x, first_chunk=True): |
|
|
x = (x - self.mean) / self.std |
|
|
x_in = self.preprocess(x) |
|
|
mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1]) |
|
|
mu = self.postprocess(mu) |
|
|
return mu |
|
|
|
|
|
@torch.no_grad() |
|
|
def stream_decode(self, mu, first_chunk=True): |
|
|
mu_in = self.preprocess(mu) |
|
|
x_decoder = self.model.stream_decode( |
|
|
mu_in, first_chunk=first_chunk, scale=[0, 1] |
|
|
) |
|
|
x_out = self.postprocess(x_decoder) |
|
|
x_out = x_out * self.std + self.mean |
|
|
return x_out |
|
|
|
|
|
def clear_cache(self): |
|
|
self.model.clear_cache() |
|
|
|
|
|
def generate(self, x): |
|
|
features = x["feature"] |
|
|
feature_length = x["feature_length"] |
|
|
y_hat = self.decode(self.encode(features)) |
|
|
|
|
|
y_hat_out = [] |
|
|
|
|
|
for i in range(y_hat.shape[0]): |
|
|
|
|
|
valid_len = ( |
|
|
feature_length[i] - 1 |
|
|
) // self.downsample_factor * self.downsample_factor + 1 |
|
|
|
|
|
y_hat_out.append(y_hat[i, :valid_len, :]) |
|
|
|
|
|
out = {} |
|
|
out["generated"] = y_hat_out |
|
|
return out |
|
|
|