wxDai's picture
init
6b1e9f7
raw
history blame
5.65 kB
from typing import Optional
import torch
import torch.nn as nn
from torch.distributions.distribution import Distribution
from mld.models.operator.cross_attention import (
SkipTransformerEncoder,
SkipTransformerDecoder,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from mld.models.operator.position_encoding import build_position_encoding
from mld.utils.temos_utils import lengths_to_mask
class MldVae(nn.Module):
def __init__(self,
nfeats: int,
latent_dim: list = [1, 256],
ff_size: int = 1024,
num_layers: int = 9,
num_heads: int = 4,
dropout: float = 0.1,
arch: str = "encoder_decoder",
normalize_before: bool = False,
activation: str = "gelu",
position_embedding: str = "learned") -> None:
super().__init__()
self.latent_size = latent_dim[0]
self.latent_dim = latent_dim[-1]
input_feats = nfeats
output_feats = nfeats
self.arch = arch
self.query_pos_encoder = build_position_encoding(
self.latent_dim, position_embedding=position_embedding)
encoder_layer = TransformerEncoderLayer(
self.latent_dim,
num_heads,
ff_size,
dropout,
activation,
normalize_before,
)
encoder_norm = nn.LayerNorm(self.latent_dim)
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers,
encoder_norm)
if self.arch == "all_encoder":
decoder_norm = nn.LayerNorm(self.latent_dim)
self.decoder = SkipTransformerEncoder(encoder_layer, num_layers,
decoder_norm)
elif self.arch == 'encoder_decoder':
self.query_pos_decoder = build_position_encoding(
self.latent_dim, position_embedding=position_embedding)
decoder_layer = TransformerDecoderLayer(
self.latent_dim,
num_heads,
ff_size,
dropout,
activation,
normalize_before,
)
decoder_norm = nn.LayerNorm(self.latent_dim)
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers,
decoder_norm)
else:
raise ValueError(f"Not support architecture: {self.arch}!")
self.global_motion_token = nn.Parameter(
torch.randn(self.latent_size * 2, self.latent_dim))
self.skel_embedding = nn.Linear(input_feats, self.latent_dim)
self.final_layer = nn.Linear(self.latent_dim, output_feats)
def forward(self, features: torch.Tensor,
lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, torch.Tensor, Distribution]:
z, dist = self.encode(features, lengths)
feats_rst = self.decode(z, lengths)
return feats_rst, z, dist
def encode(self, features: torch.Tensor,
lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, Distribution]:
if lengths is None:
lengths = [len(feature) for feature in features]
device = features.device
bs, nframes, nfeats = features.shape
mask = lengths_to_mask(lengths, device)
x = features
# Embed each human poses into latent vectors
x = self.skel_embedding(x)
# Switch sequence and batch_size because the input of
# Pytorch Transformer is [Sequence, Batch size, ...]
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
# Each batch has its own set of tokens
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
# create a bigger mask, to allow attend to emb
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
aug_mask = torch.cat((dist_masks, mask), 1)
# adding the embedding token for all sequences
xseq = torch.cat((dist, x), 0)
xseq = self.query_pos_encoder(xseq)
dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]]
mu = dist[0:self.latent_size, ...]
logvar = dist[self.latent_size:, ...]
# resampling
std = logvar.exp().pow(0.5)
dist = torch.distributions.Normal(mu, std)
latent = dist.rsample()
return latent, dist
def decode(self, z: torch.Tensor, lengths: list[int]) -> torch.Tensor:
mask = lengths_to_mask(lengths, z.device)
bs, nframes = mask.shape
queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
if self.arch == "all_encoder":
xseq = torch.cat((z, queries), axis=0)
z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device)
aug_mask = torch.cat((z_mask, mask), axis=1)
xseq = self.query_pos_decoder(xseq)
output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[z.shape[0]:]
elif self.arch == "encoder_decoder":
queries = self.query_pos_decoder(queries)
output = self.decoder(
tgt=queries,
memory=z,
tgt_key_padding_mask=~mask)
output = self.final_layer(output)
# zero for padded area
output[~mask.T] = 0
# Pytorch Transformer: [Sequence, Batch size, ...]
feats = output.permute(1, 0, 2)
return feats