StutterModelAPI / app.py
mio
add files
347e853
import os
import uuid
import torch
import torchaudio
import transformers
import numpy as np
import librosa
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import soundfile as sf
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB max
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# --------- MODEL DEFINITIONS ---------
class AttentiveStatsPool(torch.nn.Module):
def __init__(self, in_dim, use_std=True):
super().__init__()
self.use_std = use_std
self.att = torch.nn.Sequential(
torch.nn.Linear(in_dim, in_dim // 2),
torch.nn.Tanh(),
torch.nn.Linear(in_dim // 2, 1)
)
def forward(self, H, mask=None):
if mask is not None:
logits = self.att(H).squeeze(-1).masked_fill(~mask, float("-inf"))
alpha = torch.softmax(logits, dim=1).unsqueeze(-1)
else:
alpha = torch.softmax(self.att(H), dim=1)
mean = (alpha * H).sum(dim=1)
if not self.use_std:
return mean
ex2 = (alpha * (H ** 2)).sum(dim=1)
std = torch.sqrt(torch.clamp(ex2 - mean**2, min=1e-6))
return torch.cat([mean, std], dim=-1)
class ASPMLPClassifier(torch.nn.Module):
def __init__(self, hidden_dim=768, output_dim=2, dropout=0.3):
super().__init__()
self.pool = AttentiveStatsPool(hidden_dim, use_std=True)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(hidden_dim * 2, 256),
torch.nn.BatchNorm1d(256),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(256, 128),
torch.nn.BatchNorm1d(128),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(128, output_dim)
)
def forward(self, H):
z = self.pool(H)
return self.classifier(z)
class FullInferenceModel(torch.nn.Module):
def __init__(self, asp_mlp_classifier, wav2vec_model):
super().__init__()
self.wav2vec = wav2vec_model
self.asp_mlp = asp_mlp_classifier
def forward(self, input_values):
with torch.no_grad():
H = self.wav2vec(input_values).last_hidden_state
return self.asp_mlp(H)
# Global model and processor
model = None
processor = None
def load_model(model_path="stutter_detector_mio.pth"):
global model, processor
if model is not None:
return model, processor
print("Loading model...")
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
asp = ASPMLPClassifier()
wav2vec = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
full = FullInferenceModel(asp, wav2vec)
full.load_state_dict(checkpoint["full_model_state_dict"])
full.eval()
model = full
processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
print("Model loaded successfully!")
return model, processor
def preprocess_audio(audio, max_len=48000):
if len(audio) > max_len:
audio = audio[:max_len]
else:
audio = np.pad(audio, (0, max_len - len(audio)))
return audio
def predict(audio_np):
mdl, proc = load_model()
inputs = proc(audio_np, sampling_rate=16000, return_tensors="pt").input_values
with torch.no_grad():
logits = mdl(inputs)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0, pred].item()
return pred, confidence
def analyze_segments(audio_path):
"""Mode 1: Returns label for each 3-second segment"""
audio, sr = librosa.load(audio_path, sr=16000)
segment_duration = 3
segment_samples = segment_duration * sr
total_samples = len(audio)
total_duration = total_samples / sr
results = []
CONF_THRESHOLD = 0.5
start = 0
segment_idx = 0
while start < total_samples:
end = min(start + segment_samples, total_samples)
segment = audio[start:end]
if len(segment) >= sr: # At least 1 second
audio_np = preprocess_audio(segment)
pred, conf = predict(audio_np)
label = "stutter" if pred == 1 and conf >= CONF_THRESHOLD else "no_stutter"
results.append({
"segment": segment_idx,
"start_time": round(start / sr, 2),
"end_time": round(end / sr, 2),
"label": label,
"confidence": round(conf, 4)
})
start = end
segment_idx += 1
return {"duration": round(total_duration, 2), "segments": results}
def analyze_seconds(audio_path):
"""Mode 2: Returns label for each second using overlapping segments"""
audio, sr = librosa.load(audio_path, sr=16000)
segment_duration = 3
segment_samples = segment_duration * sr
hop = int(segment_samples * 0.5)
total_samples = len(audio)
total_seconds = int(total_samples // sr)
# Get overlapping segment predictions
segments = []
start = 0
while start + segment_samples <= total_samples:
end = start + segment_samples
segments.append((start, end))
start += hop
segment_predictions = []
for start, end in segments:
segment = audio[start:end]
audio_np = preprocess_audio(segment)
pred, conf = predict(audio_np)
segment_predictions.append((pred, conf, start, end))
# Vote per second
votes = [[] for _ in range(total_seconds)]
confs = [[] for _ in range(total_seconds)]
for i, (pred, conf, start, end) in enumerate(segment_predictions):
if i == 0 or i == len(segment_predictions) - 1:
continue
seg_start_sec = int(start // sr)
seg_end_sec = int(end // sr)
for sec in range(seg_start_sec, min(seg_end_sec, total_seconds)):
votes[sec].append(pred)
confs[sec].append(conf)
results = []
CONF_THRESHOLD = 0.6
for sec in range(total_seconds):
if len(votes[sec]) == 0:
results.append({"second": sec, "label": "no_data", "confidence": None})
continue
majority = 1 if votes[sec].count(1) > votes[sec].count(0) else 0
mean_conf = sum(confs[sec]) / len(confs[sec])
label = "stutter" if majority == 1 and mean_conf >= CONF_THRESHOLD else "no_stutter"
results.append({"second": sec, "label": label, "confidence": round(mean_conf, 4)})
return {"duration": total_seconds, "seconds": results}
def analyze_percentage(audio_path):
"""Mode 3: Returns only stutter percentage"""
data = analyze_seconds(audio_path)
total_seconds = data["duration"]
stutter_seconds = sum(1 for r in data["seconds"] if r["label"] == "stutter")
stutter_percentage = (stutter_seconds / total_seconds) * 100 if total_seconds > 0 else 0
return {
"duration": total_seconds,
"stutter_percentage": round(stutter_percentage, 2)
}
@app.route('/', methods=['GET'])
def home():
return jsonify({
'service': 'Stutter Detection API',
'status': 'running',
'endpoints': {
'/analyze': 'POST - Analyze audio file',
'/health': 'GET - Health check'
},
'modes': ['segments', 'seconds', 'percentage'],
'usage': 'POST audio file to /analyze?mode=<mode>'
})
@app.route('/analyze', methods=['POST'])
def analyze():
if 'audio' not in request.files:
return jsonify({'error': 'No audio file provided'}), 400
file = request.files['audio']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
# Get mode parameter (default: segments)
mode = request.args.get('mode', 'segments')
if mode not in ['segments', 'seconds', 'percentage']:
return jsonify({'error': 'Invalid mode. Choose: segments, seconds, or percentage'}), 400
file_id = str(uuid.uuid4())
original_ext = os.path.splitext(secure_filename(file.filename))[1] or '.webm'
original_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{file_id}{original_ext}")
wav_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{file_id}.wav")
file.save(original_path)
try:
# Convert to WAV
audio, sr = librosa.load(original_path, sr=16000, mono=True)
sf.write(wav_path, audio, sr, format='WAV')
# Analyze based on mode
if mode == 'segments':
result = analyze_segments(wav_path)
elif mode == 'seconds':
result = analyze_seconds(wav_path)
else: # percentage
result = analyze_percentage(wav_path)
# Cleanup
if os.path.exists(original_path):
os.remove(original_path)
if os.path.exists(wav_path):
os.remove(wav_path)
return jsonify({'success': True, **result})
except Exception as e:
# Cleanup on error
if os.path.exists(original_path):
os.remove(original_path)
if os.path.exists(wav_path):
os.remove(wav_path)
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'model_loaded': model is not None
})
if __name__ == '__main__':
load_model()
port = int(os.environ.get('PORT', 7860)) # Hugging Face uses port 7860
app.run(host='0.0.0.0', port=port, debug=False)