speech-fast-api / models /model_wav2vec.py
avi292423's picture
Upload 3 files
ae94ca3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2Model
class Wav2VecIntent(nn.Module):
def __init__(self, num_classes=31, pretrained_model="facebook/wav2vec2-large"):
super().__init__()
# Load pretrained wav2vec model
self.wav2vec = Wav2Vec2Model.from_pretrained(pretrained_model)
# Get hidden size from model config
hidden_size = self.wav2vec.config.hidden_size
# Add layer normalization
self.layer_norm = nn.LayerNorm(hidden_size)
# Add attention mechanism
self.attention = nn.Linear(hidden_size, 1)
# Add dropout for regularization
self.dropout = nn.Dropout(p=0.5)
# Classification head
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, input_values, attention_mask=None):
# Get wav2vec features
outputs = self.wav2vec(
input_values,
attention_mask=attention_mask,
return_dict=True
)
hidden_states = outputs.last_hidden_state # [batch, sequence, hidden]
# Apply layer normalization
hidden_states = self.layer_norm(hidden_states)
# Apply attention
attn_weights = F.softmax(self.attention(hidden_states), dim=1)
x = torch.sum(hidden_states * attn_weights, dim=1) # Weighted sum
# Apply dropout
x = self.dropout(x)
# Final classification
x = self.fc(x)
return x