Spaces:
Running
Running
| # ============================================ | |
| # LOCAL PC STUTTER DETECTION SETUP | |
| # Run this on your local machine | |
| # ============================================ | |
| import os | |
| import numpy as np | |
| import librosa | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from typing import List, Tuple | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # ============================================ | |
| # MODEL ARCHITECTURE (MUST MATCH TRAINING) | |
| # ============================================ | |
| class ImprovedWav2VecClassifier(nn.Module): | |
| """Improved classifier matching training architecture.""" | |
| def __init__(self, hidden_dim=768, intermediate_dim=256, output_dim=2, dropout=0.3): | |
| super().__init__() | |
| # Load pre-trained Wav2Vec model | |
| self.wav2vec = transformers.Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base') | |
| # Freeze Wav2Vec parameters | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, intermediate_dim), | |
| nn.BatchNorm1d(intermediate_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(intermediate_dim, intermediate_dim // 2), | |
| nn.BatchNorm1d(intermediate_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(intermediate_dim // 2, output_dim) | |
| ) | |
| def forward(self, x): | |
| with torch.no_grad(): | |
| encoder_output = self.wav2vec(x).last_hidden_state | |
| pooled_features = encoder_output.mean(dim=1) | |
| return self.classifier(pooled_features) | |
| # ============================================ | |
| # FEATURE EXTRACTOR | |
| # ============================================ | |
| class Wav2VecFeatureExtractor: | |
| """Extract features from audio files.""" | |
| def __init__(self, model_name='facebook/wav2vec2-base', duration=3): | |
| self.processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained(model_name) | |
| self.duration = duration | |
| self.sample_rate = 16000 | |
| def extract_features(self, audio_data, sr): | |
| try: | |
| if sr != self.sample_rate: | |
| audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate) | |
| features = self.processor(audio_data, sampling_rate=self.sample_rate, return_tensors='pt').input_values | |
| return features.squeeze(0) | |
| except Exception as e: | |
| print(f"Error extracting features: {e}") | |
| return None | |
| # ============================================ | |
| # AUDIO PROCESSING FUNCTIONS | |
| # ============================================ | |
| def load_audio_file(file_path: str) -> Tuple[np.ndarray, int]: | |
| """Load an audio file.""" | |
| try: | |
| audio_data, sr = librosa.load(file_path, sr=None) | |
| return audio_data, sr | |
| except Exception as e: | |
| raise Exception(f"Error loading audio file: {e}") | |
| def segment_audio(audio_data: np.ndarray, sr: int, segment_duration: float = 3.0) -> List[np.ndarray]: | |
| """Split audio into fixed-duration segments.""" | |
| segment_samples = int(segment_duration * sr) | |
| segments = [] | |
| for i in range(0, len(audio_data), segment_samples): | |
| segment = audio_data[i:i + segment_samples] | |
| if len(segment) >= sr: # At least 1 second | |
| if len(segment) < segment_samples: | |
| padding = segment_samples - len(segment) | |
| segment = np.pad(segment, (0, padding), mode='constant') | |
| segments.append(segment) | |
| return segments | |
| def pad_or_truncate_features(features: torch.Tensor, max_length: int = 32007) -> torch.Tensor: | |
| """Pad or truncate features to match expected input length.""" | |
| if features.size(0) < max_length: | |
| padding = max_length - features.size(0) | |
| features = torch.cat([features, torch.zeros(padding)], dim=0) | |
| elif features.size(0) > max_length: | |
| features = features[:max_length] | |
| return features | |
| # ============================================ | |
| # STUTTER DETECTOR CLASS | |
| # ============================================ | |
| class ImprovedStutterDetector: | |
| """Stutter detector for all types: prolongations, blocks, repetitions, interjections.""" | |
| def __init__(self, model_path: str, device: str = None): | |
| # Set device | |
| if device is None: | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| self.device = torch.device(device) | |
| print(f"Using device: {self.device}") | |
| # Load model | |
| print("Loading model...") | |
| self.model = ImprovedWav2VecClassifier() | |
| self.model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("β Model loaded successfully!") | |
| # Initialize feature extractor | |
| self.feature_extractor = Wav2VecFeatureExtractor(duration=3) | |
| # Class names | |
| self.class_names = ['No Stutter', 'Stutter (All Types)'] | |
| print("\nThis model detects ALL stutter types:") | |
| print(" β’ Prolongations (ssssso)") | |
| print(" β’ Blocks (getting stuck)") | |
| print(" β’ Sound Repetitions (b-b-ball)") | |
| print(" β’ Word Repetitions (I-I-I want)") | |
| print(" β’ Interjections (um, uh)") | |
| def analyze_audio_file(self, file_path: str, segment_duration: float = 3.0, | |
| stutter_threshold: float = 0.5, show_probabilities: bool = True) -> dict: | |
| """Analyze an entire audio file for stuttering.""" | |
| print(f"\n{'='*70}") | |
| print(f"ANALYZING: {os.path.basename(file_path)}") | |
| print(f"{'='*70}") | |
| # Load audio | |
| audio_data, sr = load_audio_file(file_path) | |
| duration = len(audio_data) / sr | |
| print(f"π Audio duration: {duration:.2f} seconds") | |
| # Segment audio | |
| segments = segment_audio(audio_data, sr, segment_duration) | |
| print(f"π Number of segments: {len(segments)}") | |
| if len(segments) == 0: | |
| return {'error': 'Audio too short for analysis (minimum 1 second required)'} | |
| # Analyze each segment | |
| results = [] | |
| stutter_count = 0 | |
| print(f"\n{'='*70}") | |
| print("SEGMENT ANALYSIS") | |
| print(f"{'='*70}") | |
| for i, segment in enumerate(segments): | |
| features = self.feature_extractor.extract_features(segment, sr) | |
| if features is None: | |
| continue | |
| features = pad_or_truncate_features(features) | |
| features = features.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(features) | |
| probabilities = torch.softmax(outputs, dim=1) | |
| predicted_class = torch.argmax(probabilities, dim=1).item() | |
| confidence = probabilities[0][predicted_class].item() | |
| no_stutter_prob = probabilities[0][0].item() | |
| stutter_prob = probabilities[0][1].item() | |
| is_stutter = stutter_prob >= stutter_threshold | |
| results.append({ | |
| 'segment': i + 1, | |
| 'prediction': self.class_names[predicted_class], | |
| 'confidence': confidence, | |
| 'is_stutter': is_stutter, | |
| 'no_stutter_probability': no_stutter_prob, | |
| 'stutter_probability': stutter_prob | |
| }) | |
| if is_stutter: | |
| stutter_count += 1 | |
| if show_probabilities: | |
| status_emoji = "π΄" if is_stutter else "π’" | |
| status_text = "STUTTER DETECTED" if is_stutter else "Clear" | |
| print(f"{status_emoji} Segment {i+1}: {status_text}") | |
| print(f" No Stutter: {no_stutter_prob:.2%} | Stutter: {stutter_prob:.2%}") | |
| # Calculate statistics | |
| total_segments = len(results) | |
| stutter_percentage = (stutter_count / total_segments * 100) if total_segments > 0 else 0 | |
| print(f"\n{'='*70}") | |
| print("FINAL RESULTS") | |
| print(f"{'='*70}") | |
| print(f"β Total segments analyzed: {total_segments}") | |
| print(f"π΄ Segments with stutter: {stutter_count}") | |
| print(f"π’ Segments without stutter: {total_segments - stutter_count}") | |
| print(f"π Stuttering percentage: {stutter_percentage:.1f}%") | |
| return { | |
| 'file_path': file_path, | |
| 'duration': duration, | |
| 'total_segments': total_segments, | |
| 'stutter_count': stutter_count, | |
| 'no_stutter_count': total_segments - stutter_count, | |
| 'stutter_percentage': stutter_percentage, | |
| 'segment_results': results | |
| } | |
| # ============================================ | |
| # SEVERITY ANALYSIS | |
| # ============================================ | |
| def calculate_stutter_severity(results): | |
| """Calculate detailed stutter severity metrics.""" | |
| segment_results = results['segment_results'] | |
| stutter_probs = [seg['stutter_probability'] for seg in segment_results] | |
| avg_prob = sum(stutter_probs) / len(stutter_probs) | |
| max_prob = max(stutter_probs) | |
| min_prob = min(stutter_probs) | |
| # Count segments by severity | |
| severe = sum(1 for p in stutter_probs if p > 0.6) | |
| moderate = sum(1 for p in stutter_probs if 0.4 < p <= 0.6) | |
| mild = sum(1 for p in stutter_probs if 0.2 < p <= 0.4) | |
| minimal = sum(1 for p in stutter_probs if p <= 0.2) | |
| # Calculate severity score as stutters / total segments | |
| total_segments = results.get('total_segments', 0) | |
| stutter_count = results.get('stutter_count', 0) | |
| severity_score = stutter_count / total_segments if total_segments > 0 else 0.0 | |
| print(f"\n{'='*70}") | |
| print("DETAILED SEVERITY ANALYSIS") | |
| print(f"{'='*70}") | |
| print(f"Average stutter probability: {avg_prob:.2%}") | |
| print(f"Peak stutter probability: {max_prob:.2%}") | |
| print(f"Minimum stutter probability: {min_prob:.2%}") | |
| print(f"\nSegment Severity Breakdown:") | |
| print(f" π΄ Severe (>70%): {severe} segments") | |
| print(f" π Moderate (40-70%): {moderate} segments") | |
| print(f" π‘ Mild (20-40%): {mild} segments") | |
| print(f" π’ Minimal (<20%): {minimal} segments") | |
| # Overall severity | |
| if avg_prob < 0.15: | |
| severity = "β Minimal or No Stuttering" | |
| elif avg_prob < 0.35: | |
| severity = "β οΈ Mild Stuttering" | |
| elif avg_prob < 0.60: | |
| severity = "β οΈ Moderate Stuttering" | |
| else: | |
| severity = "π΄ Significant Stuttering" | |
| print(f"\nπ― Overall Assessment: {severity}") | |
| print(f"π Severity Score: {severity_score:.2%} (stutters/total segments)") | |
| return { | |
| 'average_probability': avg_prob, | |
| 'max_probability': max_prob, | |
| 'severity_level': severity, | |
| 'severity_score': severity_score, | |
| 'severe_segments': severe, | |
| 'moderate_segments': moderate, | |
| 'mild_segments': mild, | |
| 'minimal_segments': minimal | |
| } | |