EMAGE / models /motion_representation.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import smplx
import copy
from .motion_encoder import *
# ----------- AE, VAE ------------- #
class VAEConvZero(nn.Module):
def __init__(self, args):
super(VAEConvZero, self).__init__()
self.encoder = VQEncoderV5(args)
# self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV5(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
# embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(pre_latent)
return {
# "poses_feat":vq_latent,
# "embedding_loss":embedding_loss,
# "perplexity":perplexity,
"rec_pose": rec_pose
}
class VAEConv(nn.Module):
def __init__(self, args):
super(VAEConv, self).__init__()
self.encoder = VQEncoderV3(args)
self.decoder = VQDecoderV3(args)
self.fc_mu = nn.Linear(args.vae_length, args.vae_length)
self.fc_logvar = nn.Linear(args.vae_length, args.vae_length)
self.variational = args.variational
def forward(self, inputs):
pre_latent = self.encoder(inputs)
mu, logvar = None, None
if self.variational:
mu = self.fc_mu(pre_latent)
logvar = self.fc_logvar(pre_latent)
pre_latent = reparameterize(mu, logvar)
rec_pose = self.decoder(pre_latent)
return {
"poses_feat":pre_latent,
"rec_pose": rec_pose,
"pose_mu": mu,
"pose_logvar": logvar,
}
def map2latent(self, inputs):
pre_latent = self.encoder(inputs)
if self.variational:
mu = self.fc_mu(pre_latent)
logvar = self.fc_logvar(pre_latent)
pre_latent = reparameterize(mu, logvar)
return pre_latent
def decode(self, pre_latent):
rec_pose = self.decoder(pre_latent)
return rec_pose
class VAESKConv(VAEConv):
def __init__(self, args):
super(VAESKConv, self).__init__(args)
smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz'
smpl_data = np.load(smpl_fname, encoding='latin1')
parents = smpl_data['kintree_table'][0].astype(np.int32)
edges = build_edge_topology(parents)
self.encoder = LocalEncoder(args, edges)
self.decoder = VQDecoderV3(args)
class VAEConvMLP(VAEConv):
def __init__(self, args):
super(VAEConvMLP, self).__init__(args)
self.encoder = PoseEncoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length)
self.decoder = PoseDecoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length)
class VAELSTM(VAEConv):
def __init__(self, args):
super(VAELSTM, self).__init__(args)
pose_dim = args.vae_test_dim
feature_length = args.vae_length
self.encoder = PoseEncoderLSTM_Resnet(pose_dim, feature_length=feature_length)
self.decoder = PoseDecoderLSTM(pose_dim, feature_length=feature_length)
class VAETransformer(VAEConv):
def __init__(self, args):
super(VAETransformer, self).__init__(args)
self.encoder = Encoder_TRANSFORMER(args)
self.decoder = Decoder_TRANSFORMER(args)
# ----------- VQVAE --------------- #
class VQVAEConv(nn.Module):
def __init__(self, args):
super(VQVAEConv, self).__init__()
self.encoder = VQEncoderV3(args)
self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV3(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(vq_latent)
return {
"poses_feat":vq_latent,
"embedding_loss":embedding_loss,
"perplexity":perplexity,
"rec_pose": rec_pose
}
def map2index(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
return index
def map2latent(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
z_q = self.quantizer.get_codebook_entry(index)
return z_q
def decode(self, index):
z_q = self.quantizer.get_codebook_entry(index)
rec_pose = self.decoder(z_q)
return rec_pose
class VQVAESKConv(VQVAEConv):
def __init__(self, args):
super(VQVAESKConv, self).__init__(args)
smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz'
smpl_data = np.load(smpl_fname, encoding='latin1')
parents = smpl_data['kintree_table'][0].astype(np.int32)
edges = build_edge_topology(parents)
self.encoder = LocalEncoder(args, edges)
class VQVAEConvStride(nn.Module):
def __init__(self, args):
super(VQVAEConvStride, self).__init__()
self.encoder = VQEncoderV4(args)
self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV4(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(vq_latent)
return {
"poses_feat":vq_latent,
"embedding_loss":embedding_loss,
"perplexity":perplexity,
"rec_pose": rec_pose
}
def map2index(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
return index
def map2latent(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
z_q = self.quantizer.get_codebook_entry(index)
return z_q
def decode(self, index):
z_q = self.quantizer.get_codebook_entry(index)
rec_pose = self.decoder(z_q)
return rec_pose
class VQVAEConvZero(nn.Module):
def __init__(self, args):
super(VQVAEConvZero, self).__init__()
self.encoder = VQEncoderV5(args)
self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV5(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(vq_latent)
return {
"poses_feat":vq_latent,
"embedding_loss":embedding_loss,
"perplexity":perplexity,
"rec_pose": rec_pose
}
def map2index(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
return index
def map2latent(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
z_q = self.quantizer.get_codebook_entry(index)
return z_q
def decode(self, index):
z_q = self.quantizer.get_codebook_entry(index)
rec_pose = self.decoder(z_q)
return rec_pose
class VAEConvZero(nn.Module):
def __init__(self, args):
super(VAEConvZero, self).__init__()
self.encoder = VQEncoderV5(args)
# self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV5(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
# embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(pre_latent)
return {
# "poses_feat":vq_latent,
# "embedding_loss":embedding_loss,
# "perplexity":perplexity,
"rec_pose": rec_pose
}
# def map2index(self, inputs):
# pre_latent = self.encoder(inputs)
# index = self.quantizer.map2index(pre_latent)
# return index
# def map2latent(self, inputs):
# pre_latent = self.encoder(inputs)
# index = self.quantizer.map2index(pre_latent)
# z_q = self.quantizer.get_codebook_entry(index)
# return z_q
# def decode(self, index):
# z_q = self.quantizer.get_codebook_entry(index)
# rec_pose = self.decoder(z_q)
# return rec_pose
class VQVAEConvZero3(nn.Module):
def __init__(self, args):
super(VQVAEConvZero3, self).__init__()
self.encoder = VQEncoderV5(args)
self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV5(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(vq_latent)
return {
"poses_feat":vq_latent,
"embedding_loss":embedding_loss,
"perplexity":perplexity,
"rec_pose": rec_pose
}
def map2index(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
return index
def map2latent(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
z_q = self.quantizer.get_codebook_entry(index)
return z_q
def decode(self, index):
z_q = self.quantizer.get_codebook_entry(index)
rec_pose = self.decoder(z_q)
return rec_pose
class VQVAEConvZero2(nn.Module):
def __init__(self, args):
super(VQVAEConvZero2, self).__init__()
self.encoder = VQEncoderV5(args)
self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
self.decoder = VQDecoderV7(args)
def forward(self, inputs):
pre_latent = self.encoder(inputs)
# print(pre_latent.shape)
embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
rec_pose = self.decoder(vq_latent)
return {
"poses_feat":vq_latent,
"embedding_loss":embedding_loss,
"perplexity":perplexity,
"rec_pose": rec_pose
}
def map2index(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
return index
def map2latent(self, inputs):
pre_latent = self.encoder(inputs)
index = self.quantizer.map2index(pre_latent)
z_q = self.quantizer.get_codebook_entry(index)
return z_q
def decode(self, index):
z_q = self.quantizer.get_codebook_entry(index)
rec_pose = self.decoder(z_q)
return rec_pose
class VQVAE2(nn.Module):
def __init__(self, args):
super(VQVAE2, self).__init__()
# Bottom-level encoder and decoder
args_bottom = copy.deepcopy(args)
args_bottom.vae_layer = 2
self.bottom_encoder = VQEncoderV6(args_bottom)
self.bottom_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
args_bottom.vae_test_dim = args.vae_test_dim
self.bottom_decoder = VQDecoderV6(args_bottom)
# Top-level encoder and decoder
args_top = copy.deepcopy(args)
args_top.vae_layer = 3
args_top.vae_test_dim = args.vae_length
self.top_encoder = VQEncoderV3(args_top) # Adjust according to the top level's design
self.quantize_conv_t = nn.Conv1d(args.vae_length+args.vae_length, args.vae_length, 1)
self.top_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
# self.upsample_t_up = nn.Upsample(scale_factor=2, mode='nearest')
layers = [
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True)
]
self.upsample_t= nn.Sequential(*layers)
self.top_decoder = VQDecoderV3(args_top) # Adjust to handle top level features appropriately
def forward(self, inputs):
# Bottom-level processing
enc_b = self.bottom_encoder(inputs)
enc_t = self.top_encoder(enc_b)
#print(enc_b.shape, enc_t.shape)
top_embedding_loss, quant_t, _, top_perplexity = self.top_quantizer(enc_t)
#print(quant_t.shape)
dec_t = self.top_decoder(quant_t)
#print(dec_t.shape)
enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1)
#print(enc_b.shape)
quant_b = self.quantize_conv_t(enc_b).permute(0,2,1)
#print("5",quant_b.shape)
bottom_embedding_loss, quant_b, _, bottom_perplexity = self.bottom_quantizer(quant_b)
#print("6",quant_b.shape)
upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1)
#print("7",upsample_t.shape)
quant = torch.cat([upsample_t, quant_b], 2)
rec_pose = self.bottom_decoder(quant)
# print(quant_t.shape, quant_b.shape, rec_pose.shape)
return {
"poses_feat_top": quant_t,
"pose_feat_bottom": quant_b,
"embedding_loss":top_embedding_loss+bottom_embedding_loss,
#"perplexity":perplexity,
"rec_pose": rec_pose
}
def map2index(self, inputs):
enc_b = self.bottom_encoder(inputs)
enc_t = self.top_encoder(enc_b)
_, quant_t, _, _ = self.top_quantizer(enc_t)
top_index = self.top_quantizer.map2index(enc_t)
dec_t = self.top_decoder(quant_t)
enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1)
#print(enc_b.shape)
quant_b = self.quantize_conv_t(enc_b).permute(0,2,1)
# quant_b = self.quantize_conv_t(enc_b)
bottom_index = self.bottom_quantizer.map2index(quant_b)
return top_index, bottom_index
def get_top_laent(self, top_index):
z_q_top = self.top_quantizer.get_codebook_entry(top_index)
return z_q_top
def map2latent(self, inputs):
enc_b = self.bottom_encoder(inputs)
enc_t = self.top_encoder(enc_b)
_, quant_t, _, _ = self.top_quantizer(enc_t)
top_index = self.top_quantizer.map2index(enc_t)
dec_t = self.top_decoder(quant_t)
enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1)
#print(enc_b.shape)
quant_b = self.quantize_conv_t(enc_b).permute(0,2,1)
# quant_b = self.quantize_conv_t(enc_b)
bottom_index = self.bottom_quantizer.map2index(quant_b)
z_q_top = self.top_quantizer.get_codebook_entry(top_index)
z_q_bottom = self.bottom_quantizer.get_codebook_entry(bottom_index)
return z_q_top, z_q_bottom
def map2latent_top(self, inputs):
enc_b = self.bottom_encoder(inputs)
enc_t = self.top_encoder(enc_b)
top_index = self.top_quantizer.map2index(enc_t)
z_q_top = self.top_quantizer.get_codebook_entry(top_index)
return z_q_top
def decode(self, top_index, bottom_index):
quant_t = self.top_quantizer.get_codebook_entry(top_index)
quant_b = self.bottom_quantizer.get_codebook_entry(bottom_index)
upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1)
#print("7",upsample_t.shape)
quant = torch.cat([upsample_t, quant_b], 2)
rec_pose = self.bottom_decoder(quant)
return rec_pose