Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import io | |
| import re | |
| import random | |
| import librosa | |
| import soundfile as sf | |
| import pandas as pd | |
| from gradio_client import Client, handle_file | |
| from transformers import pipeline | |
| from datasets import load_dataset, Audio | |
| from stats_data import ( | |
| get_loss_data, | |
| get_loso_f01_data, | |
| get_zeroshot_ua_data, | |
| get_arbitration_table, | |
| SPEAKER_META | |
| ) | |
| # 1. Initialize Baseline ASR | |
| print("Initializing Whisper Tiny Baseline...") | |
| whisper_asr = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-tiny", | |
| generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0} | |
| ) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private" | |
| def normalize_text(text): | |
| if not text: return "" | |
| text = re.sub(r'[^\w\s]', '', text).lower().strip() | |
| return " ".join(text.split()) | |
| def format_audio(audio_path): | |
| y, sr = librosa.load(audio_path, sr=16000) | |
| out_path = "formatted_input.wav" | |
| sf.write(out_path, y, sr) | |
| return out_path | |
| def get_sample_logic(speaker_id): | |
| try: | |
| # Standardize matching for metadata keys | |
| meta = SPEAKER_META[speaker_id] | |
| if "UA" in meta["Dataset"]: | |
| dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True) | |
| dataset = dataset.cast_column("audio", Audio(decode=False)) | |
| sample = next(iter(dataset.skip(random.randint(0, 20)))) | |
| gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence') | |
| else: | |
| dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True) | |
| dataset = dataset.cast_column("audio", Audio(decode=False)) | |
| def filter_spk(x): | |
| # Handle dropdown display mapping back to dataset IDs | |
| sid_clean = speaker_id.split(" (")[0].upper() | |
| sid = str(x.get('speaker_id', '')).upper() | |
| if not sid or sid == "NONE": | |
| sid = os.path.basename(x['audio']['path']).split('_')[0].upper() | |
| return sid == sid_clean | |
| speaker_ds = dataset.filter(filter_spk) | |
| sample = next(iter(speaker_ds.shuffle(buffer_size=10))) | |
| gt_text = sample.get('transcription') or sample.get('text') | |
| audio_bytes = sample['audio']['bytes'] | |
| audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000) | |
| temp_path = "dataset_sample.wav" | |
| sf.write(temp_path, audio_data, sr) | |
| return temp_path, gt_text.lower().strip(), meta | |
| except Exception as e: | |
| return None, f"Dataset Error: {e}", {} | |
| def process_audio_step_1(audio_path): | |
| if not audio_path: return "No audio loaded", "" | |
| formatted_path = format_audio(audio_path) | |
| result = whisper_asr(formatted_path) | |
| raw_w = result["text"] | |
| norm_w = normalize_text(raw_w) | |
| return raw_w, norm_w | |
| def process_audio_step_2(audio_path, norm_whisper): | |
| if not audio_path or not norm_whisper: | |
| return "Please load data and run Whisper (Step 1) first." | |
| try: | |
| client = Client(PRIVATE_BACKEND_URL, token=HF_TOKEN) | |
| prediction = client.predict( | |
| audio_path=handle_file(audio_path), | |
| whisper_norm=norm_whisper, | |
| api_name="/predict_dsr" | |
| ) | |
| return prediction | |
| except Exception as e: | |
| return f"Backend Connection Required. Details: {e}" | |
| # --- UI Construction --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# βοΈ Torgo DSR Lab") | |
| active_audio_path = gr.State("") | |
| with gr.Tab("π¬ Laboratory"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Channel A: Research Datasets") | |
| speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Select Speaker Profile", value="F01") | |
| load_btn = gr.Button("Load Sample from Dataset") | |
| gt_box = gr.Textbox(label="Ground Truth (Reference)", interactive=False) | |
| meta_display = gr.JSON(label="Speaker Metadata") | |
| gr.Markdown("---") | |
| with gr.Group(): | |
| gr.Markdown("### Channel B: Personal Input") | |
| user_audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio") | |
| user_load_btn = gr.Button("Use This Audio") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Analysis & Reconstruction") | |
| with gr.Group(): | |
| gr.Markdown("#### Step 1: ASR Baseline") | |
| whisper_btn = gr.Button("Run Whisper Tiny") | |
| w_raw = gr.Textbox(label="Whisper Raw Transcript") | |
| w_norm = gr.Textbox(label="Whisper Normalized") | |
| gr.Markdown("---") | |
| with gr.Group(): | |
| gr.Markdown("#### Step 2: Neural Reconstruction (Gemma 2)") | |
| model_btn = gr.Button("Run Neural Arbitrator (temperature=0.8)", variant="primary") | |
| final_out = gr.Textbox(label="DSR Lab Prediction (Gemma 2)") | |
| with gr.Tab("π Research Statistics"): | |
| gr.Markdown("## 1. Learning Dynamics (Loss Comparison)") | |
| gr.LinePlot( | |
| get_loss_data(), | |
| x="Step", y="Loss", color="Metric", | |
| title="Training vs. Validation (F01) Loss", | |
| tooltip=["Step", "Loss", "Metric"], | |
| color_title="Metric Type" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("## 2. Metric Trends across Checkpoints") | |
| # Accuracy Definition | |
| gr.Markdown(""" | |
| ### π Defining Accuracy: Exact Match (EM) | |
| Accuracy in this research is defined as the **Exact Match** of normalized text. | |
| Before comparison, both the Ground Truth and the Model Prediction are converted to lowercase and all punctuation is removed. | |
| A "Success" is only recorded if the two strings match perfectly word-for-word. | |
| """) | |
| with gr.Row(): | |
| # LOSO F01 | |
| with gr.Column(): | |
| gr.Markdown("### Torgo F01 (LOSO Severe)") | |
| f01_df = get_loso_f01_data() | |
| gr.LinePlot(f01_df[f01_df["Metric"] == "Accuracy (%)"], x="Step", y="Value", color="Model", title="F01 Accuracy Trend", y_title="Accuracy (%)", color_title="Legend") | |
| gr.LinePlot(f01_df[f01_df["Metric"] == "WER"], x="Step", y="Value", color="Model", title="F01 WER Trend (Lower is Better)", y_title="WER", color_title="Legend") | |
| # UA F02 | |
| with gr.Column(): | |
| gr.Markdown("### UA-Speech F02 (Zero-Shot Severe)") | |
| ua_df = get_zeroshot_ua_data() | |
| gr.LinePlot(ua_df[ua_df["Metric"] == "Accuracy (%)"], x="Step", y="Value", color="Model", title="F02 Accuracy Trend", y_title="Accuracy (%)", color_title="Legend") | |
| gr.LinePlot(ua_df[ua_df["Metric"] == "WER"], x="Step", y="Value", color="Model", title="F02 WER Trend (Lower is Better)", y_title="WER", color_title="Legend") | |
| gr.Markdown("---") | |
| gr.Markdown("## 3. Neural Arbitration Logic") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.DataFrame(get_arbitration_table()) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### π§ Logic Metrics | |
| * **Whisper Retention (n=11):** Cases where Whisper was correct. A drop indicates high 'aggression' where the model prioritizes phonetic signals over ASR context. | |
| * **Pure Correction (n=205):** Cases where ASR baselines failed. The model successfully recovered the word using phonetic witnesses. | |
| """) | |
| gr.Markdown("---") | |
| gr.Markdown("## 4. Research Journey: Defining a Standardized Protocol for DSR") | |
| gr.Markdown(""" | |
| The current state of this DSR GER research is the result of a multi-stage ablation study across model architectures, data provenance, and training methodologies to establish a robust protocol for Dysarthric Speech Reconstruction (DSR). | |
| ### **I. Architectural Trajectory: From Encoder-Decoder to LLM Reasoning** | |
| The research evaluated several architectural paradigms to find the optimal balance between token reconstruction and linguistic reasoning: | |
| * **Encoder-Decoder Benchmarks:** Initial experiments utilized **T5 and Byte-T5 (Base/Large)**. While capable of phonetic smoothing, they lacked the deep semantic "arbitration" required for severe distortions. | |
| * **Decoder-Only Scaling:** The project transitioned to LLM backbones, evaluating **Qwen** and **Gemma 3 (4B)**. | |
| * **Optimization Target:** Due to environment-specific constraints and the need for high-parameter reasoning, **Gemma 2 (9B)** was selected as the final backbone, providing the best trade-off for complex phonetic-to-semantic reconstruction. | |
| ### **II. Data Provenance and the 'Real-Data Floor'** | |
| We conducted extensive experiments on data mixture and diversity: | |
| * **Synthetic Augmentation:** Large-scale synthetic datasets were generated to improve quantity, domain diversity, and class balance. | |
| * **The Purity Discovery:** Ablation results revealed a critical **Acoustic Floor**: as the proportion of real dysarthric samples (Torgo) decreased in favor of synthetic data, performance dropped significantly. This proved that real-world articulatory distortions possess a unique "phonetic DNA" that cannot yet be fully replicated by current synthetic synthesis. | |
| ### **III. The Training Stack: Ablation & Optimization** | |
| To reach the final protocol, we systematically tested a suite of PEFT and regularization techniques: | |
| * **Parameter-Efficient Fine-Tuning:** Comparison of **LoRA vs. DoRA (Weight-Decomposed LoRA)** for improved weight update stability. | |
| * **Regularization & Safeguards:** Integration of **NEFTune** (noise injection) to prevent overfitting and **Modality Dropout** to force the model to prioritize the Phonetic witness over the noisy Semantic witness. | |
| * **Curriculum & Loss Logic:** Implementation of **Curriculum Learning** (Phonetic Anchoring first) combined with **Specialized Loss Masking** to ensure the model learns to reconstruct meaning rather than merely copying inputs. | |
| **Outcome:** This journey has culminated in a **Standardized DSR Protocol**, providing a blueprint for training robust correction layers for atypical speech by prioritizing real-world phonetic grounding and multi-modal arbitration logic. | |
| """) | |
| load_btn.click(get_sample_logic, inputs=speaker_input, outputs=[active_audio_path, gt_box, meta_display]) | |
| user_load_btn.click(lambda x: (x, "User Recorded", {"Dataset": "Custom", "Severity": "N/A"}), inputs=user_audio, outputs=[active_audio_path, gt_box, meta_display]) | |
| whisper_btn.click(process_audio_step_1, inputs=active_audio_path, outputs=[w_raw, w_norm]) | |
| model_btn.click(process_audio_step_2, inputs=[active_audio_path, w_norm], outputs=final_out) | |
| demo.launch(theme=gr.themes.Soft()) |