File size: 1,492 Bytes
c23173c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from .transcription import SpeechEncoder
from .sentiment     import TextEncoder
import torch
import torch.nn as nn

class MultimodalSentimentClassifier(nn.Module):
    def __init__(
        self,
        wav2vec_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-french",
        #wav2vec_name: str = "alec228/audio-sentiment/tree/main/wav2vec2",
        bert_name:    str = "nlptown/bert-base-multilingual-uncased-sentiment",
        #bert_name:    str = "alec228/audio-sentiment/tree/main/bert-sentiment",
        #cache_dir:    str = "./models",
        hidden_dim:   int = 256,
        n_classes:    int = 3
    ):
        super().__init__()
        self.speech_encoder = SpeechEncoder(
            model_name=wav2vec_name,
      #      cache_dir=cache_dir
        )
        self.text_encoder = TextEncoder(
            model_name=bert_name,
          #  cache_dir=cache_dir
        )
        dim_a = self.speech_encoder.model.config.hidden_size
        dim_t = self.text_encoder.model.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(dim_a + dim_t, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, n_classes)
        )

    def forward(self, audio_path: str, text: str) -> torch.Tensor:
        a_feat = self.speech_encoder.extract_features(audio_path)
        t_feat = self.text_encoder.extract_features([text])
        fused  = torch.cat([a_feat, t_feat], dim=1)
        return self.classifier(fused)