File size: 11,266 Bytes
1154abd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# ============================================
# LOCAL PC STUTTER DETECTION SETUP
# Run this on your local machine
# ============================================

import os
import numpy as np
import librosa
import torch
import torch.nn as nn
import transformers
from typing import List, Tuple
import warnings
warnings.filterwarnings('ignore')

# ============================================
# MODEL ARCHITECTURE (MUST MATCH TRAINING)
# ============================================

class ImprovedWav2VecClassifier(nn.Module):
    """Improved classifier matching training architecture."""
    
    def __init__(self, hidden_dim=768, intermediate_dim=256, output_dim=2, dropout=0.3):
        super().__init__()
        
        # Load pre-trained Wav2Vec model
        self.wav2vec = transformers.Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
        
        # Freeze Wav2Vec parameters
        for param in self.wav2vec.parameters():
            param.requires_grad = False
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, intermediate_dim),
            nn.BatchNorm1d(intermediate_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(intermediate_dim, intermediate_dim // 2),
            nn.BatchNorm1d(intermediate_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(intermediate_dim // 2, output_dim)
        )
    
    def forward(self, x):
        with torch.no_grad():
            encoder_output = self.wav2vec(x).last_hidden_state
        pooled_features = encoder_output.mean(dim=1)
        return self.classifier(pooled_features)


# ============================================
# FEATURE EXTRACTOR
# ============================================

class Wav2VecFeatureExtractor:
    """Extract features from audio files."""
    
    def __init__(self, model_name='facebook/wav2vec2-base', duration=3):
        self.processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        self.duration = duration
        self.sample_rate = 16000
    
    def extract_features(self, audio_data, sr):
        try:
            if sr != self.sample_rate:
                audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate)
            
            features = self.processor(audio_data, sampling_rate=self.sample_rate, return_tensors='pt').input_values
            return features.squeeze(0)
        except Exception as e:
            print(f"Error extracting features: {e}")
            return None


# ============================================
# AUDIO PROCESSING FUNCTIONS
# ============================================

def load_audio_file(file_path: str) -> Tuple[np.ndarray, int]:
    """Load an audio file."""
    try:
        audio_data, sr = librosa.load(file_path, sr=None)
        return audio_data, sr
    except Exception as e:
        raise Exception(f"Error loading audio file: {e}")


def segment_audio(audio_data: np.ndarray, sr: int, segment_duration: float = 3.0) -> List[np.ndarray]:
    """Split audio into fixed-duration segments."""
    segment_samples = int(segment_duration * sr)
    segments = []
    
    for i in range(0, len(audio_data), segment_samples):
        segment = audio_data[i:i + segment_samples]
        
        if len(segment) >= sr:  # At least 1 second
            if len(segment) < segment_samples:
                padding = segment_samples - len(segment)
                segment = np.pad(segment, (0, padding), mode='constant')
            segments.append(segment)
    
    return segments


def pad_or_truncate_features(features: torch.Tensor, max_length: int = 32007) -> torch.Tensor:
    """Pad or truncate features to match expected input length."""
    if features.size(0) < max_length:
        padding = max_length - features.size(0)
        features = torch.cat([features, torch.zeros(padding)], dim=0)
    elif features.size(0) > max_length:
        features = features[:max_length]
    return features


# ============================================
# STUTTER DETECTOR CLASS
# ============================================

class ImprovedStutterDetector:
    """Stutter detector for all types: prolongations, blocks, repetitions, interjections."""
    
    def __init__(self, model_path: str, device: str = None):
        # Set device
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        print(f"Using device: {self.device}")
        
        # Load model
        print("Loading model...")
        self.model = ImprovedWav2VecClassifier()
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        print("βœ“ Model loaded successfully!")
        
        # Initialize feature extractor
        self.feature_extractor = Wav2VecFeatureExtractor(duration=3)
        
        # Class names
        self.class_names = ['No Stutter', 'Stutter (All Types)']
        
        print("\nThis model detects ALL stutter types:")
        print("  β€’ Prolongations (ssssso)")
        print("  β€’ Blocks (getting stuck)")
        print("  β€’ Sound Repetitions (b-b-ball)")
        print("  β€’ Word Repetitions (I-I-I want)")
        print("  β€’ Interjections (um, uh)")
    
    def analyze_audio_file(self, file_path: str, segment_duration: float = 3.0,
                          stutter_threshold: float = 0.5, show_probabilities: bool = True) -> dict:
        """Analyze an entire audio file for stuttering."""
        
        print(f"\n{'='*70}")
        print(f"ANALYZING: {os.path.basename(file_path)}")
        print(f"{'='*70}")
        
        # Load audio
        audio_data, sr = load_audio_file(file_path)
        duration = len(audio_data) / sr
        print(f"πŸ“Š Audio duration: {duration:.2f} seconds")
        
        # Segment audio
        segments = segment_audio(audio_data, sr, segment_duration)
        print(f"πŸ“Š Number of segments: {len(segments)}")
        
        if len(segments) == 0:
            return {'error': 'Audio too short for analysis (minimum 1 second required)'}
        
        # Analyze each segment
        results = []
        stutter_count = 0
        
        print(f"\n{'='*70}")
        print("SEGMENT ANALYSIS")
        print(f"{'='*70}")
        
        for i, segment in enumerate(segments):
            features = self.feature_extractor.extract_features(segment, sr)
            if features is None:
                continue
            
            features = pad_or_truncate_features(features)
            features = features.unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(features)
                probabilities = torch.softmax(outputs, dim=1)
                predicted_class = torch.argmax(probabilities, dim=1).item()
                confidence = probabilities[0][predicted_class].item()
                
                no_stutter_prob = probabilities[0][0].item()
                stutter_prob = probabilities[0][1].item()
            
            is_stutter = stutter_prob >= stutter_threshold
            
            results.append({
                'segment': i + 1,
                'prediction': self.class_names[predicted_class],
                'confidence': confidence,
                'is_stutter': is_stutter,
                'no_stutter_probability': no_stutter_prob,
                'stutter_probability': stutter_prob
            })
            
            if is_stutter:
                stutter_count += 1
            
            if show_probabilities:
                status_emoji = "πŸ”΄" if is_stutter else "🟒"
                status_text = "STUTTER DETECTED" if is_stutter else "Clear"
                print(f"{status_emoji} Segment {i+1}: {status_text}")
                print(f"    No Stutter: {no_stutter_prob:.2%} | Stutter: {stutter_prob:.2%}")
        
        # Calculate statistics
        total_segments = len(results)
        stutter_percentage = (stutter_count / total_segments * 100) if total_segments > 0 else 0
        
        print(f"\n{'='*70}")
        print("FINAL RESULTS")
        print(f"{'='*70}")
        print(f"βœ“ Total segments analyzed: {total_segments}")
        print(f"πŸ”΄ Segments with stutter: {stutter_count}")
        print(f"🟒 Segments without stutter: {total_segments - stutter_count}")
        print(f"πŸ“Š Stuttering percentage: {stutter_percentage:.1f}%")
        
        return {
            'file_path': file_path,
            'duration': duration,
            'total_segments': total_segments,
            'stutter_count': stutter_count,
            'no_stutter_count': total_segments - stutter_count,
            'stutter_percentage': stutter_percentage,
            'segment_results': results
        }


# ============================================
# SEVERITY ANALYSIS
# ============================================

def calculate_stutter_severity(results):
    """Calculate detailed stutter severity metrics."""
    segment_results = results['segment_results']
    stutter_probs = [seg['stutter_probability'] for seg in segment_results]
    
    avg_prob = sum(stutter_probs) / len(stutter_probs)
    max_prob = max(stutter_probs)
    min_prob = min(stutter_probs)
    
    # Count segments by severity
    severe = sum(1 for p in stutter_probs if p > 0.6)
    moderate = sum(1 for p in stutter_probs if 0.4 < p <= 0.6)
    mild = sum(1 for p in stutter_probs if 0.2 < p <= 0.4)
    minimal = sum(1 for p in stutter_probs if p <= 0.2)
    
    # Calculate severity score as stutters / total segments
    total_segments = results.get('total_segments', 0)
    stutter_count = results.get('stutter_count', 0)
    severity_score = stutter_count / total_segments if total_segments > 0 else 0.0
    
    print(f"\n{'='*70}")
    print("DETAILED SEVERITY ANALYSIS")
    print(f"{'='*70}")
    print(f"Average stutter probability: {avg_prob:.2%}")
    print(f"Peak stutter probability: {max_prob:.2%}")
    print(f"Minimum stutter probability: {min_prob:.2%}")
    print(f"\nSegment Severity Breakdown:")
    print(f"  πŸ”΄ Severe (>70%): {severe} segments")
    print(f"  🟠 Moderate (40-70%): {moderate} segments")
    print(f"  🟑 Mild (20-40%): {mild} segments")
    print(f"  🟒 Minimal (<20%): {minimal} segments")
    
    # Overall severity
    if avg_prob < 0.15:
        severity = "βœ“ Minimal or No Stuttering"
    elif avg_prob < 0.35:
        severity = "⚠️ Mild Stuttering"
    elif avg_prob < 0.60:
        severity = "⚠️ Moderate Stuttering"
    else:
        severity = "πŸ”΄ Significant Stuttering"
    
    print(f"\n🎯 Overall Assessment: {severity}")
    print(f"πŸ“Š Severity Score: {severity_score:.2%} (stutters/total segments)")
    
    return {
        'average_probability': avg_prob,
        'max_probability': max_prob,
        'severity_level': severity,
        'severity_score': severity_score,
        'severe_segments': severe,
        'moderate_segments': moderate,
        'mild_segments': mild,
        'minimal_segments': minimal
    }