Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from src.audio2pose_models.res_unet import ResUnet | |
| def class2onehot(idx, class_num): | |
| assert torch.max(idx).item() < class_num | |
| onehot = torch.zeros(idx.size(0), class_num).to(idx.device) | |
| onehot.scatter_(1, idx, 1) | |
| return onehot | |
| class CVAE(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES | |
| decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES | |
| latent_size = cfg.MODEL.CVAE.LATENT_SIZE | |
| num_classes = cfg.DATASET.NUM_CLASSES | |
| audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE | |
| audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE | |
| seq_len = cfg.MODEL.CVAE.SEQ_LEN | |
| self.latent_size = latent_size | |
| self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, | |
| audio_emb_in_size, audio_emb_out_size, seq_len) | |
| self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, | |
| audio_emb_in_size, audio_emb_out_size, seq_len) | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, batch): | |
| batch = self.encoder(batch) | |
| mu = batch['mu'] | |
| logvar = batch['logvar'] | |
| z = self.reparameterize(mu, logvar) | |
| batch['z'] = z | |
| return self.decoder(batch) | |
| def test(self, batch): | |
| ''' | |
| class_id = batch['class'] | |
| z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) | |
| batch['z'] = z | |
| ''' | |
| return self.decoder(batch) | |
| class ENCODER(nn.Module): | |
| def __init__(self, layer_sizes, latent_size, num_classes, | |
| audio_emb_in_size, audio_emb_out_size, seq_len): | |
| super().__init__() | |
| self.resunet = ResUnet() | |
| self.num_classes = num_classes | |
| self.seq_len = seq_len | |
| self.MLP = nn.Sequential() | |
| layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 | |
| for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): | |
| self.MLP.add_module( | |
| name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) | |
| self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) | |
| self.linear_means = nn.Linear(layer_sizes[-1], latent_size) | |
| self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) | |
| self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) | |
| self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) | |
| def forward(self, batch): | |
| class_id = batch['class'] | |
| pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 | |
| ref = batch['ref'] #bs 6 | |
| bs = pose_motion_gt.shape[0] | |
| audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size | |
| #pose encode | |
| pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 | |
| pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 | |
| #audio mapping | |
| print(audio_in.shape) | |
| audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size | |
| audio_out = audio_out.reshape(bs, -1) | |
| class_bias = self.classbias[class_id] #bs latent_size | |
| x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size | |
| x_out = self.MLP(x_in) | |
| mu = self.linear_means(x_out) | |
| logvar = self.linear_means(x_out) #bs latent_size | |
| batch.update({'mu':mu, 'logvar':logvar}) | |
| return batch | |
| class DECODER(nn.Module): | |
| def __init__(self, layer_sizes, latent_size, num_classes, | |
| audio_emb_in_size, audio_emb_out_size, seq_len): | |
| super().__init__() | |
| self.resunet = ResUnet() | |
| self.num_classes = num_classes | |
| self.seq_len = seq_len | |
| self.MLP = nn.Sequential() | |
| input_size = latent_size + seq_len*audio_emb_out_size + 6 | |
| for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): | |
| self.MLP.add_module( | |
| name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) | |
| if i+1 < len(layer_sizes): | |
| self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) | |
| else: | |
| self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) | |
| self.pose_linear = nn.Linear(6, 6) | |
| self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) | |
| self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) | |
| def forward(self, batch): | |
| z = batch['z'] #bs latent_size | |
| bs = z.shape[0] | |
| class_id = batch['class'] | |
| ref = batch['ref'] #bs 6 | |
| audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size | |
| #print('audio_in: ', audio_in[:, :, :10]) | |
| audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size | |
| #print('audio_out: ', audio_out[:, :, :10]) | |
| audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size | |
| class_bias = self.classbias[class_id] #bs latent_size | |
| z = z + class_bias | |
| x_in = torch.cat([ref, z, audio_out], dim=-1) | |
| x_out = self.MLP(x_in) # bs layer_sizes[-1] | |
| x_out = x_out.reshape((bs, self.seq_len, -1)) | |
| #print('x_out: ', x_out) | |
| pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 | |
| pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 | |
| batch.update({'pose_motion_pred':pose_motion_pred}) | |
| return batch | |