cantone / whisper_audio_classifier.py
AlienKevin's picture
Add app.py and model
b256b6f
raw
history blame
No virus
3.59 kB
from transformers import WhisperModel
from torch import nn
import torch
from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas
class WhisperAudioClassifier(nn.Module):
def __init__(self):
super(WhisperAudioClassifier, self).__init__()
# Load the Whisper model encoder
self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder()
self.whisper_encoder.eval() # Set the Whisper model to evaluation mode
# Assuming we know the output size of the Whisper encoder, or it needs to be determined
whisper_output_size = 768
self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
self.pool = nn.AdaptiveAvgPool1d(1)
# Separate output layers for each class set
self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials))
self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas))
self.tone_fc1 = nn.Linear(whisper_output_size, 6)
self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials))
self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas))
self.tone_fc2 = nn.Linear(whisper_output_size, 6)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
# Use Whisper model to encode audio input
with torch.no_grad(): # No need to track gradients for the encoder
x = self.whisper_encoder(x).last_hidden_state
initial, _ = self.initial_attention(x, x, x, need_weights=False)
initial = initial.permute(0, 2, 1) # [batch_size, channels, seq_len]
initial = self.pool(initial) # [batch_size, channels, 1]
initial = initial.squeeze(-1) # [batch_size, channels]
initial_out1 = self.initial_fc1(initial)
initial_out2 = self.initial_fc2(initial)
nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False)
nucleus = nucleus.permute(0, 2, 1) # [batch_size, channels, seq_len]
nucleus = self.pool(nucleus) # [batch_size, channels, 1]
nucleus = nucleus.squeeze(-1) # [batch_size, channels]
nucleus_out1 = self.nucleus_fc1(nucleus)
nucleus_out2 = self.nucleus_fc2(nucleus)
coda, _ = self.coda_attention(x, x, x, need_weights=False)
coda = coda.permute(0, 2, 1) # [batch_size, channels, seq_len]
coda = self.pool(coda) # [batch_size, channels, 1]
coda = coda.squeeze(-1) # [batch_size, channels]
coda_out1 = self.coda_fc1(coda)
coda_out2 = self.coda_fc2(coda)
tone, _ = self.tone_attention(x, x, x, need_weights=False)
tone = tone.permute(0, 2, 1) # [batch_size, channels, seq_len]
tone = self.pool(tone) # [batch_size, channels, 1]
tone = tone.squeeze(-1) # [batch_size, channels]
tone_out1 = self.tone_fc1(tone)
tone_out2 = self.tone_fc2(tone)
return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2