|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from src.audio2pose_models.res_unet import ResUnet
|
|
|
|
def class2onehot(idx, class_num):
|
|
|
|
assert torch.max(idx).item() < class_num
|
|
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
|
|
onehot.scatter_(1, idx, 1)
|
|
return onehot
|
|
|
|
class CVAE(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
|
|
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
|
|
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
|
|
num_classes = cfg.DATASET.NUM_CLASSES
|
|
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
|
|
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
|
|
seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
|
|
|
self.latent_size = latent_size
|
|
|
|
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
|
|
audio_emb_in_size, audio_emb_out_size, seq_len)
|
|
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
|
|
audio_emb_in_size, audio_emb_out_size, seq_len)
|
|
def reparameterize(self, mu, logvar):
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
return mu + eps * std
|
|
|
|
def forward(self, batch):
|
|
batch = self.encoder(batch)
|
|
mu = batch['mu']
|
|
logvar = batch['logvar']
|
|
z = self.reparameterize(mu, logvar)
|
|
batch['z'] = z
|
|
return self.decoder(batch)
|
|
|
|
def test(self, batch):
|
|
'''
|
|
class_id = batch['class']
|
|
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
|
|
batch['z'] = z
|
|
'''
|
|
return self.decoder(batch)
|
|
|
|
class ENCODER(nn.Module):
|
|
def __init__(self, layer_sizes, latent_size, num_classes,
|
|
audio_emb_in_size, audio_emb_out_size, seq_len):
|
|
super().__init__()
|
|
|
|
self.resunet = ResUnet()
|
|
self.num_classes = num_classes
|
|
self.seq_len = seq_len
|
|
|
|
self.MLP = nn.Sequential()
|
|
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
|
|
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
|
|
self.MLP.add_module(
|
|
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
|
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
|
|
|
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
|
|
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
|
|
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
|
|
|
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
|
|
|
def forward(self, batch):
|
|
class_id = batch['class']
|
|
pose_motion_gt = batch['pose_motion_gt']
|
|
ref = batch['ref']
|
|
bs = pose_motion_gt.shape[0]
|
|
audio_in = batch['audio_emb']
|
|
|
|
|
|
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1))
|
|
pose_emb = pose_emb.reshape(bs, -1)
|
|
|
|
|
|
print(audio_in.shape)
|
|
audio_out = self.linear_audio(audio_in)
|
|
audio_out = audio_out.reshape(bs, -1)
|
|
|
|
class_bias = self.classbias[class_id]
|
|
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1)
|
|
x_out = self.MLP(x_in)
|
|
|
|
mu = self.linear_means(x_out)
|
|
logvar = self.linear_means(x_out)
|
|
|
|
batch.update({'mu':mu, 'logvar':logvar})
|
|
return batch
|
|
|
|
class DECODER(nn.Module):
|
|
def __init__(self, layer_sizes, latent_size, num_classes,
|
|
audio_emb_in_size, audio_emb_out_size, seq_len):
|
|
super().__init__()
|
|
|
|
self.resunet = ResUnet()
|
|
self.num_classes = num_classes
|
|
self.seq_len = seq_len
|
|
|
|
self.MLP = nn.Sequential()
|
|
input_size = latent_size + seq_len*audio_emb_out_size + 6
|
|
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
|
|
self.MLP.add_module(
|
|
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
|
if i+1 < len(layer_sizes):
|
|
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
|
else:
|
|
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
|
|
|
|
self.pose_linear = nn.Linear(6, 6)
|
|
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
|
|
|
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
|
|
|
def forward(self, batch):
|
|
|
|
z = batch['z']
|
|
bs = z.shape[0]
|
|
class_id = batch['class']
|
|
ref = batch['ref']
|
|
audio_in = batch['audio_emb']
|
|
|
|
|
|
audio_out = self.linear_audio(audio_in)
|
|
|
|
audio_out = audio_out.reshape([bs, -1])
|
|
class_bias = self.classbias[class_id]
|
|
|
|
z = z + class_bias
|
|
x_in = torch.cat([ref, z, audio_out], dim=-1)
|
|
x_out = self.MLP(x_in)
|
|
x_out = x_out.reshape((bs, self.seq_len, -1))
|
|
|
|
|
|
|
|
pose_emb = self.resunet(x_out.unsqueeze(1))
|
|
|
|
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1))
|
|
|
|
batch.update({'pose_motion_pred':pose_motion_pred})
|
|
return batch
|
|
|