Spaces:
Runtime error
Runtime error
File size: 1,492 Bytes
c23173c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from .transcription import SpeechEncoder
from .sentiment import TextEncoder
import torch
import torch.nn as nn
class MultimodalSentimentClassifier(nn.Module):
def __init__(
self,
wav2vec_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-french",
#wav2vec_name: str = "alec228/audio-sentiment/tree/main/wav2vec2",
bert_name: str = "nlptown/bert-base-multilingual-uncased-sentiment",
#bert_name: str = "alec228/audio-sentiment/tree/main/bert-sentiment",
#cache_dir: str = "./models",
hidden_dim: int = 256,
n_classes: int = 3
):
super().__init__()
self.speech_encoder = SpeechEncoder(
model_name=wav2vec_name,
# cache_dir=cache_dir
)
self.text_encoder = TextEncoder(
model_name=bert_name,
# cache_dir=cache_dir
)
dim_a = self.speech_encoder.model.config.hidden_size
dim_t = self.text_encoder.model.config.hidden_size
self.classifier = nn.Sequential(
nn.Linear(dim_a + dim_t, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, n_classes)
)
def forward(self, audio_path: str, text: str) -> torch.Tensor:
a_feat = self.speech_encoder.extract_features(audio_path)
t_feat = self.text_encoder.extract_features([text])
fused = torch.cat([a_feat, t_feat], dim=1)
return self.classifier(fused)
|