Torgo-DSR-Lab / app.py
st192011's picture
Update app.py
9f40c62 verified
import gradio as gr
import os
import io
import re
import random
import librosa
import soundfile as sf
import pandas as pd
from gradio_client import Client, handle_file
from transformers import pipeline
from datasets import load_dataset, Audio
from stats_data import (
get_loss_data,
get_loso_f01_data,
get_zeroshot_ua_data,
get_arbitration_table,
SPEAKER_META
)
# 1. Initialize Baseline ASR
print("Initializing Whisper Tiny Baseline...")
whisper_asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny",
generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0}
)
HF_TOKEN = os.getenv("HF_TOKEN")
PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
def normalize_text(text):
if not text: return ""
text = re.sub(r'[^\w\s]', '', text).lower().strip()
return " ".join(text.split())
def format_audio(audio_path):
y, sr = librosa.load(audio_path, sr=16000)
out_path = "formatted_input.wav"
sf.write(out_path, y, sr)
return out_path
def get_sample_logic(speaker_id):
try:
# Standardize matching for metadata keys
meta = SPEAKER_META[speaker_id]
if "UA" in meta["Dataset"]:
dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
dataset = dataset.cast_column("audio", Audio(decode=False))
sample = next(iter(dataset.skip(random.randint(0, 20))))
gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence')
else:
dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
dataset = dataset.cast_column("audio", Audio(decode=False))
def filter_spk(x):
# Handle dropdown display mapping back to dataset IDs
sid_clean = speaker_id.split(" (")[0].upper()
sid = str(x.get('speaker_id', '')).upper()
if not sid or sid == "NONE":
sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
return sid == sid_clean
speaker_ds = dataset.filter(filter_spk)
sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
gt_text = sample.get('transcription') or sample.get('text')
audio_bytes = sample['audio']['bytes']
audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
temp_path = "dataset_sample.wav"
sf.write(temp_path, audio_data, sr)
return temp_path, gt_text.lower().strip(), meta
except Exception as e:
return None, f"Dataset Error: {e}", {}
def process_audio_step_1(audio_path):
if not audio_path: return "No audio loaded", ""
formatted_path = format_audio(audio_path)
result = whisper_asr(formatted_path)
raw_w = result["text"]
norm_w = normalize_text(raw_w)
return raw_w, norm_w
def process_audio_step_2(audio_path, norm_whisper):
if not audio_path or not norm_whisper:
return "Please load data and run Whisper (Step 1) first."
try:
client = Client(PRIVATE_BACKEND_URL, token=HF_TOKEN)
prediction = client.predict(
audio_path=handle_file(audio_path),
whisper_norm=norm_whisper,
api_name="/predict_dsr"
)
return prediction
except Exception as e:
return f"Backend Connection Required. Details: {e}"
# --- UI Construction ---
with gr.Blocks() as demo:
gr.Markdown("# βš—οΈ Torgo DSR Lab")
active_audio_path = gr.State("")
with gr.Tab("πŸ”¬ Laboratory"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Channel A: Research Datasets")
speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Select Speaker Profile", value="F01")
load_btn = gr.Button("Load Sample from Dataset")
gt_box = gr.Textbox(label="Ground Truth (Reference)", interactive=False)
meta_display = gr.JSON(label="Speaker Metadata")
gr.Markdown("---")
with gr.Group():
gr.Markdown("### Channel B: Personal Input")
user_audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio")
user_load_btn = gr.Button("Use This Audio")
with gr.Column(scale=2):
gr.Markdown("### Analysis & Reconstruction")
with gr.Group():
gr.Markdown("#### Step 1: ASR Baseline")
whisper_btn = gr.Button("Run Whisper Tiny")
w_raw = gr.Textbox(label="Whisper Raw Transcript")
w_norm = gr.Textbox(label="Whisper Normalized")
gr.Markdown("---")
with gr.Group():
gr.Markdown("#### Step 2: Neural Reconstruction (Gemma 2)")
model_btn = gr.Button("Run Neural Arbitrator (temperature=0.8)", variant="primary")
final_out = gr.Textbox(label="DSR Lab Prediction (Gemma 2)")
with gr.Tab("πŸ“Š Research Statistics"):
gr.Markdown("## 1. Learning Dynamics (Loss Comparison)")
gr.LinePlot(
get_loss_data(),
x="Step", y="Loss", color="Metric",
title="Training vs. Validation (F01) Loss",
tooltip=["Step", "Loss", "Metric"],
color_title="Metric Type"
)
gr.Markdown("---")
gr.Markdown("## 2. Metric Trends across Checkpoints")
# Accuracy Definition
gr.Markdown("""
### πŸ” Defining Accuracy: Exact Match (EM)
Accuracy in this research is defined as the **Exact Match** of normalized text.
Before comparison, both the Ground Truth and the Model Prediction are converted to lowercase and all punctuation is removed.
A "Success" is only recorded if the two strings match perfectly word-for-word.
""")
with gr.Row():
# LOSO F01
with gr.Column():
gr.Markdown("### Torgo F01 (LOSO Severe)")
f01_df = get_loso_f01_data()
gr.LinePlot(f01_df[f01_df["Metric"] == "Accuracy (%)"], x="Step", y="Value", color="Model", title="F01 Accuracy Trend", y_title="Accuracy (%)", color_title="Legend")
gr.LinePlot(f01_df[f01_df["Metric"] == "WER"], x="Step", y="Value", color="Model", title="F01 WER Trend (Lower is Better)", y_title="WER", color_title="Legend")
# UA F02
with gr.Column():
gr.Markdown("### UA-Speech F02 (Zero-Shot Severe)")
ua_df = get_zeroshot_ua_data()
gr.LinePlot(ua_df[ua_df["Metric"] == "Accuracy (%)"], x="Step", y="Value", color="Model", title="F02 Accuracy Trend", y_title="Accuracy (%)", color_title="Legend")
gr.LinePlot(ua_df[ua_df["Metric"] == "WER"], x="Step", y="Value", color="Model", title="F02 WER Trend (Lower is Better)", y_title="WER", color_title="Legend")
gr.Markdown("---")
gr.Markdown("## 3. Neural Arbitration Logic")
with gr.Row():
with gr.Column(scale=2):
gr.DataFrame(get_arbitration_table())
with gr.Column(scale=1):
gr.Markdown("""
### 🧠 Logic Metrics
* **Whisper Retention (n=11):** Cases where Whisper was correct. A drop indicates high 'aggression' where the model prioritizes phonetic signals over ASR context.
* **Pure Correction (n=205):** Cases where ASR baselines failed. The model successfully recovered the word using phonetic witnesses.
""")
gr.Markdown("---")
gr.Markdown("## 4. Research Journey: Defining a Standardized Protocol for DSR")
gr.Markdown("""
The current state of this DSR GER research is the result of a multi-stage ablation study across model architectures, data provenance, and training methodologies to establish a robust protocol for Dysarthric Speech Reconstruction (DSR).
### **I. Architectural Trajectory: From Encoder-Decoder to LLM Reasoning**
The research evaluated several architectural paradigms to find the optimal balance between token reconstruction and linguistic reasoning:
* **Encoder-Decoder Benchmarks:** Initial experiments utilized **T5 and Byte-T5 (Base/Large)**. While capable of phonetic smoothing, they lacked the deep semantic "arbitration" required for severe distortions.
* **Decoder-Only Scaling:** The project transitioned to LLM backbones, evaluating **Qwen** and **Gemma 3 (4B)**.
* **Optimization Target:** Due to environment-specific constraints and the need for high-parameter reasoning, **Gemma 2 (9B)** was selected as the final backbone, providing the best trade-off for complex phonetic-to-semantic reconstruction.
### **II. Data Provenance and the 'Real-Data Floor'**
We conducted extensive experiments on data mixture and diversity:
* **Synthetic Augmentation:** Large-scale synthetic datasets were generated to improve quantity, domain diversity, and class balance.
* **The Purity Discovery:** Ablation results revealed a critical **Acoustic Floor**: as the proportion of real dysarthric samples (Torgo) decreased in favor of synthetic data, performance dropped significantly. This proved that real-world articulatory distortions possess a unique "phonetic DNA" that cannot yet be fully replicated by current synthetic synthesis.
### **III. The Training Stack: Ablation & Optimization**
To reach the final protocol, we systematically tested a suite of PEFT and regularization techniques:
* **Parameter-Efficient Fine-Tuning:** Comparison of **LoRA vs. DoRA (Weight-Decomposed LoRA)** for improved weight update stability.
* **Regularization & Safeguards:** Integration of **NEFTune** (noise injection) to prevent overfitting and **Modality Dropout** to force the model to prioritize the Phonetic witness over the noisy Semantic witness.
* **Curriculum & Loss Logic:** Implementation of **Curriculum Learning** (Phonetic Anchoring first) combined with **Specialized Loss Masking** to ensure the model learns to reconstruct meaning rather than merely copying inputs.
**Outcome:** This journey has culminated in a **Standardized DSR Protocol**, providing a blueprint for training robust correction layers for atypical speech by prioritizing real-world phonetic grounding and multi-modal arbitration logic.
""")
load_btn.click(get_sample_logic, inputs=speaker_input, outputs=[active_audio_path, gt_box, meta_display])
user_load_btn.click(lambda x: (x, "User Recorded", {"Dataset": "Custom", "Severity": "N/A"}), inputs=user_audio, outputs=[active_audio_path, gt_box, meta_display])
whisper_btn.click(process_audio_step_1, inputs=active_audio_path, outputs=[w_raw, w_norm])
model_btn.click(process_audio_step_2, inputs=[active_audio_path, w_norm], outputs=final_out)
demo.launch(theme=gr.themes.Soft())