stutter-detector / models /stutter_detector_local.py
abdu-l7hman
Initial commit with model and app
1154abd
# ============================================
# LOCAL PC STUTTER DETECTION SETUP
# Run this on your local machine
# ============================================
import os
import numpy as np
import librosa
import torch
import torch.nn as nn
import transformers
from typing import List, Tuple
import warnings
warnings.filterwarnings('ignore')
# ============================================
# MODEL ARCHITECTURE (MUST MATCH TRAINING)
# ============================================
class ImprovedWav2VecClassifier(nn.Module):
"""Improved classifier matching training architecture."""
def __init__(self, hidden_dim=768, intermediate_dim=256, output_dim=2, dropout=0.3):
super().__init__()
# Load pre-trained Wav2Vec model
self.wav2vec = transformers.Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
# Freeze Wav2Vec parameters
for param in self.wav2vec.parameters():
param.requires_grad = False
# Classification head
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.BatchNorm1d(intermediate_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(intermediate_dim, intermediate_dim // 2),
nn.BatchNorm1d(intermediate_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(intermediate_dim // 2, output_dim)
)
def forward(self, x):
with torch.no_grad():
encoder_output = self.wav2vec(x).last_hidden_state
pooled_features = encoder_output.mean(dim=1)
return self.classifier(pooled_features)
# ============================================
# FEATURE EXTRACTOR
# ============================================
class Wav2VecFeatureExtractor:
"""Extract features from audio files."""
def __init__(self, model_name='facebook/wav2vec2-base', duration=3):
self.processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained(model_name)
self.duration = duration
self.sample_rate = 16000
def extract_features(self, audio_data, sr):
try:
if sr != self.sample_rate:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate)
features = self.processor(audio_data, sampling_rate=self.sample_rate, return_tensors='pt').input_values
return features.squeeze(0)
except Exception as e:
print(f"Error extracting features: {e}")
return None
# ============================================
# AUDIO PROCESSING FUNCTIONS
# ============================================
def load_audio_file(file_path: str) -> Tuple[np.ndarray, int]:
"""Load an audio file."""
try:
audio_data, sr = librosa.load(file_path, sr=None)
return audio_data, sr
except Exception as e:
raise Exception(f"Error loading audio file: {e}")
def segment_audio(audio_data: np.ndarray, sr: int, segment_duration: float = 3.0) -> List[np.ndarray]:
"""Split audio into fixed-duration segments."""
segment_samples = int(segment_duration * sr)
segments = []
for i in range(0, len(audio_data), segment_samples):
segment = audio_data[i:i + segment_samples]
if len(segment) >= sr: # At least 1 second
if len(segment) < segment_samples:
padding = segment_samples - len(segment)
segment = np.pad(segment, (0, padding), mode='constant')
segments.append(segment)
return segments
def pad_or_truncate_features(features: torch.Tensor, max_length: int = 32007) -> torch.Tensor:
"""Pad or truncate features to match expected input length."""
if features.size(0) < max_length:
padding = max_length - features.size(0)
features = torch.cat([features, torch.zeros(padding)], dim=0)
elif features.size(0) > max_length:
features = features[:max_length]
return features
# ============================================
# STUTTER DETECTOR CLASS
# ============================================
class ImprovedStutterDetector:
"""Stutter detector for all types: prolongations, blocks, repetitions, interjections."""
def __init__(self, model_path: str, device: str = None):
# Set device
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"Using device: {self.device}")
# Load model
print("Loading model...")
self.model = ImprovedWav2VecClassifier()
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
print("βœ“ Model loaded successfully!")
# Initialize feature extractor
self.feature_extractor = Wav2VecFeatureExtractor(duration=3)
# Class names
self.class_names = ['No Stutter', 'Stutter (All Types)']
print("\nThis model detects ALL stutter types:")
print(" β€’ Prolongations (ssssso)")
print(" β€’ Blocks (getting stuck)")
print(" β€’ Sound Repetitions (b-b-ball)")
print(" β€’ Word Repetitions (I-I-I want)")
print(" β€’ Interjections (um, uh)")
def analyze_audio_file(self, file_path: str, segment_duration: float = 3.0,
stutter_threshold: float = 0.5, show_probabilities: bool = True) -> dict:
"""Analyze an entire audio file for stuttering."""
print(f"\n{'='*70}")
print(f"ANALYZING: {os.path.basename(file_path)}")
print(f"{'='*70}")
# Load audio
audio_data, sr = load_audio_file(file_path)
duration = len(audio_data) / sr
print(f"πŸ“Š Audio duration: {duration:.2f} seconds")
# Segment audio
segments = segment_audio(audio_data, sr, segment_duration)
print(f"πŸ“Š Number of segments: {len(segments)}")
if len(segments) == 0:
return {'error': 'Audio too short for analysis (minimum 1 second required)'}
# Analyze each segment
results = []
stutter_count = 0
print(f"\n{'='*70}")
print("SEGMENT ANALYSIS")
print(f"{'='*70}")
for i, segment in enumerate(segments):
features = self.feature_extractor.extract_features(segment, sr)
if features is None:
continue
features = pad_or_truncate_features(features)
features = features.unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(features)
probabilities = torch.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
no_stutter_prob = probabilities[0][0].item()
stutter_prob = probabilities[0][1].item()
is_stutter = stutter_prob >= stutter_threshold
results.append({
'segment': i + 1,
'prediction': self.class_names[predicted_class],
'confidence': confidence,
'is_stutter': is_stutter,
'no_stutter_probability': no_stutter_prob,
'stutter_probability': stutter_prob
})
if is_stutter:
stutter_count += 1
if show_probabilities:
status_emoji = "πŸ”΄" if is_stutter else "🟒"
status_text = "STUTTER DETECTED" if is_stutter else "Clear"
print(f"{status_emoji} Segment {i+1}: {status_text}")
print(f" No Stutter: {no_stutter_prob:.2%} | Stutter: {stutter_prob:.2%}")
# Calculate statistics
total_segments = len(results)
stutter_percentage = (stutter_count / total_segments * 100) if total_segments > 0 else 0
print(f"\n{'='*70}")
print("FINAL RESULTS")
print(f"{'='*70}")
print(f"βœ“ Total segments analyzed: {total_segments}")
print(f"πŸ”΄ Segments with stutter: {stutter_count}")
print(f"🟒 Segments without stutter: {total_segments - stutter_count}")
print(f"πŸ“Š Stuttering percentage: {stutter_percentage:.1f}%")
return {
'file_path': file_path,
'duration': duration,
'total_segments': total_segments,
'stutter_count': stutter_count,
'no_stutter_count': total_segments - stutter_count,
'stutter_percentage': stutter_percentage,
'segment_results': results
}
# ============================================
# SEVERITY ANALYSIS
# ============================================
def calculate_stutter_severity(results):
"""Calculate detailed stutter severity metrics."""
segment_results = results['segment_results']
stutter_probs = [seg['stutter_probability'] for seg in segment_results]
avg_prob = sum(stutter_probs) / len(stutter_probs)
max_prob = max(stutter_probs)
min_prob = min(stutter_probs)
# Count segments by severity
severe = sum(1 for p in stutter_probs if p > 0.6)
moderate = sum(1 for p in stutter_probs if 0.4 < p <= 0.6)
mild = sum(1 for p in stutter_probs if 0.2 < p <= 0.4)
minimal = sum(1 for p in stutter_probs if p <= 0.2)
# Calculate severity score as stutters / total segments
total_segments = results.get('total_segments', 0)
stutter_count = results.get('stutter_count', 0)
severity_score = stutter_count / total_segments if total_segments > 0 else 0.0
print(f"\n{'='*70}")
print("DETAILED SEVERITY ANALYSIS")
print(f"{'='*70}")
print(f"Average stutter probability: {avg_prob:.2%}")
print(f"Peak stutter probability: {max_prob:.2%}")
print(f"Minimum stutter probability: {min_prob:.2%}")
print(f"\nSegment Severity Breakdown:")
print(f" πŸ”΄ Severe (>70%): {severe} segments")
print(f" 🟠 Moderate (40-70%): {moderate} segments")
print(f" 🟑 Mild (20-40%): {mild} segments")
print(f" 🟒 Minimal (<20%): {minimal} segments")
# Overall severity
if avg_prob < 0.15:
severity = "βœ“ Minimal or No Stuttering"
elif avg_prob < 0.35:
severity = "⚠️ Mild Stuttering"
elif avg_prob < 0.60:
severity = "⚠️ Moderate Stuttering"
else:
severity = "πŸ”΄ Significant Stuttering"
print(f"\n🎯 Overall Assessment: {severity}")
print(f"πŸ“Š Severity Score: {severity_score:.2%} (stutters/total segments)")
return {
'average_probability': avg_prob,
'max_probability': max_prob,
'severity_level': severity,
'severity_score': severity_score,
'severe_segments': severe,
'moderate_segments': moderate,
'mild_segments': mild,
'minimal_segments': minimal
}