| import torch |
| import torch.nn as nn |
| from transformers import Wav2Vec2Model, Wav2Vec2Config |
| import librosa |
| import numpy as np |
|
|
| class AudioEmotionModel(nn.Module): |
| """ |
| CNN + Transformer for audio emotion recognition. |
| Uses Wav2Vec2 backbone for feature extraction. |
| """ |
| def __init__(self, num_emotions=7, pretrained=True): |
| super().__init__() |
| self.num_emotions = num_emotions |
|
|
| |
| if pretrained: |
| self.wav2vec = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h') |
| else: |
| config = Wav2Vec2Config() |
| self.wav2vec = Wav2Vec2Model(config) |
|
|
| |
| for param in self.wav2vec.parameters(): |
| param.requires_grad = False |
|
|
| hidden_size = self.wav2vec.config.hidden_size |
|
|
| |
| self.cnn = nn.Sequential( |
| nn.Conv1d(hidden_size, 256, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv1d(256, 128, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool1d(1) |
| ) |
|
|
| |
| self.transformer = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer(d_model=128, nhead=8, dim_feedforward=512), |
| num_layers=4 |
| ) |
|
|
| |
| self.emotion_classifier = nn.Sequential( |
| nn.Linear(128, 64), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(64, num_emotions) |
| ) |
|
|
| |
| self.stress_head = nn.Sequential( |
| nn.Linear(128, 32), |
| nn.ReLU(), |
| nn.Linear(32, 1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, input_values): |
| """ |
| input_values: batch of audio waveforms (B, T) |
| Returns: emotion_logits, stress_score |
| """ |
| |
| outputs = self.wav2vec(input_values) |
| hidden_states = outputs.last_hidden_state |
|
|
| |
| hidden_states = hidden_states.transpose(1, 2) |
|
|
| |
| cnn_features = self.cnn(hidden_states).squeeze(-1) |
|
|
| |
| cnn_features = cnn_features.unsqueeze(1) |
|
|
| |
| transformer_out = self.transformer(cnn_features) |
| pooled_features = transformer_out.mean(dim=1) |
|
|
| emotion_logits = self.emotion_classifier(pooled_features) |
| stress_score = self.stress_head(pooled_features) |
|
|
| return emotion_logits, stress_score.squeeze() |
|
|
| def preprocess_audio(self, audio_path, sample_rate=16000, duration=3.0): |
| """ |
| Load and preprocess audio file. |
| """ |
| |
| audio, sr = librosa.load(audio_path, sr=sample_rate, duration=duration) |
|
|
| |
| target_length = int(sample_rate * duration) |
| if len(audio) < target_length: |
| audio = np.pad(audio, (0, target_length - len(audio))) |
| else: |
| audio = audio[:target_length] |
|
|
| return torch.tensor(audio, dtype=torch.float32) |
|
|
| def extract_prosody_features(self, audio): |
| """ |
| Extract additional prosody features (pitch, rhythm, etc.) |
| """ |
| |
| pitches, magnitudes = librosa.piptrack(y=audio.numpy(), sr=16000) |
| pitch = np.mean(pitches[pitches > 0]) |
|
|
| |
| rms = librosa.feature.rms(y=audio.numpy())[0].mean() |
|
|
| |
| zcr = librosa.feature.zero_crossing_rate(y=audio.numpy())[0].mean() |
|
|
| return torch.tensor([pitch, rms, zcr], dtype=torch.float32) |