Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import Wav2Vec2Config | |
from .torch_utils import get_mask_from_lengths | |
from .wav2vec2 import Wav2Vec2Model | |
class Audio2MeshModel(nn.Module): | |
def __init__( | |
self, | |
config | |
): | |
super().__init__() | |
out_dim = config['out_dim'] | |
latent_dim = config['latent_dim'] | |
model_path = config['model_path'] | |
only_last_fetures = config['only_last_fetures'] | |
from_pretrained = config['from_pretrained'] | |
self._only_last_features = only_last_fetures | |
self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) | |
if from_pretrained: | |
self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) | |
else: | |
self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) | |
self.audio_encoder.feature_extractor._freeze_parameters() | |
hidden_size = self.audio_encoder_config.hidden_size | |
self.in_fn = nn.Linear(hidden_size, latent_dim) | |
self.out_fn = nn.Linear(latent_dim, out_dim) | |
nn.init.constant_(self.out_fn.weight, 0) | |
nn.init.constant_(self.out_fn.bias, 0) | |
def forward(self, audio, label, audio_len=None): | |
attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None | |
seq_len = label.shape[1] | |
embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True, | |
attention_mask=attention_mask) | |
if self._only_last_features: | |
hidden_states = embeddings.last_hidden_state | |
else: | |
hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) | |
layer_in = self.in_fn(hidden_states) | |
out = self.out_fn(layer_in) | |
return out, None | |
def infer(self, input_value, seq_len): | |
embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) | |
if self._only_last_features: | |
hidden_states = embeddings.last_hidden_state | |
else: | |
hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) | |
layer_in = self.in_fn(hidden_states) | |
out = self.out_fn(layer_in) | |
return out | |