|
import torch |
|
from torch import nn |
|
from src.audio2pose_models.cvae import CVAE |
|
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator |
|
from src.audio2pose_models.audio_encoder import AudioEncoder |
|
|
|
class Audio2Pose(nn.Module): |
|
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN |
|
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE |
|
self.device = device |
|
|
|
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) |
|
self.audio_encoder.eval() |
|
for param in self.audio_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
self.netG = CVAE(cfg) |
|
self.netD_motion = PoseSequenceDiscriminator(cfg) |
|
|
|
|
|
def forward(self, x): |
|
|
|
batch = {} |
|
coeff_gt = x['gt'].cuda().squeeze(0) |
|
batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] |
|
batch['ref'] = coeff_gt[:, 0, -9:-3] |
|
batch['class'] = x['class'].squeeze(0).cuda() |
|
indiv_mels= x['indiv_mels'].cuda().squeeze(0) |
|
|
|
|
|
audio_emb_list = [] |
|
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) |
|
batch['audio_emb'] = audio_emb |
|
batch = self.netG(batch) |
|
|
|
pose_motion_pred = batch['pose_motion_pred'] |
|
pose_gt = coeff_gt[:, 1:, -9:-3].clone() |
|
pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred |
|
|
|
batch['pose_pred'] = pose_pred |
|
batch['pose_gt'] = pose_gt |
|
|
|
return batch |
|
|
|
def test(self, x): |
|
|
|
batch = {} |
|
ref = x['ref'] |
|
batch['ref'] = x['ref'][:,0,-6:] |
|
batch['class'] = x['class'] |
|
bs = ref.shape[0] |
|
|
|
indiv_mels= x['indiv_mels'] |
|
indiv_mels_use = indiv_mels[:, 1:] |
|
num_frames = x['num_frames'] |
|
num_frames = int(num_frames) - 1 |
|
|
|
|
|
div = num_frames//self.seq_len |
|
re = num_frames%self.seq_len |
|
audio_emb_list = [] |
|
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, |
|
device=batch['ref'].device)] |
|
|
|
for i in range(div): |
|
z = torch.randn(bs, self.latent_dim).to(ref.device) |
|
batch['z'] = z |
|
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) |
|
batch['audio_emb'] = audio_emb |
|
batch = self.netG.test(batch) |
|
pose_motion_pred_list.append(batch['pose_motion_pred']) |
|
|
|
if re != 0: |
|
z = torch.randn(bs, self.latent_dim).to(ref.device) |
|
batch['z'] = z |
|
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) |
|
if audio_emb.shape[1] != self.seq_len: |
|
pad_dim = self.seq_len-audio_emb.shape[1] |
|
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) |
|
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) |
|
batch['audio_emb'] = audio_emb |
|
batch = self.netG.test(batch) |
|
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) |
|
|
|
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) |
|
batch['pose_motion_pred'] = pose_motion_pred |
|
|
|
pose_pred = ref[:, :1, -6:] + pose_motion_pred |
|
|
|
batch['pose_pred'] = pose_pred |
|
return batch |
|
|