|
|
""" |
|
|
Speech Fluency Analysis - Hugging Face Gradio App |
|
|
WavLM stutter detection + Whisper transcription. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from datetime import datetime |
|
|
from transformers import WavLMModel |
|
|
|
|
|
STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"] |
|
|
|
|
|
STUTTER_INFO = { |
|
|
"Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')", |
|
|
"Block": "Complete stoppage of airflow/sound with tension", |
|
|
"SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')", |
|
|
"WordRep": "Whole word repetition (e.g. 'I-I-I want')", |
|
|
"Interjection": "Filler words like 'um', 'uh', 'like'", |
|
|
} |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
class WaveLmStutterClassification(nn.Module): |
|
|
def __init__(self, num_labels=5): |
|
|
super().__init__() |
|
|
self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") |
|
|
self.hidden_size = self.wavlm.config.hidden_size |
|
|
for p in self.wavlm.parameters(): |
|
|
p.requires_grad = False |
|
|
self.classifier = nn.Linear(self.hidden_size, num_labels) |
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state |
|
|
return self.classifier(h.mean(dim=1)) |
|
|
|
|
|
|
|
|
wavlm_model = None |
|
|
whisper_model = None |
|
|
models_loaded = False |
|
|
|
|
|
|
|
|
def load_models(): |
|
|
"""Load WavLM checkpoint and Whisper once.""" |
|
|
global wavlm_model, whisper_model, models_loaded |
|
|
if models_loaded: |
|
|
return True |
|
|
|
|
|
print("Loading WavLM ...") |
|
|
wavlm_model = WaveLmStutterClassification(num_labels=5) |
|
|
ckpt = "wavlm_stutter_classification_best.pth" |
|
|
if os.path.exists(ckpt): |
|
|
state = torch.load(ckpt, map_location=DEVICE, weights_only=False) |
|
|
if isinstance(state, dict) and "model_state_dict" in state: |
|
|
wavlm_model.load_state_dict(state["model_state_dict"]) |
|
|
else: |
|
|
wavlm_model.load_state_dict(state) |
|
|
wavlm_model.to(DEVICE).eval() |
|
|
|
|
|
print("Loading Whisper ...") |
|
|
import whisper |
|
|
whisper_model = whisper.load_model("base", device=DEVICE) |
|
|
|
|
|
models_loaded = True |
|
|
print("Models ready.") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_audio(path): |
|
|
"""Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg).""" |
|
|
waveform, sr = torchaudio.load(path) |
|
|
if waveform.size(0) > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
if sr != 16000: |
|
|
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) |
|
|
return waveform.squeeze(0), 16000 |
|
|
|
|
|
|
|
|
def analyze_chunk(chunk, threshold=0.5): |
|
|
"""Run WavLM on a single chunk.""" |
|
|
with torch.no_grad(): |
|
|
logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE)) |
|
|
probs = torch.sigmoid(logits).cpu().numpy()[0] |
|
|
detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold] |
|
|
prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs])) |
|
|
return detected, prob_dict |
|
|
|
|
|
|
|
|
def analyze_audio(audio_path, threshold, progress=gr.Progress()): |
|
|
"""Main pipeline: chunk -> WavLM -> Whisper -> formatted results.""" |
|
|
if audio_path is None: |
|
|
return "Upload an audio file first.", "", "", "" |
|
|
|
|
|
if isinstance(audio_path, tuple): |
|
|
import tempfile, soundfile as sf |
|
|
sr, data = audio_path |
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
|
|
sf.write(tmp.name, data, sr) |
|
|
audio_path = tmp.name |
|
|
|
|
|
progress(0.05, desc="Loading models ...") |
|
|
if not models_loaded and not load_models(): |
|
|
return "Failed to load models.", "", "", "" |
|
|
|
|
|
progress(0.15, desc="Loading audio ...") |
|
|
waveform, sr = load_audio(audio_path) |
|
|
duration = len(waveform) / sr |
|
|
|
|
|
progress(0.25, desc="Detecting stutters ...") |
|
|
chunk_samples = 3 * sr |
|
|
counts = {l: 0 for l in STUTTER_LABELS} |
|
|
timeline_rows = [] |
|
|
total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples) |
|
|
|
|
|
for i, start in enumerate(range(0, len(waveform), chunk_samples)): |
|
|
progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...") |
|
|
end = min(start + chunk_samples, len(waveform)) |
|
|
chunk = waveform[start:end] |
|
|
if len(chunk) < chunk_samples: |
|
|
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
|
|
|
|
|
detected, probs = analyze_chunk(chunk, threshold) |
|
|
for label in detected: |
|
|
counts[label] += 1 |
|
|
|
|
|
time_str = f"{start/sr:.1f}-{end/sr:.1f}s" |
|
|
timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs}) |
|
|
|
|
|
progress(0.75, desc="Transcribing ...") |
|
|
transcription = whisper_model.transcribe(audio_path).get("text", "").strip() |
|
|
|
|
|
progress(0.90, desc="Building report ...") |
|
|
total_stutters = sum(counts.values()) |
|
|
chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"]) |
|
|
stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0 |
|
|
word_count = len(transcription.split()) if transcription else 0 |
|
|
wpm = (word_count / duration) * 60 if duration > 0 else 0 |
|
|
|
|
|
severity = ( |
|
|
"Very Mild" if stutter_pct < 5 else |
|
|
"Mild" if stutter_pct < 10 else |
|
|
"Moderate" if stutter_pct < 20 else |
|
|
"Severe" if stutter_pct < 30 else |
|
|
"Very Severe" |
|
|
) |
|
|
|
|
|
summary_lines = [ |
|
|
"## Analysis Results\n", |
|
|
"| Metric | Value |", |
|
|
"|--------|-------|", |
|
|
f"| Duration | {duration:.1f}s |", |
|
|
f"| Words | {word_count} |", |
|
|
f"| Speaking Rate | {wpm:.0f} wpm |", |
|
|
f"| Stutter Events | {total_stutters} |", |
|
|
f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |", |
|
|
f"| Severity | **{severity}** |", |
|
|
"", |
|
|
"### Stutter Counts", |
|
|
"", |
|
|
] |
|
|
for label in STUTTER_LABELS: |
|
|
c = counts[label] |
|
|
bar = "X" * min(c, 20) |
|
|
icon = "!" if c > 0 else "o" |
|
|
summary_lines.append(f"- {icon} **{label}**: {c} {bar}") |
|
|
|
|
|
summary_md = "\n".join(summary_lines) |
|
|
|
|
|
tl_lines = ["| Time | Detected |", "|------|----------|"] |
|
|
for row in timeline_rows: |
|
|
tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |") |
|
|
timeline_md = "\n".join(tl_lines) |
|
|
|
|
|
recs = ["## Recommendations\n"] |
|
|
if severity in ("Very Mild", "Mild"): |
|
|
recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.") |
|
|
elif severity == "Moderate": |
|
|
recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.") |
|
|
else: |
|
|
recs.append("- Professional speech-language pathology evaluation is strongly recommended.") |
|
|
|
|
|
dominant = max(counts, key=counts.get) |
|
|
if counts[dominant] > 0: |
|
|
recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}") |
|
|
|
|
|
if wpm > 180: |
|
|
recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.") |
|
|
|
|
|
recs.append("\n### Stutter Type Definitions\n") |
|
|
for label, desc in STUTTER_INFO.items(): |
|
|
recs.append(f"- **{label}**: {desc}") |
|
|
|
|
|
recs_md = "\n".join(recs) |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
return summary_md, transcription, timeline_md, recs_md |
|
|
|
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
.gradio-container { max-width: 960px !important; } |
|
|
.gr-button-primary { background: #0f766e !important; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# Speech Fluency Analysis |
|
|
Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection) |
|
|
and **Whisper** (transcription). |
|
|
|
|
|
Supported formats: **WAV, MP3, M4A, FLAC, OGG** |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
audio_in = gr.Audio(label="Upload Audio", type="filepath") |
|
|
threshold = gr.Slider( |
|
|
0.3, 0.7, value=0.5, step=0.05, |
|
|
label="Detection Threshold", |
|
|
info="Lower = more sensitive, Higher = more strict", |
|
|
) |
|
|
btn = gr.Button("Analyze", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Transcription"): |
|
|
trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False) |
|
|
with gr.TabItem("Timeline"): |
|
|
timeline_out = gr.Markdown() |
|
|
with gr.TabItem("Recommendations"): |
|
|
recs_out = gr.Markdown() |
|
|
|
|
|
gr.Markdown( |
|
|
"---\n*Disclaimer: AI-assisted analysis for clinical support only. " |
|
|
"Consult a qualified Speech-Language Pathologist for diagnosis.*" |
|
|
) |
|
|
|
|
|
btn.click( |
|
|
fn=analyze_audio, |
|
|
inputs=[audio_in, threshold], |
|
|
outputs=[summary_out, trans_out, timeline_out, recs_out], |
|
|
show_progress="full", |
|
|
) |
|
|
|
|
|
print("Loading models at startup ...") |
|
|
load_models() |
|
|
|
|
|
print("Launching Gradio ...") |
|
|
demo.queue() |
|
|
demo.launch(ssr_mode=False) |
|
|
|