File size: 3,587 Bytes
b256b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import WhisperModel
from torch import nn
import torch
from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas

class WhisperAudioClassifier(nn.Module):
    def __init__(self):
        super(WhisperAudioClassifier, self).__init__()
        # Load the Whisper model encoder
        self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder()
        self.whisper_encoder.eval()  # Set the Whisper model to evaluation mode

        # Assuming we know the output size of the Whisper encoder, or it needs to be determined
        whisper_output_size = 768

        self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
        self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
        self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
        self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)

        self.pool = nn.AdaptiveAvgPool1d(1)

        # Separate output layers for each class set
        self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials))
        self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
        self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas))
        self.tone_fc1 = nn.Linear(whisper_output_size, 6)

        self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials))
        self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
        self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas))
        self.tone_fc2 = nn.Linear(whisper_output_size, 6)
        
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # Use Whisper model to encode audio input
        with torch.no_grad():  # No need to track gradients for the encoder
            x = self.whisper_encoder(x).last_hidden_state

        initial, _ = self.initial_attention(x, x, x, need_weights=False)
        initial = initial.permute(0, 2, 1)  # [batch_size, channels, seq_len]
        initial = self.pool(initial)  # [batch_size, channels, 1]
        initial = initial.squeeze(-1) # [batch_size, channels]
        initial_out1 = self.initial_fc1(initial)
        initial_out2 = self.initial_fc2(initial)

        nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False)
        nucleus = nucleus.permute(0, 2, 1)  # [batch_size, channels, seq_len]
        nucleus = self.pool(nucleus)  # [batch_size, channels, 1]
        nucleus = nucleus.squeeze(-1) # [batch_size, channels]
        nucleus_out1 = self.nucleus_fc1(nucleus)
        nucleus_out2 = self.nucleus_fc2(nucleus)

        coda, _ = self.coda_attention(x, x, x, need_weights=False)
        coda = coda.permute(0, 2, 1)  # [batch_size, channels, seq_len]
        coda = self.pool(coda)  # [batch_size, channels, 1]
        coda = coda.squeeze(-1) # [batch_size, channels]
        coda_out1 = self.coda_fc1(coda)
        coda_out2 = self.coda_fc2(coda)

        tone, _ = self.tone_attention(x, x, x, need_weights=False)
        tone = tone.permute(0, 2, 1)  # [batch_size, channels, seq_len]
        tone = self.pool(tone)  # [batch_size, channels, 1]
        tone = tone.squeeze(-1) # [batch_size, channels]
        tone_out1 = self.tone_fc1(tone)
        tone_out2 = self.tone_fc2(tone)
        
        return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2