Spaces:
Build error
Build error
''' | |
not exactly the same as the official repo but the results are good | |
''' | |
import sys | |
import os | |
from transformers import Wav2Vec2Processor | |
from .wav2vec import Wav2Vec2Model | |
from torchaudio.sox_effects import apply_effects_tensor | |
sys.path.append(os.getcwd()) | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio as ta | |
import math | |
from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu | |
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ | |
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): | |
""" | |
:param audio: 1 x T tensor containing a 16kHz audio signal | |
:param frame_rate: frame rate for video (we need one audio chunk per video frame) | |
:param chunk_size: number of audio samples per chunk | |
:return: num_chunks x chunk_size tensor containing sliced audio | |
""" | |
samples_per_frame = 16000 // frame_rate | |
padding = (chunk_size - samples_per_frame) // 2 | |
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) | |
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) | |
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) | |
return audio | |
class MeshtalkEncoder(nn.Module): | |
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): | |
""" | |
:param latent_dim: size of the latent audio embedding | |
:param model_name: name of the model, used to load and save the model | |
""" | |
super().__init__() | |
self.melspec = ta.transforms.MelSpectrogram( | |
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 | |
) | |
conv_len = 5 | |
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) | |
self.weights_init(self.convert_dimensions) | |
self.receptive_field = conv_len | |
convs = [] | |
for i in range(6): | |
dilation = 2 * (i % 3 + 1) | |
self.receptive_field += (conv_len - 1) * dilation | |
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] | |
self.weights_init(convs[-1]) | |
self.convs = torch.nn.ModuleList(convs) | |
self.code = torch.nn.Linear(128, latent_dim) | |
self.apply(lambda x: self.weights_init(x)) | |
def weights_init(self, m): | |
if isinstance(m, torch.nn.Conv1d): | |
torch.nn.init.xavier_uniform_(m.weight) | |
try: | |
torch.nn.init.constant_(m.bias, .01) | |
except: | |
pass | |
def forward(self, audio: torch.Tensor): | |
""" | |
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame | |
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding | |
""" | |
B, T = audio.shape[0], audio.shape[1] | |
x = self.melspec(audio).squeeze(1) | |
x = torch.log(x.clamp(min=1e-10, max=None)) | |
if T == 1: | |
x = x.unsqueeze(1) | |
# Convert to the right dimensionality | |
x = x.view(-1, x.shape[2], x.shape[3]) | |
x = F.leaky_relu(self.convert_dimensions(x), .2) | |
# Process stacks | |
for conv in self.convs: | |
x_ = F.leaky_relu(conv(x), .2) | |
if self.training: | |
x_ = F.dropout(x_, .2) | |
l = (x.shape[2] - x_.shape[2]) // 2 | |
x = (x[:, :, l:-l] + x_) / 2 | |
x = torch.mean(x, dim=-1) | |
x = x.view(B, T, x.shape[-1]) | |
x = self.code(x) | |
return {"code": x} | |
class AudioEncoder(nn.Module): | |
def __init__(self, in_dim, out_dim, identity=False, num_classes=0): | |
super().__init__() | |
self.identity = identity | |
if self.identity: | |
in_dim = in_dim + 64 | |
self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1) | |
self.first_net = SeqTranslator1D(in_dim, out_dim, | |
min_layers_num=3, | |
residual=True, | |
norm='ln' | |
) | |
self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True) | |
self.dropout = nn.Dropout(0.1) | |
# self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True) | |
def forward(self, spectrogram, pre_state=None, id=None, time_steps=None): | |
spectrogram = spectrogram | |
spectrogram = self.dropout(spectrogram) | |
if self.identity: | |
id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32) | |
id = self.id_mlp(id) | |
spectrogram = torch.cat([spectrogram, id], dim=1) | |
x1 = self.first_net(spectrogram)# .permute(0, 2, 1) | |
if time_steps is not None: | |
x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear') | |
# x1, _ = self.att(x1, x1, x1) | |
# x1, hidden_state = self.grus(x1) | |
# x1 = x1.permute(0, 2, 1) | |
hidden_state=None | |
return x1, hidden_state | |
class Generator(nn.Module): | |
def __init__(self, | |
n_poses, | |
each_dim: list, | |
dim_list: list, | |
training=False, | |
device=None, | |
identity=True, | |
num_classes=0, | |
): | |
super().__init__() | |
self.training = training | |
self.device = device | |
self.gen_length = n_poses | |
self.identity = identity | |
norm = 'ln' | |
in_dim = 256 | |
out_dim = 256 | |
self.encoder_choice = 'faceformer' | |
if self.encoder_choice == 'meshtalk': | |
self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim) | |
elif self.encoder_choice == 'faceformer': | |
# wav2vec 2.0 weights initialization | |
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" | |
self.audio_encoder.feature_extractor._freeze_parameters() | |
self.audio_feature_map = nn.Linear(768, in_dim) | |
else: | |
self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim) | |
self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes) | |
self.dim_list = dim_list | |
self.decoder = nn.ModuleList() | |
self.final_out = nn.ModuleList() | |
self.decoder.append(nn.Sequential( | |
ConvNormRelu(out_dim, 64, norm=norm), | |
ConvNormRelu(64, 64, norm=norm), | |
ConvNormRelu(64, 64, norm=norm), | |
)) | |
self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) | |
self.decoder.append(nn.Sequential( | |
ConvNormRelu(out_dim, out_dim, norm=norm), | |
ConvNormRelu(out_dim, out_dim, norm=norm), | |
ConvNormRelu(out_dim, out_dim, norm=norm), | |
)) | |
self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1)) | |
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): | |
if self.training: | |
time_steps = gt_poses.shape[1] | |
# vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) | |
if self.encoder_choice == 'meshtalk': | |
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) | |
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) | |
elif self.encoder_choice == 'faceformer': | |
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state | |
feature = self.audio_feature_map(hidden_states).transpose(1, 2) | |
else: | |
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) | |
# hidden_states = in_spec | |
feature, _ = self.audio_middle(feature, id=id) | |
out = [] | |
for i in range(self.decoder.__len__()): | |
mid = self.decoder[i](feature) | |
mid = self.final_out[i](mid) | |
out.append(mid) | |
out = torch.cat(out, dim=1) | |
out = out.transpose(1, 2) | |
return out, None | |