Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import torch | |
| import torchaudio | |
| import transformers | |
| import numpy as np | |
| import librosa | |
| from flask import Flask, request, jsonify | |
| from werkzeug.utils import secure_filename | |
| import soundfile as sf | |
| app = Flask(__name__) | |
| app.config['UPLOAD_FOLDER'] = '/tmp/uploads' | |
| app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB max | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| # --------- MODEL DEFINITIONS --------- | |
| class AttentiveStatsPool(torch.nn.Module): | |
| def __init__(self, in_dim, use_std=True): | |
| super().__init__() | |
| self.use_std = use_std | |
| self.att = torch.nn.Sequential( | |
| torch.nn.Linear(in_dim, in_dim // 2), | |
| torch.nn.Tanh(), | |
| torch.nn.Linear(in_dim // 2, 1) | |
| ) | |
| def forward(self, H, mask=None): | |
| if mask is not None: | |
| logits = self.att(H).squeeze(-1).masked_fill(~mask, float("-inf")) | |
| alpha = torch.softmax(logits, dim=1).unsqueeze(-1) | |
| else: | |
| alpha = torch.softmax(self.att(H), dim=1) | |
| mean = (alpha * H).sum(dim=1) | |
| if not self.use_std: | |
| return mean | |
| ex2 = (alpha * (H ** 2)).sum(dim=1) | |
| std = torch.sqrt(torch.clamp(ex2 - mean**2, min=1e-6)) | |
| return torch.cat([mean, std], dim=-1) | |
| class ASPMLPClassifier(torch.nn.Module): | |
| def __init__(self, hidden_dim=768, output_dim=2, dropout=0.3): | |
| super().__init__() | |
| self.pool = AttentiveStatsPool(hidden_dim, use_std=True) | |
| self.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_dim * 2, 256), | |
| torch.nn.BatchNorm1d(256), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.Linear(256, 128), | |
| torch.nn.BatchNorm1d(128), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.Linear(128, output_dim) | |
| ) | |
| def forward(self, H): | |
| z = self.pool(H) | |
| return self.classifier(z) | |
| class FullInferenceModel(torch.nn.Module): | |
| def __init__(self, asp_mlp_classifier, wav2vec_model): | |
| super().__init__() | |
| self.wav2vec = wav2vec_model | |
| self.asp_mlp = asp_mlp_classifier | |
| def forward(self, input_values): | |
| with torch.no_grad(): | |
| H = self.wav2vec(input_values).last_hidden_state | |
| return self.asp_mlp(H) | |
| # Global model and processor | |
| model = None | |
| processor = None | |
| def load_model(model_path="stutter_detector_mio.pth"): | |
| global model, processor | |
| if model is not None: | |
| return model, processor | |
| print("Loading model...") | |
| checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) | |
| asp = ASPMLPClassifier() | |
| wav2vec = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| full = FullInferenceModel(asp, wav2vec) | |
| full.load_state_dict(checkpoint["full_model_state_dict"]) | |
| full.eval() | |
| model = full | |
| processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| print("Model loaded successfully!") | |
| return model, processor | |
| def preprocess_audio(audio, max_len=48000): | |
| if len(audio) > max_len: | |
| audio = audio[:max_len] | |
| else: | |
| audio = np.pad(audio, (0, max_len - len(audio))) | |
| return audio | |
| def predict(audio_np): | |
| mdl, proc = load_model() | |
| inputs = proc(audio_np, sampling_rate=16000, return_tensors="pt").input_values | |
| with torch.no_grad(): | |
| logits = mdl(inputs) | |
| probs = torch.softmax(logits, dim=1) | |
| pred = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0, pred].item() | |
| return pred, confidence | |
| def analyze_segments(audio_path): | |
| """Mode 1: Returns label for each 3-second segment""" | |
| audio, sr = librosa.load(audio_path, sr=16000) | |
| segment_duration = 3 | |
| segment_samples = segment_duration * sr | |
| total_samples = len(audio) | |
| total_duration = total_samples / sr | |
| results = [] | |
| CONF_THRESHOLD = 0.5 | |
| start = 0 | |
| segment_idx = 0 | |
| while start < total_samples: | |
| end = min(start + segment_samples, total_samples) | |
| segment = audio[start:end] | |
| if len(segment) >= sr: # At least 1 second | |
| audio_np = preprocess_audio(segment) | |
| pred, conf = predict(audio_np) | |
| label = "stutter" if pred == 1 and conf >= CONF_THRESHOLD else "no_stutter" | |
| results.append({ | |
| "segment": segment_idx, | |
| "start_time": round(start / sr, 2), | |
| "end_time": round(end / sr, 2), | |
| "label": label, | |
| "confidence": round(conf, 4) | |
| }) | |
| start = end | |
| segment_idx += 1 | |
| return {"duration": round(total_duration, 2), "segments": results} | |
| def analyze_seconds(audio_path): | |
| """Mode 2: Returns label for each second using overlapping segments""" | |
| audio, sr = librosa.load(audio_path, sr=16000) | |
| segment_duration = 3 | |
| segment_samples = segment_duration * sr | |
| hop = int(segment_samples * 0.5) | |
| total_samples = len(audio) | |
| total_seconds = int(total_samples // sr) | |
| # Get overlapping segment predictions | |
| segments = [] | |
| start = 0 | |
| while start + segment_samples <= total_samples: | |
| end = start + segment_samples | |
| segments.append((start, end)) | |
| start += hop | |
| segment_predictions = [] | |
| for start, end in segments: | |
| segment = audio[start:end] | |
| audio_np = preprocess_audio(segment) | |
| pred, conf = predict(audio_np) | |
| segment_predictions.append((pred, conf, start, end)) | |
| # Vote per second | |
| votes = [[] for _ in range(total_seconds)] | |
| confs = [[] for _ in range(total_seconds)] | |
| for i, (pred, conf, start, end) in enumerate(segment_predictions): | |
| if i == 0 or i == len(segment_predictions) - 1: | |
| continue | |
| seg_start_sec = int(start // sr) | |
| seg_end_sec = int(end // sr) | |
| for sec in range(seg_start_sec, min(seg_end_sec, total_seconds)): | |
| votes[sec].append(pred) | |
| confs[sec].append(conf) | |
| results = [] | |
| CONF_THRESHOLD = 0.6 | |
| for sec in range(total_seconds): | |
| if len(votes[sec]) == 0: | |
| results.append({"second": sec, "label": "no_data", "confidence": None}) | |
| continue | |
| majority = 1 if votes[sec].count(1) > votes[sec].count(0) else 0 | |
| mean_conf = sum(confs[sec]) / len(confs[sec]) | |
| label = "stutter" if majority == 1 and mean_conf >= CONF_THRESHOLD else "no_stutter" | |
| results.append({"second": sec, "label": label, "confidence": round(mean_conf, 4)}) | |
| return {"duration": total_seconds, "seconds": results} | |
| def analyze_percentage(audio_path): | |
| """Mode 3: Returns only stutter percentage""" | |
| data = analyze_seconds(audio_path) | |
| total_seconds = data["duration"] | |
| stutter_seconds = sum(1 for r in data["seconds"] if r["label"] == "stutter") | |
| stutter_percentage = (stutter_seconds / total_seconds) * 100 if total_seconds > 0 else 0 | |
| return { | |
| "duration": total_seconds, | |
| "stutter_percentage": round(stutter_percentage, 2) | |
| } | |
| def home(): | |
| return jsonify({ | |
| 'service': 'Stutter Detection API', | |
| 'status': 'running', | |
| 'endpoints': { | |
| '/analyze': 'POST - Analyze audio file', | |
| '/health': 'GET - Health check' | |
| }, | |
| 'modes': ['segments', 'seconds', 'percentage'], | |
| 'usage': 'POST audio file to /analyze?mode=<mode>' | |
| }) | |
| def analyze(): | |
| if 'audio' not in request.files: | |
| return jsonify({'error': 'No audio file provided'}), 400 | |
| file = request.files['audio'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Get mode parameter (default: segments) | |
| mode = request.args.get('mode', 'segments') | |
| if mode not in ['segments', 'seconds', 'percentage']: | |
| return jsonify({'error': 'Invalid mode. Choose: segments, seconds, or percentage'}), 400 | |
| file_id = str(uuid.uuid4()) | |
| original_ext = os.path.splitext(secure_filename(file.filename))[1] or '.webm' | |
| original_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{file_id}{original_ext}") | |
| wav_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{file_id}.wav") | |
| file.save(original_path) | |
| try: | |
| # Convert to WAV | |
| audio, sr = librosa.load(original_path, sr=16000, mono=True) | |
| sf.write(wav_path, audio, sr, format='WAV') | |
| # Analyze based on mode | |
| if mode == 'segments': | |
| result = analyze_segments(wav_path) | |
| elif mode == 'seconds': | |
| result = analyze_seconds(wav_path) | |
| else: # percentage | |
| result = analyze_percentage(wav_path) | |
| # Cleanup | |
| if os.path.exists(original_path): | |
| os.remove(original_path) | |
| if os.path.exists(wav_path): | |
| os.remove(wav_path) | |
| return jsonify({'success': True, **result}) | |
| except Exception as e: | |
| # Cleanup on error | |
| if os.path.exists(original_path): | |
| os.remove(original_path) | |
| if os.path.exists(wav_path): | |
| os.remove(wav_path) | |
| return jsonify({'error': str(e)}), 500 | |
| def health(): | |
| return jsonify({ | |
| 'status': 'ok', | |
| 'model_loaded': model is not None | |
| }) | |
| if __name__ == '__main__': | |
| load_model() | |
| port = int(os.environ.get('PORT', 7860)) # Hugging Face uses port 7860 | |
| app.run(host='0.0.0.0', port=port, debug=False) |