|
import torch |
|
from torch import nn |
|
|
|
|
|
class Audio2Exp(nn.Module): |
|
def __init__(self, netG, cfg, device, prepare_training_loss=False): |
|
super(Audio2Exp, self).__init__() |
|
self.cfg = cfg |
|
self.device = device |
|
self.netG = netG.to(device) |
|
|
|
def test(self, batch): |
|
|
|
mel_input = batch['indiv_mels'] |
|
bs = mel_input.shape[0] |
|
T = mel_input.shape[1] |
|
|
|
ref = batch['ref'][:, :, :64].repeat((1,T,1)) |
|
ratio = batch['ratio_gt'] |
|
|
|
audiox = mel_input.view(-1, 1, 80, 16) |
|
exp_coeff_pred = self.netG(audiox, ref, ratio) |
|
|
|
|
|
results_dict = { |
|
'exp_coeff_pred': exp_coeff_pred |
|
} |
|
return results_dict |
|
|
|
|
|
|