vinthony's picture
fixed req
a86a2b8
raw
history blame
3.76 kB
import torch
from torch import nn
from src.audio2pose_models.cvae import CVAE
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
from src.audio2pose_models.audio_encoder import AudioEncoder
class Audio2Pose(nn.Module):
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
super().__init__()
self.cfg = cfg
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
self.device = device
self.audio_encoder = AudioEncoder(wav2lip_checkpoint)
self.audio_encoder.eval()
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.netG = CVAE(cfg)
self.netD_motion = PoseSequenceDiscriminator(cfg)
self.gan_criterion = nn.MSELoss()
self.reg_criterion = nn.L1Loss(reduction='none')
self.pair_criterion = nn.PairwiseDistance()
self.cosine_loss = nn.CosineSimilarity(dim=1)
def forward(self, x):
batch = {}
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
batch['class'] = x['class'].squeeze(0).cuda() # bs
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
# forward
audio_emb_list = []
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG(batch)
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
batch['pose_pred'] = pose_pred
batch['pose_gt'] = pose_gt
return batch
def test(self, x):
batch = {}
ref = x['ref'] #bs 1 70
batch['ref'] = x['ref'][:,0,-6:]
batch['class'] = x['class']
bs = ref.shape[0]
indiv_mels= x['indiv_mels'] # bs T 1 80 16
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
num_frames = x['num_frames']
num_frames = int(num_frames) - 1
#
div = num_frames//self.seq_len
re = num_frames%self.seq_len
audio_emb_list = []
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
device=batch['ref'].device)]
for i in range(div):
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
if re != 0:
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
batch['pose_motion_pred'] = pose_motion_pred
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
batch['pose_pred'] = pose_pred
return batch