Spaces:
Running
Running
jefffffff9 commited on
Commit ·
76db545
0
Parent(s):
Initial commit: Sahel-Agri Voice AI
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .claude/settings.local.json +7 -0
- .env.example +20 -0
- .gitignore +66 -0
- .vscode/extensions.json +5 -0
- README.md +39 -0
- app.py +611 -0
- configs/api_config.yaml +21 -0
- configs/base_config.yaml +30 -0
- configs/lora_bambara.yaml +19 -0
- configs/lora_fula.yaml +19 -0
- noise_samples/README.md +20 -0
- notebooks/bootstrap_repos.ipynb +308 -0
- notebooks/train_colab.ipynb +283 -0
- packages.txt +1 -0
- requirements.txt +50 -0
- scripts/export_onnx.py +67 -0
- scripts/run_data_pipeline.py +76 -0
- scripts/run_server.py +42 -0
- scripts/train_bambara.py +28 -0
- scripts/train_fula.py +29 -0
- scripts/verify_baseline.py +78 -0
- src/__init__.py +0 -0
- src/api/__init__.py +0 -0
- src/api/app.py +98 -0
- src/api/dependencies.py +20 -0
- src/api/middleware.py +47 -0
- src/api/routes/__init__.py +0 -0
- src/api/routes/health.py +25 -0
- src/api/routes/iot.py +90 -0
- src/api/routes/transcribe.py +74 -0
- src/api/schemas.py +36 -0
- src/data/__init__.py +0 -0
- src/data/agri_dictionary.py +92 -0
- src/data/augmentation.py +84 -0
- src/data/feature_extractor.py +89 -0
- src/data/waxal_loader.py +119 -0
- src/engine/__init__.py +0 -0
- src/engine/adapter_manager.py +106 -0
- src/engine/transcriber.py +132 -0
- src/engine/whisper_base.py +77 -0
- src/iot/__init__.py +0 -0
- src/iot/intent_parser.py +75 -0
- src/iot/sensor_bridge.py +121 -0
- src/iot/voice_responder.py +260 -0
- src/optimization/__init__.py +0 -0
- src/optimization/onnx_exporter.py +106 -0
- src/optimization/quantizer.py +95 -0
- src/optimization/tflite_converter.py +76 -0
- src/training/__init__.py +0 -0
- src/training/callbacks.py +83 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(pip show:*)"
|
| 5 |
+
]
|
| 6 |
+
}
|
| 7 |
+
}
|
.env.example
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace read token (required for accessing google/waxal dataset)
|
| 2 |
+
HF_TOKEN=hf_your_token_here
|
| 3 |
+
|
| 4 |
+
# Model
|
| 5 |
+
MODEL_ID=openai/whisper-large-v3-turbo
|
| 6 |
+
|
| 7 |
+
# Adapter paths (relative to project root)
|
| 8 |
+
BAMBARA_ADAPTER_PATH=./adapters/bambara
|
| 9 |
+
FULA_ADAPTER_PATH=./adapters/fula
|
| 10 |
+
|
| 11 |
+
# IoT sensor API endpoint (leave empty to use mock data in development)
|
| 12 |
+
SENSOR_API_URL=
|
| 13 |
+
|
| 14 |
+
# FastAPI server
|
| 15 |
+
API_HOST=0.0.0.0
|
| 16 |
+
API_PORT=8000
|
| 17 |
+
LOG_LEVEL=INFO
|
| 18 |
+
|
| 19 |
+
# Device: "cuda" for GPU, "cpu" for CPU-only
|
| 20 |
+
DEVICE=cuda
|
.gitignore
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
.eggs/
|
| 11 |
+
|
| 12 |
+
# Environment
|
| 13 |
+
.env
|
| 14 |
+
venv/
|
| 15 |
+
.venv/
|
| 16 |
+
env/
|
| 17 |
+
|
| 18 |
+
# Model weights (large binary files)
|
| 19 |
+
*.pt
|
| 20 |
+
*.pth
|
| 21 |
+
*.bin
|
| 22 |
+
*.safetensors
|
| 23 |
+
*.ckpt
|
| 24 |
+
|
| 25 |
+
# ONNX / TFLite exports
|
| 26 |
+
*.onnx
|
| 27 |
+
*.tflite
|
| 28 |
+
models/onnx/
|
| 29 |
+
models/tflite/
|
| 30 |
+
|
| 31 |
+
# HuggingFace cache
|
| 32 |
+
data_cache/
|
| 33 |
+
.cache/
|
| 34 |
+
|
| 35 |
+
# Audio noise samples (user must provide their own)
|
| 36 |
+
noise_samples/*.wav
|
| 37 |
+
noise_samples/*.mp3
|
| 38 |
+
noise_samples/*.ogg
|
| 39 |
+
|
| 40 |
+
# Trained adapters (tracked separately or via DVC)
|
| 41 |
+
adapters/bambara/
|
| 42 |
+
adapters/fula/
|
| 43 |
+
|
| 44 |
+
# IDE
|
| 45 |
+
.vscode/settings.json
|
| 46 |
+
.idea/
|
| 47 |
+
*.code-workspace
|
| 48 |
+
|
| 49 |
+
# OS
|
| 50 |
+
.DS_Store
|
| 51 |
+
Thumbs.db
|
| 52 |
+
|
| 53 |
+
# Logs
|
| 54 |
+
*.log
|
| 55 |
+
logs/
|
| 56 |
+
|
| 57 |
+
# Local feedback data (audio + corrections live in HF Dataset repo, not git)
|
| 58 |
+
feedback/
|
| 59 |
+
|
| 60 |
+
# Local model downloads
|
| 61 |
+
models/
|
| 62 |
+
|
| 63 |
+
# Pytest
|
| 64 |
+
.pytest_cache/
|
| 65 |
+
htmlcov/
|
| 66 |
+
.coverage
|
.vscode/extensions.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"recommendations": [
|
| 3 |
+
"anthropic.claude-code"
|
| 4 |
+
]
|
| 5 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Sahel-Agri Voice AI
|
| 3 |
+
emoji: 🌾
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
hardware: cpu-basic
|
| 10 |
+
pinned: false
|
| 11 |
+
license: mit
|
| 12 |
+
tags:
|
| 13 |
+
- agriculture
|
| 14 |
+
- bambara
|
| 15 |
+
- fula
|
| 16 |
+
- speech-recognition
|
| 17 |
+
- text-to-speech
|
| 18 |
+
- west-africa
|
| 19 |
+
- low-resource-nlp
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# 🌾 Sahel-Agri Voice AI
|
| 23 |
+
|
| 24 |
+
Two-way voice assistant for Malian and Guinean farmers. Speak in **Bambara** or **Fula** — get agricultural insights spoken back in your language.
|
| 25 |
+
|
| 26 |
+
## Features
|
| 27 |
+
- 🎙️ Voice input via microphone or file upload
|
| 28 |
+
- 🌍 Bambara (bam) and Fula (ful) speech recognition via Whisper + LoRA adapters
|
| 29 |
+
- 🔊 Native-language voice responses via Facebook MMS-TTS
|
| 30 |
+
- 📊 Soil, weather, irrigation, and pest alerts from IoT sensors
|
| 31 |
+
- 💾 Feedback saved to HuggingFace Dataset for continuous improvement
|
| 32 |
+
|
| 33 |
+
## Languages supported
|
| 34 |
+
| Language | STT | TTS |
|
| 35 |
+
|----------|-----|-----|
|
| 36 |
+
| Bambara (bam) | ✅ Whisper + LoRA | ✅ facebook/mms-tts-bam |
|
| 37 |
+
| Fula (ful) | ✅ Whisper + LoRA | ✅ facebook/mms-tts-ful |
|
| 38 |
+
| French (fr) | ✅ Whisper | ✅ facebook/mms-tts-fra |
|
| 39 |
+
| English (en) | ✅ Whisper | ✅ facebook/mms-tts-eng |
|
app.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sahel-Agri Voice AI — HuggingFace Spaces (ZeroGPU)
|
| 3 |
+
Two-way voice assistant: Bambara / Fula / French / English → voice response
|
| 4 |
+
|
| 5 |
+
Environment variables (set in Space Settings → Secrets):
|
| 6 |
+
HF_TOKEN — HF write-access token
|
| 7 |
+
FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private)
|
| 8 |
+
ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private)
|
| 9 |
+
WHISPER_MODEL_ID — default: openai/whisper-large-v3-turbo
|
| 10 |
+
(use openai/whisper-base for local CPU testing)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import io
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import tempfile
|
| 20 |
+
import threading
|
| 21 |
+
from datetime import datetime, timezone
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
ROOT = Path(__file__).parent
|
| 28 |
+
sys.path.insert(0, str(ROOT))
|
| 29 |
+
|
| 30 |
+
# ── env ───────────────────────────────────────────────────────────────────────
|
| 31 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 32 |
+
FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback")
|
| 33 |
+
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters")
|
| 34 |
+
# whisper-small: ~10s on cpu-basic, good multilingual quality.
|
| 35 |
+
# Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
|
| 36 |
+
WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
|
| 37 |
+
|
| 38 |
+
# On local CPU (no HF_TOKEN / no spaces package) fall back gracefully
|
| 39 |
+
_ON_SPACES = os.environ.get("SPACE_ID") is not None
|
| 40 |
+
|
| 41 |
+
SUPPORTED_LANGUAGES = {
|
| 42 |
+
"Bambara (bam)": "bam",
|
| 43 |
+
"Fula (ful)": "ful",
|
| 44 |
+
"French / Français": "fr",
|
| 45 |
+
"English": "en",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# ── ZeroGPU decorator (no-op locally) ────────────────────────────────────────
|
| 49 |
+
try:
|
| 50 |
+
import spaces # type: ignore
|
| 51 |
+
_gpu = spaces.GPU(duration=55)
|
| 52 |
+
except ImportError:
|
| 53 |
+
def _gpu(fn): # local fallback: plain function
|
| 54 |
+
return fn
|
| 55 |
+
|
| 56 |
+
# ── Module-level model state (CPU-resident between requests) ─────────────────
|
| 57 |
+
_whisper_model = None # WhisperForConditionalGeneration (base)
|
| 58 |
+
_whisper_processor = None
|
| 59 |
+
_adapter_manager = None # AdapterManager (wraps base model with PEFT if adapters loaded)
|
| 60 |
+
_model_lock = threading.Lock()
|
| 61 |
+
_model_status = "not loaded"
|
| 62 |
+
_adapters_loaded = set() # set of language codes with loaded adapters, e.g. {"bam", "ful"}
|
| 63 |
+
|
| 64 |
+
from src.tts.mms_tts import MMSTTSEngine
|
| 65 |
+
from src.iot.intent_parser import IntentParser
|
| 66 |
+
from src.iot.sensor_bridge import SensorBridge
|
| 67 |
+
from src.iot.voice_responder import VoiceResponder
|
| 68 |
+
|
| 69 |
+
_tts = MMSTTSEngine()
|
| 70 |
+
_intent_parser = IntentParser()
|
| 71 |
+
_sensor_bridge = SensorBridge()
|
| 72 |
+
|
| 73 |
+
# HF API — only instantiate when token present
|
| 74 |
+
_hf_api = None
|
| 75 |
+
if HF_TOKEN:
|
| 76 |
+
from huggingface_hub import HfApi
|
| 77 |
+
_hf_api = HfApi(token=HF_TOKEN)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ── Model loading ─────────────────────────────────────────────────────────────
|
| 81 |
+
|
| 82 |
+
def _do_load_whisper():
|
| 83 |
+
global _whisper_model, _whisper_processor, _adapter_manager, _model_status
|
| 84 |
+
import torch
|
| 85 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
| 86 |
+
from src.engine.adapter_manager import AdapterManager
|
| 87 |
+
|
| 88 |
+
_model_status = "loading…"
|
| 89 |
+
try:
|
| 90 |
+
_whisper_processor = WhisperProcessor.from_pretrained(
|
| 91 |
+
WHISPER_MODEL_ID, token=HF_TOKEN
|
| 92 |
+
)
|
| 93 |
+
_whisper_model = WhisperForConditionalGeneration.from_pretrained(
|
| 94 |
+
WHISPER_MODEL_ID,
|
| 95 |
+
torch_dtype=torch.float32,
|
| 96 |
+
token=HF_TOKEN,
|
| 97 |
+
)
|
| 98 |
+
_whisper_model.eval()
|
| 99 |
+
|
| 100 |
+
# Create the AdapterManager wrapping the base model
|
| 101 |
+
_adapter_manager = AdapterManager(base_model=_whisper_model, config={})
|
| 102 |
+
|
| 103 |
+
# Try to load adapters from the local adapter repo snapshot (if already downloaded)
|
| 104 |
+
_try_load_local_adapters()
|
| 105 |
+
|
| 106 |
+
_model_status = f"ready ({WHISPER_MODEL_ID})"
|
| 107 |
+
except Exception as e:
|
| 108 |
+
_model_status = f"error: {e}"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _try_load_local_adapters() -> None:
|
| 112 |
+
"""Load any adapter snapshots that are already on disk (downloaded previously)."""
|
| 113 |
+
global _adapters_loaded
|
| 114 |
+
if _adapter_manager is None:
|
| 115 |
+
return
|
| 116 |
+
if not ADAPTER_REPO_ID:
|
| 117 |
+
return
|
| 118 |
+
try:
|
| 119 |
+
from huggingface_hub import try_to_load_from_cache
|
| 120 |
+
lang_dirs = {"bam": "adapters/bambara", "ful": "adapters/fula"}
|
| 121 |
+
for lang, subdir in lang_dirs.items():
|
| 122 |
+
cached = try_to_load_from_cache(
|
| 123 |
+
repo_id=ADAPTER_REPO_ID,
|
| 124 |
+
filename=f"{subdir}/adapter_config.json",
|
| 125 |
+
repo_type="model",
|
| 126 |
+
token=HF_TOKEN,
|
| 127 |
+
)
|
| 128 |
+
if cached:
|
| 129 |
+
import os
|
| 130 |
+
adapter_path = str(os.path.dirname(cached))
|
| 131 |
+
_adapter_manager.register(lang, adapter_path)
|
| 132 |
+
try:
|
| 133 |
+
_adapter_manager.load_adapter(lang)
|
| 134 |
+
_adapters_loaded.add(lang)
|
| 135 |
+
except Exception:
|
| 136 |
+
pass
|
| 137 |
+
except Exception:
|
| 138 |
+
pass # Adapters not cached yet — will load after first Hub download
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _ensure_whisper_loaded():
|
| 142 |
+
"""Load Whisper to CPU in a background thread on first call. Non-blocking."""
|
| 143 |
+
global _model_status
|
| 144 |
+
with _model_lock:
|
| 145 |
+
if _whisper_model is None and "loading" not in _model_status and "error" not in _model_status:
|
| 146 |
+
t = threading.Thread(target=_do_load_whisper, daemon=True)
|
| 147 |
+
t.start()
|
| 148 |
+
return _model_status
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_model_status() -> str:
|
| 152 |
+
s = _ensure_whisper_loaded()
|
| 153 |
+
if "ready" in s:
|
| 154 |
+
return f"🟢 {s}"
|
| 155 |
+
if "loading" in s:
|
| 156 |
+
return f"🟡 {s}"
|
| 157 |
+
if "error" in s:
|
| 158 |
+
return f"🔴 {s}"
|
| 159 |
+
return f"⚪ {s}"
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ── Core GPU pipeline ─────────────────────────────────────────────────────────
|
| 163 |
+
|
| 164 |
+
@_gpu
|
| 165 |
+
def _run_pipeline(audio_path: str, language_code: str):
|
| 166 |
+
"""
|
| 167 |
+
Full STT → Intent → Sensor → TTS pipeline.
|
| 168 |
+
Decorated with @spaces.GPU(duration=55) on HF Spaces; plain function locally.
|
| 169 |
+
Returns: (transcript, response_text, (sample_rate, wav_np))
|
| 170 |
+
"""
|
| 171 |
+
import asyncio
|
| 172 |
+
import torch
|
| 173 |
+
|
| 174 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 175 |
+
|
| 176 |
+
# ── 1. Whisper STT ────────────────────────────────────────────────────────
|
| 177 |
+
if _whisper_model is None:
|
| 178 |
+
return "⏳ Model still loading…", "", None
|
| 179 |
+
|
| 180 |
+
import librosa
|
| 181 |
+
|
| 182 |
+
audio_np, _ = librosa.load(audio_path, sr=16000, mono=True)
|
| 183 |
+
|
| 184 |
+
# Use adapter-wrapped model if an adapter for this language is loaded;
|
| 185 |
+
# otherwise fall back to base Whisper.
|
| 186 |
+
if _adapter_manager is not None and language_code in _adapters_loaded:
|
| 187 |
+
_adapter_manager.activate(language_code)
|
| 188 |
+
active_model = _adapter_manager.get_model()
|
| 189 |
+
else:
|
| 190 |
+
active_model = _whisper_model
|
| 191 |
+
|
| 192 |
+
active_model.to(device)
|
| 193 |
+
with _model_lock:
|
| 194 |
+
inputs = _whisper_processor.feature_extractor(
|
| 195 |
+
audio_np, sampling_rate=16000, return_tensors="pt"
|
| 196 |
+
)
|
| 197 |
+
input_features = inputs.input_features.to(device)
|
| 198 |
+
|
| 199 |
+
# Bambara and Fula have no Whisper language token — pass None so the model
|
| 200 |
+
# auto-detects or falls back to multilingual decoding.
|
| 201 |
+
if language_code in ("bam", "ful"):
|
| 202 |
+
forced_ids = None
|
| 203 |
+
else:
|
| 204 |
+
forced_ids = _whisper_processor.get_decoder_prompt_ids(
|
| 205 |
+
language=language_code, task="transcribe"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
predicted_ids = active_model.generate(
|
| 210 |
+
input_features,
|
| 211 |
+
forced_decoder_ids=forced_ids if forced_ids else None,
|
| 212 |
+
max_new_tokens=256,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
transcript = _whisper_processor.batch_decode(
|
| 216 |
+
predicted_ids, skip_special_tokens=True
|
| 217 |
+
)[0].strip()
|
| 218 |
+
|
| 219 |
+
# Free GPU VRAM before TTS
|
| 220 |
+
active_model.to("cpu")
|
| 221 |
+
if device == "cuda":
|
| 222 |
+
torch.cuda.empty_cache()
|
| 223 |
+
|
| 224 |
+
# ── 2. Intent + sensor data (CPU) ─────────────────────────────────────────
|
| 225 |
+
intent = _intent_parser.parse(transcript, language=language_code)
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
loop = asyncio.new_event_loop()
|
| 229 |
+
sensor_data = loop.run_until_complete(_sensor_bridge.fetch(intent))
|
| 230 |
+
loop.close()
|
| 231 |
+
except Exception:
|
| 232 |
+
from src.iot.sensor_bridge import SensorData
|
| 233 |
+
sensor_data = SensorData(sensor_type="soil", values={
|
| 234 |
+
"moisture_pct": 45.0, "ph": 6.5, "temperature_c": 28.0
|
| 235 |
+
})
|
| 236 |
+
|
| 237 |
+
responder = VoiceResponder(language=language_code)
|
| 238 |
+
response_text = responder.generate_response(intent, sensor_data)
|
| 239 |
+
|
| 240 |
+
# ── 3. MMS-TTS (GPU) ──────────────────────────────────────────────────────
|
| 241 |
+
wav_np, sample_rate = _tts.synthesize(response_text, language_code, device=device)
|
| 242 |
+
|
| 243 |
+
return transcript, response_text, (sample_rate, wav_np)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ── HF Hub feedback persistence ───────────────────────────────────────────────
|
| 247 |
+
|
| 248 |
+
def _save_feedback_to_hub(
|
| 249 |
+
audio_path: str | None,
|
| 250 |
+
transcript: str,
|
| 251 |
+
corrected_text: str,
|
| 252 |
+
response_text: str,
|
| 253 |
+
rating: int,
|
| 254 |
+
notes: str,
|
| 255 |
+
language_label: str,
|
| 256 |
+
) -> str:
|
| 257 |
+
language_code = SUPPORTED_LANGUAGES.get(language_label, "bam")
|
| 258 |
+
|
| 259 |
+
if not corrected_text.strip():
|
| 260 |
+
return "⚠️ Corrected text is empty."
|
| 261 |
+
|
| 262 |
+
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f")
|
| 263 |
+
|
| 264 |
+
record = {
|
| 265 |
+
"id": timestamp,
|
| 266 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 267 |
+
"language": language_code,
|
| 268 |
+
"audio_file": f"audio/{language_code}_{timestamp}.wav",
|
| 269 |
+
"whisper_output": transcript,
|
| 270 |
+
"corrected_text": corrected_text.strip(),
|
| 271 |
+
"response_text": response_text,
|
| 272 |
+
"rating": rating,
|
| 273 |
+
"notes": notes.strip(),
|
| 274 |
+
"is_correction": transcript.strip() != corrected_text.strip(),
|
| 275 |
+
"model": WHISPER_MODEL_ID,
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
if _hf_api is None:
|
| 279 |
+
# Local: save to disk instead
|
| 280 |
+
fb_dir = ROOT / "feedback"
|
| 281 |
+
fb_dir.mkdir(exist_ok=True)
|
| 282 |
+
(fb_dir / "audio").mkdir(exist_ok=True)
|
| 283 |
+
corrections_path = fb_dir / "corrections.jsonl"
|
| 284 |
+
if audio_path:
|
| 285 |
+
import shutil
|
| 286 |
+
shutil.copy2(audio_path, fb_dir / "audio" / f"{language_code}_{timestamp}.wav")
|
| 287 |
+
with open(corrections_path, "a", encoding="utf-8") as f:
|
| 288 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 289 |
+
total = sum(1 for _ in open(corrections_path, encoding="utf-8"))
|
| 290 |
+
return f"✅ Saved locally (#{total}) — HF_TOKEN not set, Hub upload skipped."
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
# Upload audio
|
| 294 |
+
if audio_path:
|
| 295 |
+
_hf_api.upload_file(
|
| 296 |
+
path_or_fileobj=audio_path,
|
| 297 |
+
path_in_repo=f"audio/{language_code}_{timestamp}.wav",
|
| 298 |
+
repo_id=FEEDBACK_REPO_ID,
|
| 299 |
+
repo_type="dataset",
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Download → append → re-upload corrections.jsonl (with retry on conflict)
|
| 303 |
+
from huggingface_hub import hf_hub_download
|
| 304 |
+
for attempt in range(2):
|
| 305 |
+
try:
|
| 306 |
+
local_jsonl = hf_hub_download(
|
| 307 |
+
repo_id=FEEDBACK_REPO_ID,
|
| 308 |
+
filename="corrections.jsonl",
|
| 309 |
+
repo_type="dataset",
|
| 310 |
+
token=HF_TOKEN,
|
| 311 |
+
)
|
| 312 |
+
with open(local_jsonl, encoding="utf-8") as f:
|
| 313 |
+
existing = f.read()
|
| 314 |
+
except Exception:
|
| 315 |
+
existing = ""
|
| 316 |
+
|
| 317 |
+
updated = existing + json.dumps(record, ensure_ascii=False) + "\n"
|
| 318 |
+
buf = io.BytesIO(updated.encode("utf-8"))
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
_hf_api.upload_file(
|
| 322 |
+
path_or_fileobj=buf,
|
| 323 |
+
path_in_repo="corrections.jsonl",
|
| 324 |
+
repo_id=FEEDBACK_REPO_ID,
|
| 325 |
+
repo_type="dataset",
|
| 326 |
+
)
|
| 327 |
+
break
|
| 328 |
+
except Exception as e:
|
| 329 |
+
if attempt == 1:
|
| 330 |
+
return f"⚠️ Audio uploaded but corrections.jsonl update failed: {e}"
|
| 331 |
+
|
| 332 |
+
total = updated.count("\n")
|
| 333 |
+
return f"✅ Saved to Hub (#{total}) — {FEEDBACK_REPO_ID}"
|
| 334 |
+
|
| 335 |
+
except Exception as e:
|
| 336 |
+
return f"❌ Hub upload error: {e}"
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ── Adapter reload ────────────────────────────────────────────────────────────
|
| 340 |
+
|
| 341 |
+
def _reload_adapters_from_hub() -> str:
|
| 342 |
+
global _adapters_loaded
|
| 343 |
+
if _hf_api is None:
|
| 344 |
+
return "⚠️ HF_TOKEN not set — cannot download adapters."
|
| 345 |
+
if _adapter_manager is None:
|
| 346 |
+
return "⏳ Base model not loaded yet — wait for model to finish loading and try again."
|
| 347 |
+
try:
|
| 348 |
+
from huggingface_hub import snapshot_download
|
| 349 |
+
local_dir = snapshot_download(
|
| 350 |
+
repo_id=ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN
|
| 351 |
+
)
|
| 352 |
+
results = []
|
| 353 |
+
for lang, subdir in (("bam", "adapters/bambara"), ("ful", "adapters/fula")):
|
| 354 |
+
adapter_path = Path(local_dir) / subdir
|
| 355 |
+
if not adapter_path.exists():
|
| 356 |
+
results.append(f"⚠️ {lang}: `{subdir}` not found in repo")
|
| 357 |
+
continue
|
| 358 |
+
# Check that this looks like a valid PEFT adapter
|
| 359 |
+
if not (adapter_path / "adapter_config.json").exists():
|
| 360 |
+
results.append(f"⚠️ {lang}: `{subdir}` missing adapter_config.json — run training first")
|
| 361 |
+
continue
|
| 362 |
+
try:
|
| 363 |
+
_adapter_manager.register(lang, str(adapter_path))
|
| 364 |
+
_adapter_manager.load_adapter(lang)
|
| 365 |
+
_adapters_loaded.add(lang)
|
| 366 |
+
results.append(f"✅ {lang}: adapter loaded from `{subdir}`")
|
| 367 |
+
except Exception as e:
|
| 368 |
+
results.append(f"❌ {lang}: load failed — {e}")
|
| 369 |
+
|
| 370 |
+
summary = "\n".join(results)
|
| 371 |
+
active = ", ".join(_adapters_loaded) if _adapters_loaded else "none"
|
| 372 |
+
return f"{summary}\n\n**Active adapters:** {active}\n**Repo:** `{ADAPTER_REPO_ID}`"
|
| 373 |
+
except Exception as e:
|
| 374 |
+
return f"❌ Adapter reload failed: {e}"
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _get_adapter_status() -> str:
|
| 378 |
+
lines = []
|
| 379 |
+
|
| 380 |
+
# Show which adapters are currently active in memory
|
| 381 |
+
if _adapters_loaded:
|
| 382 |
+
lines.append(f"**Active adapters (in memory):** {', '.join(sorted(_adapters_loaded))}")
|
| 383 |
+
else:
|
| 384 |
+
lines.append("**Active adapters:** none — using base Whisper")
|
| 385 |
+
|
| 386 |
+
if _hf_api is None:
|
| 387 |
+
lines.append("_HF_TOKEN not set — Hub check skipped._")
|
| 388 |
+
return "\n".join(lines)
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
from huggingface_hub import list_repo_files
|
| 392 |
+
files = list(list_repo_files(ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN))
|
| 393 |
+
bam_ok = any("bambara" in f and "adapter_config" in f for f in files)
|
| 394 |
+
ful_ok = any("fula" in f and "adapter_config" in f for f in files)
|
| 395 |
+
lines += [
|
| 396 |
+
f"\n**Hub repo:** `{ADAPTER_REPO_ID}`",
|
| 397 |
+
f"- Bambara (bam): {'✅ trained adapter present' if bam_ok else '⚠️ not yet trained — run bootstrap notebook'}",
|
| 398 |
+
f"- Fula (ful): {'✅ trained adapter present' if ful_ok else '⚠️ not yet trained — run bootstrap notebook'}",
|
| 399 |
+
]
|
| 400 |
+
if bam_ok or ful_ok:
|
| 401 |
+
lines.append("\n_Click **Reload Adapters** to activate them._")
|
| 402 |
+
except Exception as e:
|
| 403 |
+
lines.append(f"_Could not read Hub repo: {e}_")
|
| 404 |
+
|
| 405 |
+
return "\n".join(lines)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# ── Main ask handler ──────────────────────────────────────────────────────────
|
| 409 |
+
|
| 410 |
+
def handle_ask(audio_path, language_label):
|
| 411 |
+
if audio_path is None:
|
| 412 |
+
return "⚠️ No audio — press Record or upload a file.", "", None
|
| 413 |
+
|
| 414 |
+
language_code = SUPPORTED_LANGUAGES.get(language_label, "bam")
|
| 415 |
+
status = _ensure_whisper_loaded()
|
| 416 |
+
|
| 417 |
+
if _whisper_model is None:
|
| 418 |
+
return f"⏳ Model loading ({status}). Wait a moment and try again.", "", None
|
| 419 |
+
|
| 420 |
+
try:
|
| 421 |
+
transcript, response_text, audio_out = _run_pipeline(audio_path, language_code)
|
| 422 |
+
return transcript, response_text, audio_out
|
| 423 |
+
except Exception as e:
|
| 424 |
+
return f"❌ {e}", "", None
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# ── Gradio UI ─────────────────────────────────────────────────────────────────
|
| 428 |
+
|
| 429 |
+
def build_ui() -> gr.Blocks:
|
| 430 |
+
with gr.Blocks(title="Sahel-Agri Voice AI") as demo:
|
| 431 |
+
gr.Markdown("# 🌾 Sahel-Agri Voice AI")
|
| 432 |
+
gr.Markdown(
|
| 433 |
+
"Speak in **Bambara** or **Fula** — get agricultural insights spoken back "
|
| 434 |
+
"in your language. Also supports French and English."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
model_status_box = gr.Textbox(
|
| 438 |
+
value=get_model_status,
|
| 439 |
+
label="Model status",
|
| 440 |
+
interactive=False,
|
| 441 |
+
every=3,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
with gr.Tabs():
|
| 445 |
+
|
| 446 |
+
# ── Tab 1: Voice Assistant ────────────────────────────────────────
|
| 447 |
+
with gr.TabItem("🎙️ Voice Assistant"):
|
| 448 |
+
with gr.Row():
|
| 449 |
+
with gr.Column(scale=1):
|
| 450 |
+
language_dd = gr.Dropdown(
|
| 451 |
+
choices=list(SUPPORTED_LANGUAGES.keys()),
|
| 452 |
+
value="Bambara (bam)",
|
| 453 |
+
label="Language / Kan",
|
| 454 |
+
)
|
| 455 |
+
audio_input = gr.Audio(
|
| 456 |
+
sources=["microphone", "upload"],
|
| 457 |
+
type="filepath",
|
| 458 |
+
label="Record or upload audio",
|
| 459 |
+
)
|
| 460 |
+
ask_btn = gr.Button("▶ Ask / Ɲinɛ", variant="primary")
|
| 461 |
+
|
| 462 |
+
with gr.Column(scale=1):
|
| 463 |
+
transcript_box = gr.Textbox(
|
| 464 |
+
label="Whisper heard",
|
| 465 |
+
lines=3,
|
| 466 |
+
placeholder="Your words will appear here…",
|
| 467 |
+
interactive=False,
|
| 468 |
+
)
|
| 469 |
+
response_box = gr.Textbox(
|
| 470 |
+
label="Response / Jaabi",
|
| 471 |
+
lines=3,
|
| 472 |
+
placeholder="Agricultural advice will appear here…",
|
| 473 |
+
interactive=False,
|
| 474 |
+
)
|
| 475 |
+
audio_output = gr.Audio(
|
| 476 |
+
label="Voice response",
|
| 477 |
+
autoplay=True,
|
| 478 |
+
interactive=False,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
ask_btn.click(
|
| 482 |
+
fn=handle_ask,
|
| 483 |
+
inputs=[audio_input, language_dd],
|
| 484 |
+
outputs=[transcript_box, response_box, audio_output],
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# ── Tab 2: Feedback & Correction ─────────────────────────────────
|
| 488 |
+
with gr.TabItem("📝 Feedback & Correction"):
|
| 489 |
+
gr.Markdown(
|
| 490 |
+
"Help improve the model by correcting transcription errors. "
|
| 491 |
+
"Your audio and corrections are saved to the training dataset."
|
| 492 |
+
)
|
| 493 |
+
with gr.Row():
|
| 494 |
+
with gr.Column():
|
| 495 |
+
fb_lang = gr.Dropdown(
|
| 496 |
+
choices=list(SUPPORTED_LANGUAGES.keys()),
|
| 497 |
+
value="Bambara (bam)",
|
| 498 |
+
label="Language",
|
| 499 |
+
)
|
| 500 |
+
fb_audio = gr.Audio(
|
| 501 |
+
sources=["microphone", "upload"],
|
| 502 |
+
type="filepath",
|
| 503 |
+
label="Audio (re-record or upload)",
|
| 504 |
+
)
|
| 505 |
+
fb_transcript = gr.Textbox(
|
| 506 |
+
label="Whisper output (what it heard)",
|
| 507 |
+
lines=3,
|
| 508 |
+
placeholder="Paste or type what Whisper said…",
|
| 509 |
+
)
|
| 510 |
+
fb_corrected = gr.Textbox(
|
| 511 |
+
label="Corrected transcription (what was actually said)",
|
| 512 |
+
lines=3,
|
| 513 |
+
placeholder="Type the correct text here…",
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
with gr.Column():
|
| 517 |
+
fb_response = gr.Textbox(
|
| 518 |
+
label="Response text (optional — for rating)",
|
| 519 |
+
lines=2,
|
| 520 |
+
placeholder="Copy the response from Tab 1…",
|
| 521 |
+
)
|
| 522 |
+
fb_rating = gr.Slider(
|
| 523 |
+
minimum=1, maximum=5, step=1, value=3,
|
| 524 |
+
label="Response quality (1 = poor, 5 = excellent)",
|
| 525 |
+
)
|
| 526 |
+
fb_notes = gr.Textbox(
|
| 527 |
+
label="Notes (optional)",
|
| 528 |
+
lines=2,
|
| 529 |
+
placeholder="e.g. noisy background, strong accent…",
|
| 530 |
+
)
|
| 531 |
+
save_btn = gr.Button("💾 Save to Dataset", variant="secondary")
|
| 532 |
+
save_status = gr.Textbox(
|
| 533 |
+
label="Save status", interactive=False, lines=2
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
save_btn.click(
|
| 537 |
+
fn=_save_feedback_to_hub,
|
| 538 |
+
inputs=[
|
| 539 |
+
fb_audio, fb_transcript, fb_corrected,
|
| 540 |
+
fb_response, fb_rating, fb_notes, fb_lang,
|
| 541 |
+
],
|
| 542 |
+
outputs=[save_status],
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# ── Tab 3: Training Status ────────────────────────────────────────
|
| 546 |
+
with gr.TabItem("🔧 Training Status"):
|
| 547 |
+
gr.Markdown(
|
| 548 |
+
"After collecting ≥10 corrections per language, run the training "
|
| 549 |
+
"notebook on Google Colab (free GPU), then reload adapters here."
|
| 550 |
+
)
|
| 551 |
+
adapter_status_md = gr.Markdown(value=_get_adapter_status())
|
| 552 |
+
reload_btn = gr.Button("🔄 Reload Adapters from Hub")
|
| 553 |
+
reload_out = gr.Markdown()
|
| 554 |
+
|
| 555 |
+
gr.Markdown("---")
|
| 556 |
+
gr.Markdown(
|
| 557 |
+
"**Training notebook**: "
|
| 558 |
+
"`notebooks/train_colab.ipynb` — open in Colab, run all cells."
|
| 559 |
+
)
|
| 560 |
+
gr.Markdown(
|
| 561 |
+
"**Feedback dataset**: "
|
| 562 |
+
f"`{FEEDBACK_REPO_ID}` (private, auto-updated on each save)"
|
| 563 |
+
)
|
| 564 |
+
gr.Markdown(
|
| 565 |
+
"**Adapter repo**: "
|
| 566 |
+
f"`{ADAPTER_REPO_ID}` (private, updated after each training run)"
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
reload_btn.click(
|
| 570 |
+
fn=_reload_adapters_from_hub,
|
| 571 |
+
outputs=[reload_out],
|
| 572 |
+
)
|
| 573 |
+
reload_btn.click(
|
| 574 |
+
fn=_get_adapter_status,
|
| 575 |
+
outputs=[adapter_status_md],
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
return demo
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 582 |
+
|
| 583 |
+
if __name__ == "__main__":
|
| 584 |
+
from dotenv import load_dotenv
|
| 585 |
+
load_dotenv()
|
| 586 |
+
|
| 587 |
+
# Re-read env after dotenv
|
| 588 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 589 |
+
FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback")
|
| 590 |
+
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters")
|
| 591 |
+
WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
|
| 592 |
+
|
| 593 |
+
if HF_TOKEN:
|
| 594 |
+
from huggingface_hub import HfApi
|
| 595 |
+
_hf_api = HfApi(token=HF_TOKEN)
|
| 596 |
+
|
| 597 |
+
# Kick off background model load immediately
|
| 598 |
+
_ensure_whisper_loaded()
|
| 599 |
+
|
| 600 |
+
print(f"Whisper model : {WHISPER_MODEL_ID}")
|
| 601 |
+
print(f"Feedback repo : {FEEDBACK_REPO_ID}")
|
| 602 |
+
print(f"Adapter repo : {ADAPTER_REPO_ID}")
|
| 603 |
+
print(f"HF_TOKEN set : {'yes' if HF_TOKEN else 'no (local-only mode)'}")
|
| 604 |
+
print()
|
| 605 |
+
|
| 606 |
+
demo = build_ui()
|
| 607 |
+
demo.launch(
|
| 608 |
+
server_port=9001,
|
| 609 |
+
inbrowser=True,
|
| 610 |
+
share=False,
|
| 611 |
+
)
|
configs/api_config.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
server:
|
| 2 |
+
host: "0.0.0.0"
|
| 3 |
+
port: 8000
|
| 4 |
+
workers: 1 # Single worker: shares GPU model in memory
|
| 5 |
+
timeout_keep_alive: 30
|
| 6 |
+
|
| 7 |
+
inference:
|
| 8 |
+
default_language: "bam"
|
| 9 |
+
max_audio_size_mb: 10
|
| 10 |
+
supported_languages:
|
| 11 |
+
- "bam"
|
| 12 |
+
- "ful"
|
| 13 |
+
|
| 14 |
+
iot:
|
| 15 |
+
sensor_poll_timeout_s: 5
|
| 16 |
+
response_language: "fr" # French for farmer-facing TTS output
|
| 17 |
+
intent_confidence_threshold: 0.7
|
| 18 |
+
|
| 19 |
+
rate_limit:
|
| 20 |
+
requests_per_minute: 60
|
| 21 |
+
burst: 10
|
configs/base_config.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
id: "openai/whisper-large-v3-turbo"
|
| 3 |
+
task: "transcribe"
|
| 4 |
+
max_new_tokens: 128
|
| 5 |
+
chunk_length_s: 30
|
| 6 |
+
|
| 7 |
+
training:
|
| 8 |
+
output_dir: "./adapters"
|
| 9 |
+
per_device_train_batch_size: 4
|
| 10 |
+
gradient_accumulation_steps: 4
|
| 11 |
+
warmup_steps: 200
|
| 12 |
+
max_steps: 4000
|
| 13 |
+
save_steps: 500
|
| 14 |
+
eval_steps: 500
|
| 15 |
+
learning_rate: 1.0e-4
|
| 16 |
+
fp16: true
|
| 17 |
+
# CRITICAL on Windows: multiprocessing spawn breaks with tokenizers
|
| 18 |
+
dataloader_num_workers: 0
|
| 19 |
+
|
| 20 |
+
audio:
|
| 21 |
+
sample_rate: 16000
|
| 22 |
+
max_duration_s: 30
|
| 23 |
+
noise_snr_db_range: [5, 20]
|
| 24 |
+
augmentation_prob: 0.6
|
| 25 |
+
|
| 26 |
+
paths:
|
| 27 |
+
data_cache: "./data_cache"
|
| 28 |
+
adapters: "./adapters"
|
| 29 |
+
models: "./models"
|
| 30 |
+
noise_samples: "./noise_samples"
|
configs/lora_bambara.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
language: "bam"
|
| 2 |
+
language_code: "bm" # ISO 639-1 code used for Whisper forced_decoder_ids
|
| 3 |
+
dataset_subset: "bam"
|
| 4 |
+
adapter_name: "bambara"
|
| 5 |
+
output_dir: "./adapters/bambara"
|
| 6 |
+
|
| 7 |
+
lora:
|
| 8 |
+
r: 32
|
| 9 |
+
lora_alpha: 64
|
| 10 |
+
target_modules:
|
| 11 |
+
- "q_proj"
|
| 12 |
+
- "v_proj"
|
| 13 |
+
- "k_proj"
|
| 14 |
+
- "out_proj"
|
| 15 |
+
- "fc1"
|
| 16 |
+
- "fc2"
|
| 17 |
+
lora_dropout: 0.05
|
| 18 |
+
bias: "none"
|
| 19 |
+
task_type: "SEQ_2_SEQ_LM"
|
configs/lora_fula.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
language: "ful"
|
| 2 |
+
language_code: "ff" # ISO 639-1 code used for Whisper forced_decoder_ids
|
| 3 |
+
dataset_subset: "ful"
|
| 4 |
+
adapter_name: "fula"
|
| 5 |
+
output_dir: "./adapters/fula"
|
| 6 |
+
|
| 7 |
+
lora:
|
| 8 |
+
r: 16 # Smaller rank — Fula dataset is smaller than Bambara
|
| 9 |
+
lora_alpha: 32
|
| 10 |
+
target_modules:
|
| 11 |
+
- "q_proj"
|
| 12 |
+
- "v_proj"
|
| 13 |
+
- "k_proj"
|
| 14 |
+
- "out_proj"
|
| 15 |
+
- "fc1"
|
| 16 |
+
- "fc2"
|
| 17 |
+
lora_dropout: 0.05
|
| 18 |
+
bias: "none"
|
| 19 |
+
task_type: "SEQ_2_SEQ_LM"
|
noise_samples/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Field Noise Samples
|
| 2 |
+
|
| 3 |
+
Place `.wav` audio files here to enable realistic field-noise augmentation during training.
|
| 4 |
+
|
| 5 |
+
## Required Files (16kHz mono, any duration ≥5s)
|
| 6 |
+
- `tractor_engine.wav` — diesel tractor idling or working
|
| 7 |
+
- `wind_field.wav` — wind in open farmland
|
| 8 |
+
- `livestock_ambient.wav` — cattle, goats, or chickens in background
|
| 9 |
+
|
| 10 |
+
## Suggested Sources
|
| 11 |
+
- [Freesound.org](https://freesound.org) — search "tractor", "wind field", "livestock ambient" (filter by CC0 / CC-BY)
|
| 12 |
+
- Field recordings from partner NGOs or agricultural organizations in Mali/Guinea
|
| 13 |
+
|
| 14 |
+
## Licensing Note
|
| 15 |
+
Ensure all audio files are licensed for use in ML training datasets.
|
| 16 |
+
CC0 (public domain) or CC-BY are preferred.
|
| 17 |
+
|
| 18 |
+
## Without Noise Files
|
| 19 |
+
The augmenter will fall back to Gaussian noise only.
|
| 20 |
+
Training will still work but model robustness to real-world conditions may be reduced.
|
notebooks/bootstrap_repos.ipynb
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 5,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"kernelspec": {
|
| 6 |
+
"display_name": "Python 3",
|
| 7 |
+
"language": "python",
|
| 8 |
+
"name": "python3"
|
| 9 |
+
},
|
| 10 |
+
"language_info": {
|
| 11 |
+
"name": "python",
|
| 12 |
+
"version": "3.10.0"
|
| 13 |
+
},
|
| 14 |
+
"colab": {
|
| 15 |
+
"provenance": [],
|
| 16 |
+
"gpuType": "T4"
|
| 17 |
+
},
|
| 18 |
+
"accelerator": "GPU"
|
| 19 |
+
},
|
| 20 |
+
"cells": [
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"id": "cell-title",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"source": [
|
| 26 |
+
"# 🌾 Sahel-Agri Voice AI — One-Time Bootstrap\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"**Run this notebook ONCE** before deploying your Space. It:\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"1. Creates the three HuggingFace repos (`sahel-agri-feedback`, `sahel-agri-adapters`, `sahel-agri-voice`)\n",
|
| 31 |
+
"2. Seeds the feedback dataset with a `corrections.jsonl` placeholder\n",
|
| 32 |
+
"3. Trains v0 LoRA adapters for **Bambara** and **Fula** on the full Google Waxal dataset\n",
|
| 33 |
+
"4. Pushes adapters to `ous-sow/sahel-agri-adapters`\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"After this notebook completes, push your project code to the Space and your app will start\n",
|
| 36 |
+
"with working Bambara/Fula speech recognition from day 1 — **no user corrections needed yet**.\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"For subsequent improvement runs (after collecting farmer feedback), use `train_colab.ipynb`.\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"---\n",
|
| 41 |
+
"**Before running:** Runtime → Change runtime type → **T4 GPU**"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"id": "cell-gpu-check",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"# Cell 1 — GPU check\n",
|
| 52 |
+
"import subprocess\n",
|
| 53 |
+
"result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
|
| 54 |
+
"if result.returncode != 0:\n",
|
| 55 |
+
" raise RuntimeError('No GPU! Runtime → Change runtime type → T4 GPU')\n",
|
| 56 |
+
"print(result.stdout[:500])\n",
|
| 57 |
+
"print('✅ GPU ready')"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": null,
|
| 63 |
+
"id": "cell-install",
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": [
|
| 67 |
+
"# Cell 2 — Install dependencies\n",
|
| 68 |
+
"!pip install -q \\\n",
|
| 69 |
+
" torch==2.11.0 torchaudio==2.11.0 \\\n",
|
| 70 |
+
" transformers==5.5.0 datasets==4.8.4 \\\n",
|
| 71 |
+
" accelerate==1.13.0 evaluate==0.4.2 \\\n",
|
| 72 |
+
" huggingface-hub==1.9.0 peft==0.18.1 \\\n",
|
| 73 |
+
" librosa==0.10.2 soundfile==0.12.1 \\\n",
|
| 74 |
+
" jiwer==3.0.4 pyyaml==6.0.2\n",
|
| 75 |
+
"print('✅ Packages installed')"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": null,
|
| 81 |
+
"id": "cell-hf-login",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": "# Cell 3 — HuggingFace login\n# Colab: 🔑 icon (left sidebar) → Add new secret → name=HF_TOKEN\nimport os\ntry:\n from google.colab import userdata # type: ignore\n HF_TOKEN = userdata.get('HF_TOKEN')\nexcept Exception:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError(\n 'HF_TOKEN not found.\\n'\n 'Colab: click the 🔑 icon → Add new secret → name=HF_TOKEN'\n )\n\nfrom huggingface_hub import login, HfApi\nlogin(token=HF_TOKEN, add_to_git_credential=False)\napi = HfApi(token=HF_TOKEN)\n\nHF_USERNAME = 'ous-sow'\nFEEDBACK_REPO_ID = f'{HF_USERNAME}/sahel-agri-feedback'\nADAPTER_REPO_ID = f'{HF_USERNAME}/sahel-agri-adapters'\nSPACE_REPO_ID = f'{HF_USERNAME}/sahel-agri-voice'\n# whisper-small trains on Colab T4 in ~25 min and runs on CPU in ~10s.\n# Change to 'openai/whisper-large-v3-turbo' only if you upgrade to a GPU Space.\nWHISPER_MODEL_ID = 'openai/whisper-small'\n\nprint(f'✅ Logged in as {HF_USERNAME}')"
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"id": "cell-create-repos",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"outputs": [],
|
| 92 |
+
"source": [
|
| 93 |
+
"# Cell 4 — Create HuggingFace repos (skips if they already exist)\n",
|
| 94 |
+
"from huggingface_hub import RepoUrl\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"def create_repo_if_missing(repo_id, repo_type, private=True):\n",
|
| 97 |
+
" try:\n",
|
| 98 |
+
" url = api.create_repo(\n",
|
| 99 |
+
" repo_id=repo_id,\n",
|
| 100 |
+
" repo_type=repo_type,\n",
|
| 101 |
+
" private=private,\n",
|
| 102 |
+
" exist_ok=True,\n",
|
| 103 |
+
" )\n",
|
| 104 |
+
" print(f' ✅ {repo_type}: {repo_id}')\n",
|
| 105 |
+
" return url\n",
|
| 106 |
+
" except Exception as e:\n",
|
| 107 |
+
" print(f' ⚠️ {repo_id}: {e}')\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"print('Creating repos...')\n",
|
| 110 |
+
"create_repo_if_missing(FEEDBACK_REPO_ID, 'dataset', private=True)\n",
|
| 111 |
+
"create_repo_if_missing(ADAPTER_REPO_ID, 'model', private=True)\n",
|
| 112 |
+
"create_repo_if_missing(SPACE_REPO_ID, 'space', private=False)\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# Seed the feedback dataset with an empty corrections.jsonl\n",
|
| 115 |
+
"import io\n",
|
| 116 |
+
"try:\n",
|
| 117 |
+
" api.upload_file(\n",
|
| 118 |
+
" path_or_fileobj=io.BytesIO(b''),\n",
|
| 119 |
+
" path_in_repo='corrections.jsonl',\n",
|
| 120 |
+
" repo_id=FEEDBACK_REPO_ID,\n",
|
| 121 |
+
" repo_type='dataset',\n",
|
| 122 |
+
" commit_message='Init: empty corrections.jsonl',\n",
|
| 123 |
+
" )\n",
|
| 124 |
+
" print(f' ✅ {FEEDBACK_REPO_ID}/corrections.jsonl initialised')\n",
|
| 125 |
+
"except Exception as e:\n",
|
| 126 |
+
" print(f' ⚠️ corrections.jsonl upload: {e} (may already exist)')"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"execution_count": null,
|
| 132 |
+
"id": "cell-clone-space",
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"# Cell 5 — Clone Space code (so we can use src/ and configs/)\n",
|
| 137 |
+
"# If the Space is brand new and has no code yet, clone from the local zip instead.\n",
|
| 138 |
+
"import sys\n",
|
| 139 |
+
"from pathlib import Path\n",
|
| 140 |
+
"from huggingface_hub import snapshot_download\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"try:\n",
|
| 143 |
+
" space_dir = Path(snapshot_download(\n",
|
| 144 |
+
" repo_id=SPACE_REPO_ID, repo_type='space', token=HF_TOKEN\n",
|
| 145 |
+
" ))\n",
|
| 146 |
+
" print(f'Space code: {space_dir}')\n",
|
| 147 |
+
"except Exception as e:\n",
|
| 148 |
+
" print(f'Could not download Space ({e})')\n",
|
| 149 |
+
" print('Uploading project code to Space first...')\n",
|
| 150 |
+
" # If you have the project on Colab already (e.g. mounted Drive), set:\n",
|
| 151 |
+
" # space_dir = Path('/content/drive/MyDrive/voice-model')\n",
|
| 152 |
+
" # Otherwise upload via git (see README step 6) and re-run this cell.\n",
|
| 153 |
+
" raise RuntimeError(\n",
|
| 154 |
+
" 'Push your project to the Space first:\\n'\n",
|
| 155 |
+
" ' git remote add space https://huggingface.co/spaces/ous-sow/sahel-agri-voice\\n'\n",
|
| 156 |
+
" ' git push space main\\n'\n",
|
| 157 |
+
" 'Then re-run this notebook.'\n",
|
| 158 |
+
" )\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"sys.path.insert(0, str(space_dir))"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": null,
|
| 166 |
+
"id": "cell-train-bam",
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"# Cell 6 — Train v0 Bambara adapter on full Waxal (bam)\n",
|
| 171 |
+
"#\n",
|
| 172 |
+
"# Uses streaming — Waxal is ~4h of audio, we cap at 2000 samples for Colab budget.\n",
|
| 173 |
+
"# Full training (~4000 steps) on the entire dataset: use a Kaggle P100 (12h limit).\n",
|
| 174 |
+
"import os, yaml\n",
|
| 175 |
+
"os.environ['HF_TOKEN'] = HF_TOKEN\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"from src.training.trainer import WhisperLoRATrainer\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"WAXAL_CAP = 2000 # raise to 10000+ on Kaggle for a stronger v0 model\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"base_cfg = str(space_dir / 'configs' / 'base_config.yaml')\n",
|
| 182 |
+
"bam_cfg_src = str(space_dir / 'configs' / 'lora_bambara.yaml')\n",
|
| 183 |
+
"bam_out = '/tmp/sahel_adapter_bam'\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"# Override output_dir\n",
|
| 186 |
+
"with open(bam_cfg_src) as f:\n",
|
| 187 |
+
" bam_config = yaml.safe_load(f)\n",
|
| 188 |
+
"bam_config['output_dir'] = bam_out\n",
|
| 189 |
+
"tmp_bam_cfg = '/tmp/lora_bam.yaml'\n",
|
| 190 |
+
"with open(tmp_bam_cfg, 'w') as f:\n",
|
| 191 |
+
" yaml.dump(bam_config, f)\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"# Also override max_steps in base config to match Waxal cap\n",
|
| 194 |
+
"with open(base_cfg) as f:\n",
|
| 195 |
+
" base_config = yaml.safe_load(f)\n",
|
| 196 |
+
"# ~2 steps per sample @ batch_size=4, gradient_acc=4\n",
|
| 197 |
+
"base_config['training']['max_steps'] = max(500, WAXAL_CAP // 8)\n",
|
| 198 |
+
"tmp_base_cfg = '/tmp/base_config.yaml'\n",
|
| 199 |
+
"with open(tmp_base_cfg, 'w') as f:\n",
|
| 200 |
+
" yaml.dump(base_config, f)\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"print(f'Training Bambara v0 adapter (Waxal cap={WAXAL_CAP}, max_steps={base_config[\"training\"][\"max_steps\"]})...')\n",
|
| 203 |
+
"trainer_bam = WhisperLoRATrainer(\n",
|
| 204 |
+
" base_config_path=tmp_base_cfg,\n",
|
| 205 |
+
" language_config_path=tmp_bam_cfg,\n",
|
| 206 |
+
")\n",
|
| 207 |
+
"trainer_bam.setup()\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"# No feedback yet — materialise Waxal and train\n",
|
| 210 |
+
"trainer_bam.merge_extra_data([], repeat=1, waxal_cap=WAXAL_CAP)\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"trainer_bam.train()\n",
|
| 213 |
+
"print(f'✅ Bambara v0 adapter saved to {bam_out}')"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": null,
|
| 219 |
+
"id": "cell-train-ful",
|
| 220 |
+
"metadata": {},
|
| 221 |
+
"outputs": [],
|
| 222 |
+
"source": [
|
| 223 |
+
"# Cell 7 — Train v0 Fula adapter on full Waxal (ful)\n",
|
| 224 |
+
"ful_cfg_src = str(space_dir / 'configs' / 'lora_fula.yaml')\n",
|
| 225 |
+
"ful_out = '/tmp/sahel_adapter_ful'\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"with open(ful_cfg_src) as f:\n",
|
| 228 |
+
" ful_config = yaml.safe_load(f)\n",
|
| 229 |
+
"ful_config['output_dir'] = ful_out\n",
|
| 230 |
+
"tmp_ful_cfg = '/tmp/lora_ful.yaml'\n",
|
| 231 |
+
"with open(tmp_ful_cfg, 'w') as f:\n",
|
| 232 |
+
" yaml.dump(ful_config, f)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"print(f'Training Fula v0 adapter (Waxal cap={WAXAL_CAP})...')\n",
|
| 235 |
+
"trainer_ful = WhisperLoRATrainer(\n",
|
| 236 |
+
" base_config_path=tmp_base_cfg,\n",
|
| 237 |
+
" language_config_path=tmp_ful_cfg,\n",
|
| 238 |
+
")\n",
|
| 239 |
+
"trainer_ful.setup()\n",
|
| 240 |
+
"trainer_ful.merge_extra_data([], repeat=1, waxal_cap=WAXAL_CAP)\n",
|
| 241 |
+
"trainer_ful.train()\n",
|
| 242 |
+
"print(f'✅ Fula v0 adapter saved to {ful_out}')"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "code",
|
| 247 |
+
"execution_count": null,
|
| 248 |
+
"id": "cell-push-adapters",
|
| 249 |
+
"metadata": {},
|
| 250 |
+
"outputs": [],
|
| 251 |
+
"source": [
|
| 252 |
+
"# Cell 8 — Push both adapters to HF Model repo\n",
|
| 253 |
+
"from huggingface_hub import HfApi\n",
|
| 254 |
+
"api = HfApi(token=HF_TOKEN)\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"for lang, out_dir, path_in_repo in [\n",
|
| 257 |
+
" ('bam', bam_out, 'adapters/bambara'),\n",
|
| 258 |
+
" ('ful', ful_out, 'adapters/fula'),\n",
|
| 259 |
+
"]:\n",
|
| 260 |
+
" api.upload_folder(\n",
|
| 261 |
+
" folder_path=out_dir,\n",
|
| 262 |
+
" repo_id=ADAPTER_REPO_ID,\n",
|
| 263 |
+
" repo_type='model',\n",
|
| 264 |
+
" path_in_repo=path_in_repo,\n",
|
| 265 |
+
" commit_message=f'v0 {lang} adapter trained on Waxal (cap={WAXAL_CAP} samples)',\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
" print(f'✅ {lang} → {ADAPTER_REPO_ID}/{path_in_repo}')\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"print()\n",
|
| 270 |
+
"print('Bootstrap complete!')\n",
|
| 271 |
+
"print()\n",
|
| 272 |
+
"print('Next steps:')\n",
|
| 273 |
+
"print(' 1. Push your project code to the Space (git push space main)')\n",
|
| 274 |
+
"print(' 2. In Space Settings → Secrets, add HF_TOKEN, FEEDBACK_REPO_ID, ADAPTER_REPO_ID')\n",
|
| 275 |
+
"print(' 3. Space will build — your app at https://huggingface.co/spaces/ous-sow/sahel-agri-voice')\n",
|
| 276 |
+
"print(' 4. Tab 3 → Reload Adapters — Bambara + Fula adapters will be loaded')\n",
|
| 277 |
+
"print(' 5. Collect farmer corrections, then run train_colab.ipynb to keep improving')"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"cell_type": "code",
|
| 282 |
+
"execution_count": null,
|
| 283 |
+
"id": "cell-verify",
|
| 284 |
+
"metadata": {},
|
| 285 |
+
"outputs": [],
|
| 286 |
+
"source": [
|
| 287 |
+
"# Cell 9 — Quick verification: list what was pushed to the adapter repo\n",
|
| 288 |
+
"from huggingface_hub import list_repo_files\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"files = sorted(list_repo_files(ADAPTER_REPO_ID, repo_type='model', token=HF_TOKEN))\n",
|
| 291 |
+
"print(f'Files in {ADAPTER_REPO_ID}:')\n",
|
| 292 |
+
"for f in files:\n",
|
| 293 |
+
" print(f' {f}')\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"bam_ok = any('bambara/adapter_config.json' in f for f in files)\n",
|
| 296 |
+
"ful_ok = any('fula/adapter_config.json' in f for f in files)\n",
|
| 297 |
+
"print()\n",
|
| 298 |
+
"print(f'Bambara adapter: {\"✅\" if bam_ok else \"❌\"}')\n",
|
| 299 |
+
"print(f'Fula adapter: {\"✅\" if ful_ok else \"❌\"}')\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"if bam_ok and ful_ok:\n",
|
| 302 |
+
" print('\\n🎉 Both adapters ready. Your Space will use them automatically on the next reload.')\n",
|
| 303 |
+
"else:\n",
|
| 304 |
+
" print('\\n⚠️ Some adapters are missing — check the training cells above for errors.')"
|
| 305 |
+
]
|
| 306 |
+
}
|
| 307 |
+
]
|
| 308 |
+
}
|
notebooks/train_colab.ipynb
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 5,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"kernelspec": {
|
| 6 |
+
"display_name": "Python 3",
|
| 7 |
+
"language": "python",
|
| 8 |
+
"name": "python3"
|
| 9 |
+
},
|
| 10 |
+
"language_info": {
|
| 11 |
+
"name": "python",
|
| 12 |
+
"version": "3.10.0"
|
| 13 |
+
},
|
| 14 |
+
"colab": {
|
| 15 |
+
"provenance": [],
|
| 16 |
+
"gpuType": "T4"
|
| 17 |
+
},
|
| 18 |
+
"accelerator": "GPU"
|
| 19 |
+
},
|
| 20 |
+
"cells": [
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"id": "cell-title",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"source": [
|
| 26 |
+
"# 🌾 Sahel-Agri Voice AI — Fine-tune on Farmer Feedback\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"**Run after collecting ≥10 corrections in the Space.** \n",
|
| 29 |
+
"First run? Use `bootstrap_repos.ipynb` instead to train the v0 Waxal adapter.\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"This notebook fine-tunes the existing LoRA adapter using:\n",
|
| 32 |
+
"- **Waxal baseline** (up to 500 samples) — keeps the model grounded\n",
|
| 33 |
+
"- **Farmer corrections** (3× upsampled) — targeted improvement from real field use\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"**Before running:** Runtime → Change runtime type → **T4 GPU**"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"id": "cell-gpu-check",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"# Cell 1 — GPU check\n",
|
| 46 |
+
"import subprocess\n",
|
| 47 |
+
"result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
|
| 48 |
+
"if result.returncode != 0:\n",
|
| 49 |
+
" raise RuntimeError('No GPU! Runtime → Change runtime type → T4 GPU')\n",
|
| 50 |
+
"print(result.stdout[:500])\n",
|
| 51 |
+
"print('✅ GPU ready')"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"id": "cell-install",
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"# Cell 2 — Install dependencies (matching Space versions)\n",
|
| 62 |
+
"!pip install -q \\\n",
|
| 63 |
+
" torch==2.11.0 torchaudio==2.11.0 \\\n",
|
| 64 |
+
" transformers==5.5.0 datasets==4.8.4 \\\n",
|
| 65 |
+
" accelerate==1.13.0 evaluate==0.4.2 \\\n",
|
| 66 |
+
" huggingface-hub==1.9.0 peft==0.18.1 \\\n",
|
| 67 |
+
" librosa==0.10.2 soundfile==0.12.1 \\\n",
|
| 68 |
+
" jiwer==3.0.4 pyyaml==6.0.2\n",
|
| 69 |
+
"print('✅ Packages installed')"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"id": "cell-hf-login",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [],
|
| 78 |
+
"source": "# Cell 3 — HuggingFace login\n# Colab: 🔑 icon (left sidebar) → Add new secret → name=HF_TOKEN\n# Kaggle: Add Data → add as Kaggle secret named HF_TOKEN\nimport os\ntry:\n from google.colab import userdata # type: ignore\n HF_TOKEN = userdata.get('HF_TOKEN')\nexcept Exception:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError('HF_TOKEN not found — see instructions above.')\n\nfrom huggingface_hub import login\nlogin(token=HF_TOKEN, add_to_git_credential=False)\n\nSPACE_REPO_ID = 'ous-sow/sahel-agri-voice'\nFEEDBACK_REPO_ID = 'ous-sow/sahel-agri-feedback'\nADAPTER_REPO_ID = 'ous-sow/sahel-agri-adapters'\n# Must match what the Space uses — whisper-small for cpu-basic, whisper-large-v3-turbo for GPU.\nWHISPER_MODEL_ID = 'openai/whisper-small'\nTRAIN_LANG = 'bam' # ← change to 'ful' for Fula\n\nprint(f'✅ Logged in | training language: {TRAIN_LANG}')"
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"id": "cell-download",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"# Cell 4 — Download Space code and feedback corrections\n",
|
| 88 |
+
"import json, shutil, sys\n",
|
| 89 |
+
"from pathlib import Path\n",
|
| 90 |
+
"from huggingface_hub import snapshot_download, hf_hub_download\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"# Get Space code (contains src/, configs/)\n",
|
| 93 |
+
"space_dir = Path(snapshot_download(\n",
|
| 94 |
+
" repo_id=SPACE_REPO_ID, repo_type='space', token=HF_TOKEN\n",
|
| 95 |
+
"))\n",
|
| 96 |
+
"sys.path.insert(0, str(space_dir))\n",
|
| 97 |
+
"print(f'Space code: {space_dir}')\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"# Download feedback corrections.jsonl\n",
|
| 100 |
+
"jsonl_path = hf_hub_download(\n",
|
| 101 |
+
" repo_id=FEEDBACK_REPO_ID,\n",
|
| 102 |
+
" filename='corrections.jsonl',\n",
|
| 103 |
+
" repo_type='dataset',\n",
|
| 104 |
+
" token=HF_TOKEN,\n",
|
| 105 |
+
")\n",
|
| 106 |
+
"with open(jsonl_path, encoding='utf-8') as f:\n",
|
| 107 |
+
" all_records = [json.loads(l) for l in f if l.strip()]\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"corrections = [\n",
|
| 110 |
+
" r for r in all_records\n",
|
| 111 |
+
" if r.get('is_correction') and r['language'] == TRAIN_LANG\n",
|
| 112 |
+
"]\n",
|
| 113 |
+
"print(f'Total feedback records : {len(all_records)}')\n",
|
| 114 |
+
"print(f'Corrections for {TRAIN_LANG} : {len(corrections)}')\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"if len(corrections) < 5:\n",
|
| 117 |
+
" print('⚠️ Very few corrections — consider collecting more before training.')\n",
|
| 118 |
+
" print(' Training will proceed with Waxal only (corrections will be skipped).')"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": null,
|
| 124 |
+
"id": "cell-download-audio",
|
| 125 |
+
"metadata": {},
|
| 126 |
+
"outputs": [],
|
| 127 |
+
"source": [
|
| 128 |
+
"# Cell 5 — Download feedback audio files from HF Dataset repo\n",
|
| 129 |
+
"fb_audio_dir = Path('/tmp/sahel_feedback_audio')\n",
|
| 130 |
+
"fb_audio_dir.mkdir(exist_ok=True)\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"skipped = 0\n",
|
| 133 |
+
"for rec in corrections:\n",
|
| 134 |
+
" local_path = fb_audio_dir / Path(rec['audio_file']).name\n",
|
| 135 |
+
" if local_path.exists():\n",
|
| 136 |
+
" continue\n",
|
| 137 |
+
" try:\n",
|
| 138 |
+
" dl = hf_hub_download(\n",
|
| 139 |
+
" repo_id=FEEDBACK_REPO_ID,\n",
|
| 140 |
+
" filename=rec['audio_file'],\n",
|
| 141 |
+
" repo_type='dataset',\n",
|
| 142 |
+
" token=HF_TOKEN,\n",
|
| 143 |
+
" )\n",
|
| 144 |
+
" shutil.copy(dl, local_path)\n",
|
| 145 |
+
" except Exception as e:\n",
|
| 146 |
+
" skipped += 1\n",
|
| 147 |
+
" print(f' skip {rec[\"audio_file\"]}: {e}')\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"# Point records at local paths\n",
|
| 150 |
+
"for rec in corrections:\n",
|
| 151 |
+
" local = fb_audio_dir / Path(rec['audio_file']).name\n",
|
| 152 |
+
" if local.exists():\n",
|
| 153 |
+
" rec['audio_file'] = str(local)\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"available = [r for r in corrections if Path(r['audio_file']).exists()]\n",
|
| 156 |
+
"print(f'Downloaded {len(available)} / {len(corrections)} audio files (skipped {skipped})')"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": null,
|
| 162 |
+
"id": "cell-train",
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"outputs": [],
|
| 165 |
+
"source": [
|
| 166 |
+
"# Cell 6 — Fine-tune: Waxal baseline + farmer corrections\n",
|
| 167 |
+
"#\n",
|
| 168 |
+
"# WhisperLoRATrainer.setup() loads Waxal (streaming).\n",
|
| 169 |
+
"# merge_extra_data() materialises Waxal (up to 500 samples),\n",
|
| 170 |
+
"# appends corrections (3× upsampled), shuffles the combined dataset.\n",
|
| 171 |
+
"# train() runs standard Seq2SeqTrainer on the merged dataset.\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"import os\n",
|
| 174 |
+
"os.environ['HF_TOKEN'] = HF_TOKEN\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"from src.training.trainer import WhisperLoRATrainer\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"lang_config_map = {'bam': 'lora_bambara.yaml', 'ful': 'lora_fula.yaml'}\n",
|
| 179 |
+
"base_cfg = str(space_dir / 'configs' / 'base_config.yaml')\n",
|
| 180 |
+
"lang_cfg = str(space_dir / 'configs' / lang_config_map[TRAIN_LANG])\n",
|
| 181 |
+
"output_dir = f'/tmp/sahel_adapter_{TRAIN_LANG}'\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"# Override output_dir so adapter saves to /tmp on Colab\n",
|
| 184 |
+
"import yaml\n",
|
| 185 |
+
"with open(lang_cfg) as f:\n",
|
| 186 |
+
" lang_config = yaml.safe_load(f)\n",
|
| 187 |
+
"lang_config['output_dir'] = output_dir\n",
|
| 188 |
+
"tmp_lang_cfg = f'/tmp/lora_{TRAIN_LANG}_tmp.yaml'\n",
|
| 189 |
+
"with open(tmp_lang_cfg, 'w') as f:\n",
|
| 190 |
+
" yaml.dump(lang_config, f)\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"trainer = WhisperLoRATrainer(\n",
|
| 193 |
+
" base_config_path=base_cfg,\n",
|
| 194 |
+
" language_config_path=tmp_lang_cfg,\n",
|
| 195 |
+
")\n",
|
| 196 |
+
"trainer.setup()\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"if available:\n",
|
| 199 |
+
" print(f'Merging {len(available)} corrections (×3) with Waxal baseline (cap=500)...')\n",
|
| 200 |
+
" trainer.merge_extra_data(available, repeat=3, waxal_cap=500)\n",
|
| 201 |
+
"else:\n",
|
| 202 |
+
" print('No corrections available — training on Waxal only.')\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"trainer.train()\n",
|
| 205 |
+
"print(f'✅ Training complete — adapter at {output_dir}')"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": null,
|
| 211 |
+
"id": "cell-push",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"outputs": [],
|
| 214 |
+
"source": [
|
| 215 |
+
"# Cell 7 — Push adapter to HF Model repo\n",
|
| 216 |
+
"from huggingface_hub import HfApi\n",
|
| 217 |
+
"api = HfApi(token=HF_TOKEN)\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"path_in_repo = 'adapters/bambara' if TRAIN_LANG == 'bam' else 'adapters/fula'\n",
|
| 220 |
+
"n_corrections = len(available)\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"api.upload_folder(\n",
|
| 223 |
+
" folder_path=output_dir,\n",
|
| 224 |
+
" repo_id=ADAPTER_REPO_ID,\n",
|
| 225 |
+
" repo_type='model',\n",
|
| 226 |
+
" path_in_repo=path_in_repo,\n",
|
| 227 |
+
" commit_message=(\n",
|
| 228 |
+
" f'Fine-tune {TRAIN_LANG}: Waxal baseline + {n_corrections} farmer corrections'\n",
|
| 229 |
+
" ),\n",
|
| 230 |
+
")\n",
|
| 231 |
+
"print(f'✅ Pushed to {ADAPTER_REPO_ID}/{path_in_repo}')\n",
|
| 232 |
+
"print('\\nNext: Space → Tab 3 → Reload Adapters from Hub')"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "code",
|
| 237 |
+
"execution_count": null,
|
| 238 |
+
"id": "cell-sanity",
|
| 239 |
+
"metadata": {},
|
| 240 |
+
"outputs": [],
|
| 241 |
+
"source": [
|
| 242 |
+
"# Cell 8 — Sanity check: compare WER before vs after adapter\n",
|
| 243 |
+
"import random, torch, librosa, jiwer\n",
|
| 244 |
+
"from transformers import WhisperForConditionalGeneration, WhisperProcessor\n",
|
| 245 |
+
"from peft import PeftModel\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"if not available:\n",
|
| 248 |
+
" print('No test samples — skipping sanity check.')\n",
|
| 249 |
+
"else:\n",
|
| 250 |
+
" test_rec = random.choice(available)\n",
|
| 251 |
+
" print(f'Audio : {Path(test_rec[\"audio_file\"]).name}')\n",
|
| 252 |
+
" print(f'Expected : {test_rec[\"corrected_text\"]}')\n",
|
| 253 |
+
" print(f'Pre-train: {test_rec[\"whisper_output\"]}')\n",
|
| 254 |
+
"\n",
|
| 255 |
+
" # Load base + adapter\n",
|
| 256 |
+
" processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n",
|
| 257 |
+
" base = WhisperForConditionalGeneration.from_pretrained(\n",
|
| 258 |
+
" WHISPER_MODEL_ID, torch_dtype=torch.float16, token=HF_TOKEN\n",
|
| 259 |
+
" ).to('cuda')\n",
|
| 260 |
+
" model = PeftModel.from_pretrained(base, output_dir).eval()\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" audio_np, _ = librosa.load(test_rec['audio_file'], sr=16000, mono=True)\n",
|
| 263 |
+
" feats = processor.feature_extractor(\n",
|
| 264 |
+
" audio_np, sampling_rate=16000, return_tensors='pt'\n",
|
| 265 |
+
" ).input_features.half().to('cuda')\n",
|
| 266 |
+
"\n",
|
| 267 |
+
" with torch.no_grad():\n",
|
| 268 |
+
" ids = model.generate(feats, max_new_tokens=256)\n",
|
| 269 |
+
" result = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()\n",
|
| 270 |
+
" print(f'Post-train: {result}')\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" ref = test_rec['corrected_text']\n",
|
| 273 |
+
" wer_before = jiwer.wer(ref, test_rec['whisper_output']) if test_rec.get('whisper_output') else 1.0\n",
|
| 274 |
+
" wer_after = jiwer.wer(ref, result)\n",
|
| 275 |
+
" print(f'\\nWER before: {wer_before:.1%} → WER after: {wer_after:.1%}')\n",
|
| 276 |
+
" if wer_after < wer_before:\n",
|
| 277 |
+
" print('✅ Adapter improved transcription quality!')\n",
|
| 278 |
+
" else:\n",
|
| 279 |
+
" print('ℹ️ No improvement on this single sample — collect more corrections and retrain.')"
|
| 280 |
+
]
|
| 281 |
+
}
|
| 282 |
+
]
|
| 283 |
+
}
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
requirements.txt
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -----------------------------------------------------------------------------
|
| 2 |
+
# Sahel-Agri Voice AI — Python Dependencies
|
| 3 |
+
# HuggingFace Spaces (ZeroGPU) deployment — CUDA pre-installed, no +cu128 suffix
|
| 4 |
+
#
|
| 5 |
+
# Local CPU test:
|
| 6 |
+
# pip install -r requirements.txt
|
| 7 |
+
# -----------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
# PyTorch (CPU build — works on HF Spaces cpu-basic and locally)
|
| 10 |
+
torch==2.11.0
|
| 11 |
+
torchaudio==2.11.0
|
| 12 |
+
|
| 13 |
+
# HuggingFace core
|
| 14 |
+
transformers==5.5.0
|
| 15 |
+
datasets==4.8.4
|
| 16 |
+
accelerate==1.13.0
|
| 17 |
+
evaluate==0.4.2
|
| 18 |
+
huggingface-hub==1.9.0
|
| 19 |
+
|
| 20 |
+
# PEFT (LoRA adapters)
|
| 21 |
+
peft==0.18.1
|
| 22 |
+
|
| 23 |
+
# Audio processing
|
| 24 |
+
librosa==0.10.2
|
| 25 |
+
soundfile==0.12.1
|
| 26 |
+
audiomentations==0.43.1
|
| 27 |
+
|
| 28 |
+
# Quantization (CPU: installs fine; 4-bit/8-bit requires GPU at runtime)
|
| 29 |
+
bitsandbytes==0.49.2
|
| 30 |
+
|
| 31 |
+
# Metrics
|
| 32 |
+
jiwer==3.0.4
|
| 33 |
+
|
| 34 |
+
# Config & environment
|
| 35 |
+
pyyaml==6.0.2
|
| 36 |
+
python-dotenv==1.1.0
|
| 37 |
+
|
| 38 |
+
# Gradio (must match sdk_version in README.md)
|
| 39 |
+
gradio==4.44.0
|
| 40 |
+
|
| 41 |
+
# Pydantic v2
|
| 42 |
+
pydantic==2.11.3
|
| 43 |
+
|
| 44 |
+
# Testing
|
| 45 |
+
pytest==8.3.5
|
| 46 |
+
pytest-asyncio==0.26.0
|
| 47 |
+
|
| 48 |
+
# Utilities
|
| 49 |
+
numpy==2.2.4
|
| 50 |
+
scipy==1.15.2
|
scripts/export_onnx.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 4a: Merge LoRA adapters and export language-specific ONNX models.
|
| 3 |
+
Validates that ONNX WER is within 2% of PyTorch baseline.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/export_onnx.py
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from src.optimization.onnx_exporter import ONNXExporter
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def export_language(language: str, adapter_path: str, config: dict) -> None:
|
| 27 |
+
from peft import PeftModel
|
| 28 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
| 29 |
+
|
| 30 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 31 |
+
model_id = config["model"]["id"]
|
| 32 |
+
|
| 33 |
+
print(f"\n[{language.upper()}] Loading base model...")
|
| 34 |
+
base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token)
|
| 35 |
+
processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)
|
| 36 |
+
|
| 37 |
+
print(f"[{language.upper()}] Loading adapter from {adapter_path}...")
|
| 38 |
+
peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language)
|
| 39 |
+
|
| 40 |
+
output_dir = f"{config['paths']['models']}/onnx/{language}"
|
| 41 |
+
exporter = ONNXExporter()
|
| 42 |
+
result_path = exporter.merge_and_export(peft_model, processor, output_dir, language)
|
| 43 |
+
print(f"[{language.upper()}] ONNX exported to: {result_path}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main() -> None:
|
| 47 |
+
with open("configs/base_config.yaml") as f:
|
| 48 |
+
config = yaml.safe_load(f)
|
| 49 |
+
|
| 50 |
+
print("=" * 60)
|
| 51 |
+
print("Sahel-Agri Voice AI — ONNX Export")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
|
| 54 |
+
bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
|
| 55 |
+
fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")
|
| 56 |
+
|
| 57 |
+
for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]:
|
| 58 |
+
if Path(adapter_path).exists():
|
| 59 |
+
export_language(language, adapter_path, config)
|
| 60 |
+
else:
|
| 61 |
+
print(f"\nSkipping {language}: adapter not found at {adapter_path}")
|
| 62 |
+
|
| 63 |
+
print("\nExport complete.")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|
scripts/run_data_pipeline.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 2: Download google/waxal, apply augmentation, print statistics.
|
| 3 |
+
Streams examples and caches to data_cache/ as Arrow files.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/run_data_pipeline.py --subset bam --max-examples 100
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
|
| 19 |
+
load_dotenv()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main(subset: str, max_examples: int) -> None:
|
| 23 |
+
import yaml
|
| 24 |
+
from transformers import WhisperProcessor
|
| 25 |
+
|
| 26 |
+
from src.data.augmentation import FieldNoiseAugmenter
|
| 27 |
+
from src.data.waxal_loader import WaxalDataLoader
|
| 28 |
+
|
| 29 |
+
with open("configs/base_config.yaml") as f:
|
| 30 |
+
config = yaml.safe_load(f)
|
| 31 |
+
|
| 32 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 33 |
+
model_id = config["model"]["id"]
|
| 34 |
+
|
| 35 |
+
print("=" * 60)
|
| 36 |
+
print(f"Waxal Data Pipeline — subset: {subset}")
|
| 37 |
+
print("=" * 60)
|
| 38 |
+
|
| 39 |
+
print(f"\n[1/4] Loading WhisperProcessor ({model_id})...")
|
| 40 |
+
processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)
|
| 41 |
+
|
| 42 |
+
print("[2/4] Initializing augmenter...")
|
| 43 |
+
augmenter = FieldNoiseAugmenter(config["paths"]["noise_samples"], config)
|
| 44 |
+
print(f" Augmenter ready: {augmenter.is_ready()}")
|
| 45 |
+
|
| 46 |
+
print(f"[3/4] Streaming google/waxal subset={subset}...")
|
| 47 |
+
loader = WaxalDataLoader(subset, config, hf_token=hf_token)
|
| 48 |
+
|
| 49 |
+
t0 = time.time()
|
| 50 |
+
count = 0
|
| 51 |
+
total_duration = 0.0
|
| 52 |
+
|
| 53 |
+
for example in loader.iter_processed(processor, split="train", augmenter=augmenter):
|
| 54 |
+
count += 1
|
| 55 |
+
# input_features shape: (80, 3000) = 30 seconds at most
|
| 56 |
+
# Estimate actual audio duration from non-padding frames
|
| 57 |
+
total_duration += 30.0 # max chunk
|
| 58 |
+
if count >= max_examples:
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
elapsed = time.time() - t0
|
| 62 |
+
|
| 63 |
+
print(f"\n[4/4] Results:")
|
| 64 |
+
print(f" Examples processed: {count}")
|
| 65 |
+
print(f" Approx total audio: {total_duration / 3600:.2f} hours")
|
| 66 |
+
print(f" Processing time: {elapsed:.1f}s")
|
| 67 |
+
print(f" Throughput: {count / elapsed:.1f} examples/sec")
|
| 68 |
+
print(f"\nData pipeline PASSED.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
parser = argparse.ArgumentParser()
|
| 73 |
+
parser.add_argument("--subset", default="bam", choices=["bam", "ful"])
|
| 74 |
+
parser.add_argument("--max-examples", type=int, default=50)
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
main(args.subset, args.max_examples)
|
scripts/run_server.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 4b: Start the FastAPI inference server.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/run_server.py
|
| 6 |
+
python scripts/run_server.py --host 0.0.0.0 --port 8000
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
import uvicorn
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
parser = argparse.ArgumentParser(description="Start Sahel-Agri Voice AI server")
|
| 22 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 23 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 24 |
+
parser.add_argument("--reload", action="store_true", help="Enable hot-reload (dev only)")
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
print(f"Starting server on http://{args.host}:{args.port}")
|
| 28 |
+
print("Endpoints:")
|
| 29 |
+
print(f" GET http://localhost:{args.port}/api/v1/health")
|
| 30 |
+
print(f" POST http://localhost:{args.port}/api/v1/transcribe")
|
| 31 |
+
print(f" POST http://localhost:{args.port}/api/v1/query")
|
| 32 |
+
print(f" GET http://localhost:{args.port}/docs (Swagger UI)")
|
| 33 |
+
print()
|
| 34 |
+
|
| 35 |
+
uvicorn.run(
|
| 36 |
+
"src.api.app:app",
|
| 37 |
+
host=args.host,
|
| 38 |
+
port=args.port,
|
| 39 |
+
workers=1, # Single worker: GPU model shared in memory
|
| 40 |
+
reload=args.reload,
|
| 41 |
+
log_level="info",
|
| 42 |
+
)
|
scripts/train_bambara.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 3a: Fine-tune LoRA adapter for Bambara (bam).
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/train_bambara.py
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
|
| 18 |
+
|
| 19 |
+
from src.training.trainer import WhisperLoRATrainer
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
trainer = WhisperLoRATrainer(
|
| 23 |
+
base_config_path="configs/base_config.yaml",
|
| 24 |
+
language_config_path="configs/lora_bambara.yaml",
|
| 25 |
+
)
|
| 26 |
+
trainer.setup()
|
| 27 |
+
trainer.train()
|
| 28 |
+
print("\nBambara training complete. Adapter saved to adapters/bambara/")
|
scripts/train_fula.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 3b: Fine-tune LoRA adapter for Fula (ful).
|
| 3 |
+
Trains on the same frozen backbone as Bambara — base model weights are NOT modified.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/train_fula.py
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
|
| 19 |
+
|
| 20 |
+
from src.training.trainer import WhisperLoRATrainer
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
trainer = WhisperLoRATrainer(
|
| 24 |
+
base_config_path="configs/base_config.yaml",
|
| 25 |
+
language_config_path="configs/lora_fula.yaml",
|
| 26 |
+
)
|
| 27 |
+
trainer.setup()
|
| 28 |
+
trainer.train()
|
| 29 |
+
print("\nFula training complete. Adapter saved to adapters/fula/")
|
scripts/verify_baseline.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 1 smoke test: load Whisper, run inference on a sample audio clip.
|
| 3 |
+
Prints model info, inference time, GPU memory usage, and sample transcript.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/verify_baseline.py
|
| 7 |
+
"""
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Allow imports from project root
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main() -> None:
|
| 20 |
+
from src.engine.whisper_base import WhisperBackbone
|
| 21 |
+
|
| 22 |
+
print("=" * 60)
|
| 23 |
+
print("Sahel-Agri Voice AI — Baseline Verification")
|
| 24 |
+
print("=" * 60)
|
| 25 |
+
|
| 26 |
+
# 1. Check environment
|
| 27 |
+
print(f"\nPython: {sys.version.split()[0]}")
|
| 28 |
+
print(f"PyTorch: {torch.__version__}")
|
| 29 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 30 |
+
if torch.cuda.is_available():
|
| 31 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 32 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 33 |
+
|
| 34 |
+
# 2. Load model
|
| 35 |
+
print("\n[1/3] Loading backbone model...")
|
| 36 |
+
t0 = time.time()
|
| 37 |
+
backbone = WhisperBackbone("configs/base_config.yaml")
|
| 38 |
+
backbone.load(device="cuda")
|
| 39 |
+
load_time = time.time() - t0
|
| 40 |
+
print(f" Loaded in {load_time:.1f}s")
|
| 41 |
+
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
used = torch.cuda.memory_allocated() / 1e9
|
| 44 |
+
reserved = torch.cuda.memory_reserved() / 1e9
|
| 45 |
+
print(f" GPU memory: {used:.2f} GB allocated / {reserved:.2f} GB reserved")
|
| 46 |
+
|
| 47 |
+
# 3. Generate synthetic test audio (1 second of silence with slight noise)
|
| 48 |
+
print("\n[2/3] Generating test audio (1s white noise)...")
|
| 49 |
+
sample_rate = 16000
|
| 50 |
+
duration = 1.0
|
| 51 |
+
audio = np.random.randn(int(sample_rate * duration)).astype(np.float32) * 0.01
|
| 52 |
+
|
| 53 |
+
# 4. Run inference
|
| 54 |
+
print("[3/3] Running inference...")
|
| 55 |
+
processor = backbone.processor
|
| 56 |
+
model = backbone.model
|
| 57 |
+
|
| 58 |
+
inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
|
| 59 |
+
input_features = inputs.input_features.to(backbone.device)
|
| 60 |
+
if backbone.device == "cuda":
|
| 61 |
+
input_features = input_features.half()
|
| 62 |
+
|
| 63 |
+
t0 = time.time()
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
predicted_ids = model.generate(input_features, max_new_tokens=50)
|
| 66 |
+
infer_time = time.time() - t0
|
| 67 |
+
|
| 68 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 69 |
+
|
| 70 |
+
print(f"\n{'=' * 60}")
|
| 71 |
+
print(f"Transcript: '{transcription}' (noise input — blank expected)")
|
| 72 |
+
print(f"Inference time: {infer_time * 1000:.0f} ms")
|
| 73 |
+
print(f"\nBaseline verification PASSED.")
|
| 74 |
+
print(f"{'=' * 60}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
src/__init__.py
ADDED
|
File without changes
|
src/api/__init__.py
ADDED
|
File without changes
|
src/api/app.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application factory.
|
| 3 |
+
Uses lifespan context manager to load the Whisper model at startup
|
| 4 |
+
and register language adapters — keeping a single backbone in GPU memory.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
+
|
| 12 |
+
import yaml
|
| 13 |
+
from fastapi import FastAPI
|
| 14 |
+
|
| 15 |
+
from src.api.middleware import register_middleware
|
| 16 |
+
from src.api.routes import health, iot, transcribe
|
| 17 |
+
from src.engine.adapter_manager import AdapterManager
|
| 18 |
+
from src.engine.transcriber import Transcriber
|
| 19 |
+
from src.engine.whisper_base import WhisperBackbone
|
| 20 |
+
from src.iot.sensor_bridge import SensorBridge
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(
|
| 25 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
| 26 |
+
format="%(asctime)s %(levelname)s %(name)s — %(message)s",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@asynccontextmanager
|
| 31 |
+
async def lifespan(app: FastAPI):
|
| 32 |
+
"""Load model at startup, free GPU memory at shutdown."""
|
| 33 |
+
with open("configs/base_config.yaml") as f:
|
| 34 |
+
config = yaml.safe_load(f)
|
| 35 |
+
|
| 36 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 37 |
+
device = os.getenv("DEVICE", "cuda")
|
| 38 |
+
bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
|
| 39 |
+
fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")
|
| 40 |
+
sensor_api_url = os.getenv("SENSOR_API_URL") or None
|
| 41 |
+
|
| 42 |
+
# 1. Load backbone
|
| 43 |
+
logger.info("Loading Whisper backbone...")
|
| 44 |
+
backbone = WhisperBackbone("configs/base_config.yaml")
|
| 45 |
+
backbone.load(device=device, hf_token=hf_token)
|
| 46 |
+
|
| 47 |
+
# 2. Register adapters (they are loaded on first use via activate())
|
| 48 |
+
adapter_manager = AdapterManager(backbone.model, config)
|
| 49 |
+
adapter_manager.register("bam", bambara_path)
|
| 50 |
+
adapter_manager.register("ful", fula_path)
|
| 51 |
+
|
| 52 |
+
# 3. Pre-load the default adapter to warm up VRAM
|
| 53 |
+
try:
|
| 54 |
+
adapter_manager.load_adapter("bam")
|
| 55 |
+
logger.info("Default adapter 'bam' pre-loaded.")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.warning("Could not pre-load 'bam' adapter: %s", e)
|
| 58 |
+
|
| 59 |
+
# 4. Create transcriber and sensor bridge
|
| 60 |
+
transcriber = Transcriber(backbone, adapter_manager)
|
| 61 |
+
sensor_bridge = SensorBridge(sensor_api_url=sensor_api_url)
|
| 62 |
+
|
| 63 |
+
# 5. Attach to app.state for dependency injection
|
| 64 |
+
app.state.backbone = backbone
|
| 65 |
+
app.state.adapter_manager = adapter_manager
|
| 66 |
+
app.state.transcriber = transcriber
|
| 67 |
+
app.state.sensor_bridge = sensor_bridge
|
| 68 |
+
|
| 69 |
+
logger.info("Sahel-Agri Voice AI server ready.")
|
| 70 |
+
yield
|
| 71 |
+
|
| 72 |
+
# Shutdown
|
| 73 |
+
logger.info("Shutting down — freeing GPU memory...")
|
| 74 |
+
backbone.free()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def create_app() -> FastAPI:
|
| 78 |
+
app = FastAPI(
|
| 79 |
+
title="Sahel-Agri Voice AI",
|
| 80 |
+
description=(
|
| 81 |
+
"Modular STT engine for Bambara and Fula — serving Mali and Guinea farmers "
|
| 82 |
+
"via voice-first agricultural intelligence."
|
| 83 |
+
),
|
| 84 |
+
version="0.1.0",
|
| 85 |
+
lifespan=lifespan,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
register_middleware(app)
|
| 89 |
+
|
| 90 |
+
# Register routes
|
| 91 |
+
app.include_router(health.router, prefix="/api/v1", tags=["health"])
|
| 92 |
+
app.include_router(transcribe.router, prefix="/api/v1", tags=["transcribe"])
|
| 93 |
+
app.include_router(iot.router, prefix="/api/v1", tags=["iot"])
|
| 94 |
+
|
| 95 |
+
return app
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
app = create_app()
|
src/api/dependencies.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI dependency injection: retrieves shared model objects from app.state."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from fastapi import Request
|
| 5 |
+
|
| 6 |
+
from src.engine.adapter_manager import AdapterManager
|
| 7 |
+
from src.engine.transcriber import Transcriber
|
| 8 |
+
from src.iot.sensor_bridge import SensorBridge
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_transcriber(request: Request) -> Transcriber:
|
| 12 |
+
return request.app.state.transcriber
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_adapter_manager(request: Request) -> AdapterManager:
|
| 16 |
+
return request.app.state.adapter_manager
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_sensor_bridge(request: Request) -> SensorBridge:
|
| 20 |
+
return request.app.state.sensor_bridge
|
src/api/middleware.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CORS, structured request logging, and rate-limit middleware."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
from fastapi import FastAPI, Request, Response
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 11 |
+
from slowapi.errors import RateLimitExceeded
|
| 12 |
+
from slowapi.util import get_remote_address
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def register_middleware(app: FastAPI) -> None:
|
| 20 |
+
"""Attach all middleware to the FastAPI app."""
|
| 21 |
+
|
| 22 |
+
# CORS — allow WhatsApp webhook domain and local development
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=["*"], # Tighten in production with specific domains
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["GET", "POST"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Rate limiting
|
| 32 |
+
app.state.limiter = limiter
|
| 33 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 34 |
+
|
| 35 |
+
@app.middleware("http")
|
| 36 |
+
async def logging_middleware(request: Request, call_next) -> Response:
|
| 37 |
+
request_id = str(uuid.uuid4())[:8]
|
| 38 |
+
t0 = time.perf_counter()
|
| 39 |
+
response = await call_next(request)
|
| 40 |
+
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
| 41 |
+
logger.info(
|
| 42 |
+
"req_id=%s method=%s path=%s status=%d latency_ms=%d",
|
| 43 |
+
request_id, request.method, request.url.path,
|
| 44 |
+
response.status_code, elapsed_ms,
|
| 45 |
+
)
|
| 46 |
+
response.headers["X-Request-ID"] = request_id
|
| 47 |
+
return response
|
src/api/routes/__init__.py
ADDED
|
File without changes
|
src/api/routes/health.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GET /api/v1/health — model status and adapter availability."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, Depends, Request
|
| 5 |
+
|
| 6 |
+
from src.api.dependencies import get_adapter_manager
|
| 7 |
+
from src.api.schemas import HealthResponse
|
| 8 |
+
from src.engine.adapter_manager import AdapterManager
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.get("/health", response_model=HealthResponse)
|
| 14 |
+
async def health_check(
|
| 15 |
+
request: Request,
|
| 16 |
+
adapter_manager: AdapterManager = Depends(get_adapter_manager),
|
| 17 |
+
) -> HealthResponse:
|
| 18 |
+
model_loaded = hasattr(request.app.state, "transcriber")
|
| 19 |
+
return HealthResponse(
|
| 20 |
+
status="ok" if model_loaded else "loading",
|
| 21 |
+
model_loaded=model_loaded,
|
| 22 |
+
active_adapter=adapter_manager.get_active(),
|
| 23 |
+
adapters_available=adapter_manager.list_available(),
|
| 24 |
+
adapters_loaded=adapter_manager.list_loaded(),
|
| 25 |
+
)
|
src/api/routes/iot.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""POST /api/v1/query — full pipeline: audio → transcription → intent → sensor → voice response."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import time
|
| 8 |
+
from typing import Annotated, Optional
|
| 9 |
+
|
| 10 |
+
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
| 11 |
+
|
| 12 |
+
from src.api.dependencies import get_sensor_bridge, get_transcriber
|
| 13 |
+
from src.api.schemas import IoTQueryResponse
|
| 14 |
+
from src.engine.transcriber import Transcriber
|
| 15 |
+
from src.iot.intent_parser import IntentParser
|
| 16 |
+
from src.iot.sensor_bridge import SensorBridge
|
| 17 |
+
from src.iot.voice_responder import VoiceResponder
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
router = APIRouter()
|
| 21 |
+
|
| 22 |
+
_intent_parser = IntentParser()
|
| 23 |
+
_voice_responder = VoiceResponder(language="fr")
|
| 24 |
+
|
| 25 |
+
SUPPORTED_LANGUAGES = {"bam", "ful"}
|
| 26 |
+
MAX_AUDIO_BYTES = 10 * 1024 * 1024
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.post("/query", response_model=IoTQueryResponse)
|
| 30 |
+
async def agricultural_query(
|
| 31 |
+
audio_file: Annotated[UploadFile, File(description="Audio file with farmer's voice query")],
|
| 32 |
+
language: Annotated[str, Form(description="Language code: 'bam' or 'ful'")] = "bam",
|
| 33 |
+
field_id: Annotated[Optional[str], Form(description="Field/location ID for sensor lookup")] = None,
|
| 34 |
+
transcriber: Transcriber = Depends(get_transcriber),
|
| 35 |
+
sensor_bridge: SensorBridge = Depends(get_sensor_bridge),
|
| 36 |
+
) -> IoTQueryResponse:
|
| 37 |
+
t0 = time.perf_counter()
|
| 38 |
+
|
| 39 |
+
if language not in SUPPORTED_LANGUAGES:
|
| 40 |
+
raise HTTPException(
|
| 41 |
+
status_code=422,
|
| 42 |
+
detail=f"Unsupported language '{language}'. Supported: {sorted(SUPPORTED_LANGUAGES)}",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
audio_bytes = await audio_file.read()
|
| 46 |
+
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
| 47 |
+
raise HTTPException(status_code=413, detail="Audio file too large. Max 10 MB.")
|
| 48 |
+
|
| 49 |
+
ext = os.path.splitext(audio_file.filename or "audio.wav")[1].lower() or ".wav"
|
| 50 |
+
tmp_path = None
|
| 51 |
+
try:
|
| 52 |
+
# Step 1: Transcribe
|
| 53 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
| 54 |
+
tmp.write(audio_bytes)
|
| 55 |
+
tmp_path = tmp.name
|
| 56 |
+
|
| 57 |
+
transcription_result = transcriber.transcribe_file(tmp_path, language)
|
| 58 |
+
|
| 59 |
+
# Step 2: Parse intent
|
| 60 |
+
intent = _intent_parser.parse(transcription_result.text, language)
|
| 61 |
+
|
| 62 |
+
# Step 3: Fetch sensor data
|
| 63 |
+
sensor_data = await sensor_bridge.fetch(intent, field_id=field_id)
|
| 64 |
+
|
| 65 |
+
# Step 4: Generate voice response
|
| 66 |
+
voice_response = _voice_responder.generate_response(intent, sensor_data)
|
| 67 |
+
|
| 68 |
+
except HTTPException:
|
| 69 |
+
raise
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error("IoT query failed: %s", e, exc_info=True)
|
| 72 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 73 |
+
finally:
|
| 74 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 75 |
+
os.unlink(tmp_path)
|
| 76 |
+
|
| 77 |
+
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
| 78 |
+
|
| 79 |
+
return IoTQueryResponse(
|
| 80 |
+
transcription=transcription_result.text,
|
| 81 |
+
language=language,
|
| 82 |
+
intent={
|
| 83 |
+
"action": intent.action,
|
| 84 |
+
"entity": intent.entity,
|
| 85 |
+
"confidence": intent.confidence,
|
| 86 |
+
},
|
| 87 |
+
sensor_data=sensor_data.values,
|
| 88 |
+
voice_response=voice_response,
|
| 89 |
+
processing_time_ms=elapsed_ms,
|
| 90 |
+
)
|
src/api/routes/transcribe.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""POST /api/v1/transcribe — convert uploaded audio to text."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from typing import Annotated
|
| 8 |
+
|
| 9 |
+
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
| 10 |
+
|
| 11 |
+
from src.api.dependencies import get_transcriber
|
| 12 |
+
from src.api.schemas import TranscribeResponse
|
| 13 |
+
from src.engine.transcriber import Transcriber
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
SUPPORTED_LANGUAGES = {"bam", "ful"}
|
| 19 |
+
SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".webm"}
|
| 20 |
+
MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@router.post("/transcribe", response_model=TranscribeResponse)
|
| 24 |
+
async def transcribe_audio(
|
| 25 |
+
audio_file: Annotated[UploadFile, File(description="Audio file (wav/mp3/ogg/m4a/flac/webm)")],
|
| 26 |
+
language: Annotated[str, Form(description="Language code: 'bam' (Bambara) or 'ful' (Fula)")] = "bam",
|
| 27 |
+
transcriber: Transcriber = Depends(get_transcriber),
|
| 28 |
+
) -> TranscribeResponse:
|
| 29 |
+
# Validate language
|
| 30 |
+
if language not in SUPPORTED_LANGUAGES:
|
| 31 |
+
raise HTTPException(
|
| 32 |
+
status_code=422,
|
| 33 |
+
detail=f"Unsupported language '{language}'. Supported: {sorted(SUPPORTED_LANGUAGES)}",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Validate file extension
|
| 37 |
+
filename = audio_file.filename or "audio.wav"
|
| 38 |
+
ext = os.path.splitext(filename)[1].lower()
|
| 39 |
+
if ext not in SUPPORTED_EXTENSIONS:
|
| 40 |
+
raise HTTPException(
|
| 41 |
+
status_code=422,
|
| 42 |
+
detail=f"Unsupported file type '{ext}'. Supported: {sorted(SUPPORTED_EXTENSIONS)}",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Read and size-check
|
| 46 |
+
audio_bytes = await audio_file.read()
|
| 47 |
+
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
| 48 |
+
raise HTTPException(
|
| 49 |
+
status_code=413,
|
| 50 |
+
detail=f"File too large ({len(audio_bytes) / 1e6:.1f} MB). Max 10 MB.",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Windows-safe temp file: delete=False + manual unlink in finally
|
| 54 |
+
tmp_path = None
|
| 55 |
+
try:
|
| 56 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
| 57 |
+
tmp.write(audio_bytes)
|
| 58 |
+
tmp_path = tmp.name
|
| 59 |
+
|
| 60 |
+
result = transcriber.transcribe_file(tmp_path, language)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error("Transcription failed: %s", e, exc_info=True)
|
| 63 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 64 |
+
finally:
|
| 65 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 66 |
+
os.unlink(tmp_path)
|
| 67 |
+
|
| 68 |
+
return TranscribeResponse(
|
| 69 |
+
text=result.text,
|
| 70 |
+
language=result.language,
|
| 71 |
+
duration_s=result.duration_s,
|
| 72 |
+
processing_time_ms=result.processing_time_ms,
|
| 73 |
+
confidence=result.confidence,
|
| 74 |
+
)
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic v2 request and response models for all API endpoints."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Literal, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TranscribeResponse(BaseModel):
|
| 10 |
+
text: str
|
| 11 |
+
language: str
|
| 12 |
+
duration_s: float
|
| 13 |
+
processing_time_ms: int
|
| 14 |
+
confidence: Optional[float] = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class IoTQueryResponse(BaseModel):
|
| 18 |
+
transcription: str
|
| 19 |
+
language: str
|
| 20 |
+
intent: dict
|
| 21 |
+
sensor_data: dict
|
| 22 |
+
voice_response: str
|
| 23 |
+
processing_time_ms: int
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HealthResponse(BaseModel):
|
| 27 |
+
status: str
|
| 28 |
+
model_loaded: bool
|
| 29 |
+
active_adapter: Optional[str]
|
| 30 |
+
adapters_available: list[str]
|
| 31 |
+
adapters_loaded: list[str]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ErrorResponse(BaseModel):
|
| 35 |
+
error: str
|
| 36 |
+
detail: str
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/agri_dictionary.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agricultural vocabulary for Bambara and Fula.
|
| 3 |
+
Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from transformers import WhisperProcessor
|
| 13 |
+
|
| 14 |
+
# Bambara (bam) agricultural vocabulary
|
| 15 |
+
BAMBARA_VOCAB: dict[str, str] = {
|
| 16 |
+
"sɛnɛ": "farming",
|
| 17 |
+
"jiriw": "trees",
|
| 18 |
+
"nɔgɔ": "soil",
|
| 19 |
+
"sani": "fertilizer",
|
| 20 |
+
"kogomali": "groundnut",
|
| 21 |
+
"kaba": "corn/maize",
|
| 22 |
+
"tiga": "peanut",
|
| 23 |
+
"ji": "water",
|
| 24 |
+
"sanji": "rain",
|
| 25 |
+
"teliman": "weather",
|
| 26 |
+
"suruku": "pest/predator",
|
| 27 |
+
"bunding": "soil/earth",
|
| 28 |
+
"sira": "path/way",
|
| 29 |
+
"foro": "field",
|
| 30 |
+
"dugu": "village/land",
|
| 31 |
+
"dibi": "darkness/shade",
|
| 32 |
+
"fanga": "strength/fertilizer",
|
| 33 |
+
"kungoloni": "insects/pests",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Fula (ful / Fulfulde) agricultural vocabulary
|
| 37 |
+
FULA_VOCAB: dict[str, str] = {
|
| 38 |
+
"ngesa": "field",
|
| 39 |
+
"leydi": "land/soil",
|
| 40 |
+
"kosam": "milk",
|
| 41 |
+
"nagge": "cattle",
|
| 42 |
+
"leeɗe": "crops",
|
| 43 |
+
"ndiyam": "water",
|
| 44 |
+
"yeeso": "wind/weather",
|
| 45 |
+
"laabi": "road/way",
|
| 46 |
+
"demoore": "farming",
|
| 47 |
+
"hoore": "head/top",
|
| 48 |
+
"biñ-biñ": "insects/pests",
|
| 49 |
+
"fuɗorde": "sunrise/east field",
|
| 50 |
+
"ngaari": "bull",
|
| 51 |
+
"mbabba": "donkey",
|
| 52 |
+
"ladde": "bush/forest",
|
| 53 |
+
"wutte": "clothing/harvest",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
LANGUAGE_VOCABS: dict[str, dict[str, str]] = {
|
| 57 |
+
"bam": BAMBARA_VOCAB,
|
| 58 |
+
"ful": FULA_VOCAB,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AgriculturalDictionary:
|
| 63 |
+
"""Converts agricultural vocabulary into decoder prompt token IDs for Whisper."""
|
| 64 |
+
|
| 65 |
+
def get_vocab(self, language: str) -> dict[str, str]:
|
| 66 |
+
if language not in LANGUAGE_VOCABS:
|
| 67 |
+
raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}")
|
| 68 |
+
return LANGUAGE_VOCABS[language]
|
| 69 |
+
|
| 70 |
+
def get_prompt_text(self, language: str) -> str:
|
| 71 |
+
"""Return a comma-joined string of all terms, used as decoder text prompt."""
|
| 72 |
+
vocab = self.get_vocab(language)
|
| 73 |
+
return ", ".join(vocab.keys())
|
| 74 |
+
|
| 75 |
+
def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor:
|
| 76 |
+
"""
|
| 77 |
+
Tokenize the vocabulary as a decoder prompt.
|
| 78 |
+
Pass this as `decoder_input_ids` or `prompt_ids` to model.generate()
|
| 79 |
+
to bias decoding toward known agricultural terms.
|
| 80 |
+
"""
|
| 81 |
+
prompt_text = self.get_prompt_text(language)
|
| 82 |
+
token_ids = processor.tokenizer(
|
| 83 |
+
prompt_text,
|
| 84 |
+
return_tensors="pt",
|
| 85 |
+
add_special_tokens=False,
|
| 86 |
+
).input_ids
|
| 87 |
+
return token_ids # shape: (1, N)
|
| 88 |
+
|
| 89 |
+
def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]:
|
| 90 |
+
"""Return flat list of token IDs for all vocabulary terms."""
|
| 91 |
+
ids = self.build_prompt_ids(processor, language)
|
| 92 |
+
return ids[0].tolist()
|
src/data/augmentation.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Field noise augmentation for West African farm environments.
|
| 3 |
+
Mixes clean speech with tractor, wind, and livestock audio samples.
|
| 4 |
+
Degrades gracefully to Gaussian noise when no .wav files are present.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FieldNoiseAugmenter:
|
| 17 |
+
"""
|
| 18 |
+
Applies audiomentations transforms that simulate noisy field conditions.
|
| 19 |
+
If the noise_dir has no .wav files, falls back to Gaussian noise only.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, noise_dir: str, config: dict) -> None:
|
| 23 |
+
self.noise_dir = Path(noise_dir)
|
| 24 |
+
self.config = config
|
| 25 |
+
self._compose = None
|
| 26 |
+
self._gaussian_only = False
|
| 27 |
+
self._build_pipeline()
|
| 28 |
+
|
| 29 |
+
def _build_pipeline(self) -> None:
|
| 30 |
+
try:
|
| 31 |
+
from audiomentations import (
|
| 32 |
+
AddBackgroundNoise,
|
| 33 |
+
AddGaussianNoise,
|
| 34 |
+
Compose,
|
| 35 |
+
RoomSimulator,
|
| 36 |
+
TimeStretch,
|
| 37 |
+
)
|
| 38 |
+
except ImportError:
|
| 39 |
+
logger.warning("audiomentations not installed — augmentation disabled.")
|
| 40 |
+
self._compose = None
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
snr_range = self.config.get("audio", {}).get("noise_snr_db_range", [5, 20])
|
| 44 |
+
prob = self.config.get("audio", {}).get("augmentation_prob", 0.6)
|
| 45 |
+
|
| 46 |
+
wav_files = list(self.noise_dir.glob("*.wav")) if self.noise_dir.exists() else []
|
| 47 |
+
|
| 48 |
+
transforms = []
|
| 49 |
+
|
| 50 |
+
if wav_files:
|
| 51 |
+
transforms.append(
|
| 52 |
+
AddBackgroundNoise(
|
| 53 |
+
sounds_path=str(self.noise_dir),
|
| 54 |
+
min_snr_db=float(snr_range[0]),
|
| 55 |
+
max_snr_db=float(snr_range[1]),
|
| 56 |
+
p=prob,
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
logger.info("FieldNoiseAugmenter: loaded %d noise files from %s", len(wav_files), self.noise_dir)
|
| 60 |
+
else:
|
| 61 |
+
logger.warning(
|
| 62 |
+
"FieldNoiseAugmenter: no .wav files found in %s — using Gaussian noise only. "
|
| 63 |
+
"Populate noise_samples/ for realistic field augmentation.",
|
| 64 |
+
self.noise_dir,
|
| 65 |
+
)
|
| 66 |
+
self._gaussian_only = True
|
| 67 |
+
|
| 68 |
+
transforms += [
|
| 69 |
+
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.3),
|
| 70 |
+
TimeStretch(min_rate=0.9, max_rate=1.1, leave_length_unchanged=True, p=0.2),
|
| 71 |
+
RoomSimulator(p=0.3),
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
self._compose = Compose(transforms)
|
| 75 |
+
|
| 76 |
+
def augment(self, audio: np.ndarray, sr: int) -> np.ndarray:
|
| 77 |
+
"""Apply augmentation pipeline to a float32 audio array."""
|
| 78 |
+
if self._compose is None:
|
| 79 |
+
return audio
|
| 80 |
+
return self._compose(samples=audio, sample_rate=sr)
|
| 81 |
+
|
| 82 |
+
def is_ready(self) -> bool:
|
| 83 |
+
"""Returns True if augmentation is available (even Gaussian-only)."""
|
| 84 |
+
return self._compose is not None
|
src/data/feature_extractor.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import TYPE_CHECKING, Any
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torchaudio
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from transformers import WhisperProcessor
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
TARGET_SR = 16_000
|
| 20 |
+
MEL_FRAMES = 3000 # 30 seconds at 100 frames/sec
|
| 21 |
+
N_MELS = 80
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AudioFeatureExtractor:
|
| 25 |
+
"""Wraps WhisperProcessor to extract and normalize audio features."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, processor: "WhisperProcessor", config: dict) -> None:
|
| 28 |
+
self.processor = processor
|
| 29 |
+
self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR)
|
| 30 |
+
|
| 31 |
+
def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Resample audio to 16kHz, extract log-mel features.
|
| 34 |
+
Returns tensor of shape (80, 3000).
|
| 35 |
+
"""
|
| 36 |
+
if sr != TARGET_SR:
|
| 37 |
+
tensor = torch.from_numpy(audio).unsqueeze(0)
|
| 38 |
+
tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR)
|
| 39 |
+
audio = tensor.squeeze(0).numpy()
|
| 40 |
+
|
| 41 |
+
inputs = self.processor.feature_extractor(
|
| 42 |
+
audio,
|
| 43 |
+
sampling_rate=TARGET_SR,
|
| 44 |
+
return_tensors="pt",
|
| 45 |
+
)
|
| 46 |
+
features = inputs.input_features[0] # (80, 3000)
|
| 47 |
+
return features
|
| 48 |
+
|
| 49 |
+
def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""Ensure features are exactly (80, 3000)."""
|
| 51 |
+
_, t = features.shape
|
| 52 |
+
if t < MEL_FRAMES:
|
| 53 |
+
pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype)
|
| 54 |
+
features = torch.cat([features, pad], dim=-1)
|
| 55 |
+
elif t > MEL_FRAMES:
|
| 56 |
+
features = features[:, :MEL_FRAMES]
|
| 57 |
+
return features
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
| 62 |
+
"""
|
| 63 |
+
Pads input_features to uniform length and label sequences with -100
|
| 64 |
+
(so they are ignored in the cross-entropy loss).
|
| 65 |
+
Compatible with HuggingFace Seq2SeqTrainer.
|
| 66 |
+
"""
|
| 67 |
+
processor: Any
|
| 68 |
+
decoder_start_token_id: int
|
| 69 |
+
|
| 70 |
+
def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
|
| 71 |
+
# Separate input_features and labels
|
| 72 |
+
input_features = [{"input_features": f["input_features"]} for f in features]
|
| 73 |
+
label_features = [{"input_ids": f["labels"]} for f in features]
|
| 74 |
+
|
| 75 |
+
# Pad input features (processor handles this)
|
| 76 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
| 77 |
+
|
| 78 |
+
# Pad labels
|
| 79 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
| 80 |
+
labels = labels_batch["input_ids"].masked_fill(
|
| 81 |
+
labels_batch.attention_mask.ne(1), -100
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Remove decoder start token if it was prepended
|
| 85 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().item():
|
| 86 |
+
labels = labels[:, 1:]
|
| 87 |
+
|
| 88 |
+
batch["labels"] = labels
|
| 89 |
+
return batch
|
src/data/waxal_loader.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loads and preprocesses the google/waxal dataset for Bambara (bam) and Fula (ful).
|
| 3 |
+
Uses streaming to avoid downloading the full corpus before training.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import TYPE_CHECKING, Callable, Iterator
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torchaudio
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from datasets import Dataset, IterableDataset
|
| 17 |
+
from transformers import WhisperProcessor
|
| 18 |
+
|
| 19 |
+
from src.data.augmentation import FieldNoiseAugmenter
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# google/waxal column names
|
| 24 |
+
AUDIO_COL = "audio"
|
| 25 |
+
TEXT_COL = "transcription"
|
| 26 |
+
TARGET_SR = 16_000
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class WaxalDataLoader:
|
| 30 |
+
"""Streams the google/waxal dataset and prepares examples for Whisper training."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
subset: str,
|
| 35 |
+
config: dict,
|
| 36 |
+
hf_token: str | None = None,
|
| 37 |
+
) -> None:
|
| 38 |
+
if subset not in ("bam", "ful"):
|
| 39 |
+
raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'")
|
| 40 |
+
self.subset = subset
|
| 41 |
+
self.config = config
|
| 42 |
+
self.hf_token = hf_token
|
| 43 |
+
|
| 44 |
+
def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset":
|
| 45 |
+
"""Return a single split of google/waxal."""
|
| 46 |
+
logger.info("Loading google/waxal subset=%s split=%s streaming=%s", self.subset, split, streaming)
|
| 47 |
+
ds = load_dataset(
|
| 48 |
+
"google/waxal",
|
| 49 |
+
self.subset,
|
| 50 |
+
split=split,
|
| 51 |
+
token=self.hf_token,
|
| 52 |
+
streaming=streaming,
|
| 53 |
+
trust_remote_code=True,
|
| 54 |
+
)
|
| 55 |
+
if streaming:
|
| 56 |
+
ds = ds.shuffle(seed=42, buffer_size=1000)
|
| 57 |
+
return ds
|
| 58 |
+
|
| 59 |
+
def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]:
|
| 60 |
+
"""Return train / validation / test splits."""
|
| 61 |
+
splits = {}
|
| 62 |
+
for split in ("train", "validation", "test"):
|
| 63 |
+
try:
|
| 64 |
+
splits[split] = self.load_split(split, streaming=streaming)
|
| 65 |
+
except Exception:
|
| 66 |
+
logger.warning("Split '%s' not available for subset '%s'", split, self.subset)
|
| 67 |
+
return splits
|
| 68 |
+
|
| 69 |
+
def make_preprocess_fn(
|
| 70 |
+
self,
|
| 71 |
+
processor: "WhisperProcessor",
|
| 72 |
+
augmenter: "FieldNoiseAugmenter | None" = None,
|
| 73 |
+
) -> Callable[[dict], dict]:
|
| 74 |
+
"""Return a function that converts a raw Waxal example into model inputs."""
|
| 75 |
+
|
| 76 |
+
def preprocess(example: dict) -> dict:
|
| 77 |
+
# Extract and resample audio
|
| 78 |
+
audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32)
|
| 79 |
+
orig_sr: int = example[AUDIO_COL]["sampling_rate"]
|
| 80 |
+
|
| 81 |
+
if orig_sr != TARGET_SR:
|
| 82 |
+
tensor = torch.from_numpy(audio_array).unsqueeze(0)
|
| 83 |
+
tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR)
|
| 84 |
+
audio_array = tensor.squeeze(0).numpy()
|
| 85 |
+
|
| 86 |
+
# Apply field noise augmentation if provided
|
| 87 |
+
if augmenter is not None and augmenter.is_ready():
|
| 88 |
+
audio_array = augmenter.augment(audio_array, TARGET_SR)
|
| 89 |
+
|
| 90 |
+
# Extract log-mel features
|
| 91 |
+
inputs = processor.feature_extractor(
|
| 92 |
+
audio_array,
|
| 93 |
+
sampling_rate=TARGET_SR,
|
| 94 |
+
return_tensors="np",
|
| 95 |
+
)
|
| 96 |
+
input_features = inputs.input_features[0] # shape (80, 3000)
|
| 97 |
+
|
| 98 |
+
# Tokenize transcript
|
| 99 |
+
text: str = example[TEXT_COL]
|
| 100 |
+
labels = processor.tokenizer(text, return_tensors="np").input_ids[0]
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"input_features": input_features,
|
| 104 |
+
"labels": labels,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
return preprocess
|
| 108 |
+
|
| 109 |
+
def iter_processed(
|
| 110 |
+
self,
|
| 111 |
+
processor: "WhisperProcessor",
|
| 112 |
+
split: str = "train",
|
| 113 |
+
augmenter: "FieldNoiseAugmenter | None" = None,
|
| 114 |
+
) -> Iterator[dict]:
|
| 115 |
+
"""Yield preprocessed examples one at a time (streaming)."""
|
| 116 |
+
ds = self.load_split(split, streaming=True)
|
| 117 |
+
fn = self.make_preprocess_fn(processor, augmenter)
|
| 118 |
+
for example in ds:
|
| 119 |
+
yield fn(example)
|
src/engine/__init__.py
ADDED
|
File without changes
|
src/engine/adapter_manager.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA adapter hot-swap manager.
|
| 3 |
+
|
| 4 |
+
Uses PEFT's multi-adapter API:
|
| 5 |
+
- model.load_adapter(path, adapter_name=lang) — first load (~2s per adapter)
|
| 6 |
+
- model.set_adapter(lang) — subsequent swap (~50ms)
|
| 7 |
+
|
| 8 |
+
This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights,
|
| 9 |
+
vs reloading the full 1.5GB model per language.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from peft import PeftModel
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from transformers import WhisperForConditionalGeneration
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AdapterManager:
|
| 26 |
+
"""Manages registration and hot-swapping of LoRA language adapters."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None:
|
| 29 |
+
self._base_model = base_model
|
| 30 |
+
self._config = config
|
| 31 |
+
self._registry: dict[str, str] = {} # language_code -> adapter_path
|
| 32 |
+
self._peft_model: PeftModel | None = None
|
| 33 |
+
self._active: str | None = None
|
| 34 |
+
|
| 35 |
+
def register(self, language: str, adapter_path: str) -> None:
|
| 36 |
+
"""Register an adapter path. Does not load it yet."""
|
| 37 |
+
path = Path(adapter_path)
|
| 38 |
+
if not path.exists():
|
| 39 |
+
logger.warning(
|
| 40 |
+
"Adapter path '%s' for language '%s' does not exist. "
|
| 41 |
+
"Run training first, or check the path.",
|
| 42 |
+
adapter_path, language,
|
| 43 |
+
)
|
| 44 |
+
self._registry[language] = str(path)
|
| 45 |
+
logger.info("Registered adapter '%s' → %s", language, adapter_path)
|
| 46 |
+
|
| 47 |
+
def load_adapter(self, language: str) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Load an adapter into the model for the first time.
|
| 50 |
+
Slow (~2s): reads adapter weights from disk.
|
| 51 |
+
Subsequent activate() calls reuse the already-loaded weights.
|
| 52 |
+
"""
|
| 53 |
+
if language not in self._registry:
|
| 54 |
+
raise KeyError(f"No adapter registered for language '{language}'. "
|
| 55 |
+
f"Available: {list(self._registry)}")
|
| 56 |
+
|
| 57 |
+
adapter_path = self._registry[language]
|
| 58 |
+
|
| 59 |
+
if self._peft_model is None:
|
| 60 |
+
# First adapter: wrap the base model with PeftModel
|
| 61 |
+
logger.info("Wrapping base model with first adapter '%s'...", language)
|
| 62 |
+
self._peft_model = PeftModel.from_pretrained(
|
| 63 |
+
self._base_model,
|
| 64 |
+
adapter_path,
|
| 65 |
+
adapter_name=language,
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
# Subsequent adapters: load into the existing PeftModel
|
| 69 |
+
logger.info("Loading adapter '%s' into existing PeftModel...", language)
|
| 70 |
+
self._peft_model.load_adapter(adapter_path, adapter_name=language)
|
| 71 |
+
|
| 72 |
+
self._active = language
|
| 73 |
+
logger.info("Adapter '%s' loaded and active.", language)
|
| 74 |
+
|
| 75 |
+
def activate(self, language: str) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Hot-swap to a previously loaded adapter (~50ms).
|
| 78 |
+
Call load_adapter() first if this adapter hasn't been loaded.
|
| 79 |
+
"""
|
| 80 |
+
if self._peft_model is None:
|
| 81 |
+
self.load_adapter(language)
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
loaded = set(self._peft_model.peft_config.keys())
|
| 85 |
+
if language not in loaded:
|
| 86 |
+
self.load_adapter(language)
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
self._peft_model.set_adapter(language)
|
| 90 |
+
self._active = language
|
| 91 |
+
logger.debug("Hot-swapped to adapter '%s'.", language)
|
| 92 |
+
|
| 93 |
+
def get_model(self) -> "WhisperForConditionalGeneration | PeftModel":
|
| 94 |
+
"""Return the PeftModel (or base model if no adapter loaded yet)."""
|
| 95 |
+
return self._peft_model if self._peft_model is not None else self._base_model
|
| 96 |
+
|
| 97 |
+
def get_active(self) -> str | None:
|
| 98 |
+
return self._active
|
| 99 |
+
|
| 100 |
+
def list_available(self) -> list[str]:
|
| 101 |
+
return list(self._registry.keys())
|
| 102 |
+
|
| 103 |
+
def list_loaded(self) -> list[str]:
|
| 104 |
+
if self._peft_model is None:
|
| 105 |
+
return []
|
| 106 |
+
return list(self._peft_model.peft_config.keys())
|
src/engine/transcriber.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Public inference interface.
|
| 3 |
+
Accepts audio as a file path or numpy array and returns transcribed text.
|
| 4 |
+
Handles chunking for audio longer than 30 seconds.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import TYPE_CHECKING
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from src.engine.adapter_manager import AdapterManager
|
| 20 |
+
from src.engine.whisper_base import WhisperBackbone
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
TARGET_SR = 16_000
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class TranscriptionResult:
|
| 29 |
+
text: str
|
| 30 |
+
language: str
|
| 31 |
+
duration_s: float
|
| 32 |
+
processing_time_ms: int
|
| 33 |
+
confidence: float | None = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Transcriber:
|
| 37 |
+
"""
|
| 38 |
+
Composes WhisperBackbone + AdapterManager to provide a simple transcription API.
|
| 39 |
+
Thread-safety: Not thread-safe by design — use one worker process.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None:
|
| 43 |
+
self._backbone = backbone
|
| 44 |
+
self._adapter_manager = adapter_manager
|
| 45 |
+
|
| 46 |
+
def transcribe(
|
| 47 |
+
self,
|
| 48 |
+
audio: np.ndarray,
|
| 49 |
+
sample_rate: int,
|
| 50 |
+
language: str,
|
| 51 |
+
use_agri_prompt: bool = True,
|
| 52 |
+
) -> TranscriptionResult:
|
| 53 |
+
"""
|
| 54 |
+
Transcribe a float32 audio array.
|
| 55 |
+
For audio > 30s, uses transformers pipeline with chunking.
|
| 56 |
+
"""
|
| 57 |
+
t0 = time.time()
|
| 58 |
+
|
| 59 |
+
# Activate the correct language adapter
|
| 60 |
+
self._adapter_manager.activate(language)
|
| 61 |
+
|
| 62 |
+
processor = self._backbone.processor
|
| 63 |
+
model = self._adapter_manager.get_model()
|
| 64 |
+
device = self._backbone.device
|
| 65 |
+
duration_s = len(audio) / sample_rate
|
| 66 |
+
|
| 67 |
+
if duration_s <= 30.0:
|
| 68 |
+
text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device)
|
| 69 |
+
else:
|
| 70 |
+
text = self._transcribe_long(audio, sample_rate, language, processor, model, device)
|
| 71 |
+
|
| 72 |
+
elapsed_ms = int((time.time() - t0) * 1000)
|
| 73 |
+
return TranscriptionResult(
|
| 74 |
+
text=text.strip(),
|
| 75 |
+
language=language,
|
| 76 |
+
duration_s=duration_s,
|
| 77 |
+
processing_time_ms=elapsed_ms,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult:
|
| 81 |
+
"""Load audio from disk and transcribe."""
|
| 82 |
+
import librosa
|
| 83 |
+
audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
|
| 84 |
+
return self.transcribe(audio, sr, language)
|
| 85 |
+
|
| 86 |
+
def _transcribe_chunk(
|
| 87 |
+
self,
|
| 88 |
+
audio: np.ndarray,
|
| 89 |
+
sr: int,
|
| 90 |
+
language: str,
|
| 91 |
+
processor,
|
| 92 |
+
model,
|
| 93 |
+
device: str,
|
| 94 |
+
) -> str:
|
| 95 |
+
"""Transcribe a single ≤30s chunk."""
|
| 96 |
+
inputs = processor.feature_extractor(
|
| 97 |
+
audio, sampling_rate=sr, return_tensors="pt"
|
| 98 |
+
)
|
| 99 |
+
input_features = inputs.input_features.to(device)
|
| 100 |
+
if device == "cuda":
|
| 101 |
+
input_features = input_features.half()
|
| 102 |
+
|
| 103 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
| 104 |
+
language=language, task="transcribe"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
predicted_ids = model.generate(
|
| 109 |
+
input_features,
|
| 110 |
+
forced_decoder_ids=forced_decoder_ids,
|
| 111 |
+
max_new_tokens=128,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 115 |
+
|
| 116 |
+
def _transcribe_long(
|
| 117 |
+
self,
|
| 118 |
+
audio: np.ndarray,
|
| 119 |
+
sr: int,
|
| 120 |
+
language: str,
|
| 121 |
+
processor,
|
| 122 |
+
model,
|
| 123 |
+
device: str,
|
| 124 |
+
) -> str:
|
| 125 |
+
"""Chunk audio into 30s segments and concatenate transcriptions."""
|
| 126 |
+
chunk_size = TARGET_SR * 30
|
| 127 |
+
chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)]
|
| 128 |
+
parts = []
|
| 129 |
+
for chunk in chunks:
|
| 130 |
+
text = self._transcribe_chunk(chunk, sr, language, processor, model, device)
|
| 131 |
+
parts.append(text)
|
| 132 |
+
return " ".join(parts)
|
src/engine/whisper_base.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loads the Whisper backbone model and processor once.
|
| 3 |
+
All other modules receive references to this shared instance.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class WhisperBackbone:
|
| 18 |
+
"""Singleton-style loader for the Whisper base model and processor."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config_path: str = "configs/base_config.yaml") -> None:
|
| 21 |
+
config_path = Path(config_path)
|
| 22 |
+
with open(config_path) as f:
|
| 23 |
+
cfg = yaml.safe_load(f)
|
| 24 |
+
self._model_id: str = cfg["model"]["id"]
|
| 25 |
+
self._model: WhisperForConditionalGeneration | None = None
|
| 26 |
+
self._processor: WhisperProcessor | None = None
|
| 27 |
+
self._device: str = "cpu"
|
| 28 |
+
|
| 29 |
+
def load(self, device: str = "cuda", hf_token: str | None = None) -> None:
|
| 30 |
+
"""Load model and processor into memory. Call once at startup."""
|
| 31 |
+
self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu"
|
| 32 |
+
logger.info("Loading %s on %s", self._model_id, self._device)
|
| 33 |
+
|
| 34 |
+
self._processor = WhisperProcessor.from_pretrained(
|
| 35 |
+
self._model_id,
|
| 36 |
+
token=hf_token,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
dtype = torch.float16 if self._device == "cuda" else torch.float32
|
| 40 |
+
self._model = WhisperForConditionalGeneration.from_pretrained(
|
| 41 |
+
self._model_id,
|
| 42 |
+
torch_dtype=dtype,
|
| 43 |
+
token=hf_token,
|
| 44 |
+
).to(self._device)
|
| 45 |
+
|
| 46 |
+
self._model.eval()
|
| 47 |
+
logger.info("Model loaded successfully (dtype=%s, device=%s)", dtype, self._device)
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def model(self) -> WhisperForConditionalGeneration:
|
| 51 |
+
if self._model is None:
|
| 52 |
+
raise RuntimeError("Call WhisperBackbone.load() before accessing the model.")
|
| 53 |
+
return self._model
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def processor(self) -> WhisperProcessor:
|
| 57 |
+
if self._processor is None:
|
| 58 |
+
raise RuntimeError("Call WhisperBackbone.load() before accessing the processor.")
|
| 59 |
+
return self._processor
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def device(self) -> str:
|
| 63 |
+
return self._device
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def model_id(self) -> str:
|
| 67 |
+
return self._model_id
|
| 68 |
+
|
| 69 |
+
def free(self) -> None:
|
| 70 |
+
"""Release GPU memory."""
|
| 71 |
+
del self._model
|
| 72 |
+
del self._processor
|
| 73 |
+
self._model = None
|
| 74 |
+
self._processor = None
|
| 75 |
+
if torch.cuda.is_available():
|
| 76 |
+
torch.cuda.empty_cache()
|
| 77 |
+
logger.info("Backbone freed from memory.")
|
src/iot/__init__.py
ADDED
|
File without changes
|
src/iot/intent_parser.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Maps transcribed Bambara/Fula text to structured intents for IoT sensor queries.
|
| 3 |
+
Uses keyword matching (no ML required for v1).
|
| 4 |
+
Confidence = fraction of intent keywords present in the transcription.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Intent:
|
| 13 |
+
action: str # e.g., "check_soil", "check_weather"
|
| 14 |
+
entity: str # e.g., "soil", "weather"
|
| 15 |
+
parameters: dict = field(default_factory=dict)
|
| 16 |
+
confidence: float = 0.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Intent keyword taxonomy for Bambara (bam) and Fula (ful)
|
| 20 |
+
INTENT_KEYWORDS: dict[str, dict[str, list[str]]] = {
|
| 21 |
+
"check_soil": {
|
| 22 |
+
"bam": ["bunding", "nɔgɔ", "dugu", "foro", "sani"],
|
| 23 |
+
"ful": ["leydi", "ngesa", "ladde"],
|
| 24 |
+
},
|
| 25 |
+
"check_weather": {
|
| 26 |
+
"bam": ["teliman", "sanji", "dibi", "sira"],
|
| 27 |
+
"ful": ["yeeso", "fuɗorde"],
|
| 28 |
+
},
|
| 29 |
+
"irrigation_status": {
|
| 30 |
+
"bam": ["ji", "sanji", "foro"],
|
| 31 |
+
"ful": ["ndiyam", "ngesa"],
|
| 32 |
+
},
|
| 33 |
+
"pest_alert": {
|
| 34 |
+
"bam": ["kungoloni", "suruku"],
|
| 35 |
+
"ful": ["biñ-biñ"],
|
| 36 |
+
},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
INTENT_ENTITIES = {
|
| 40 |
+
"check_soil": "soil",
|
| 41 |
+
"check_weather": "weather",
|
| 42 |
+
"irrigation_status": "irrigation",
|
| 43 |
+
"pest_alert": "pest",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class IntentParser:
|
| 48 |
+
"""Parses a transcription string into a structured Intent."""
|
| 49 |
+
|
| 50 |
+
def parse(self, text: str, language: str) -> Intent:
|
| 51 |
+
"""
|
| 52 |
+
Find the best matching intent by counting keyword overlaps.
|
| 53 |
+
Returns the highest-confidence intent.
|
| 54 |
+
"""
|
| 55 |
+
text_lower = text.lower()
|
| 56 |
+
best_action = "unknown"
|
| 57 |
+
best_confidence = 0.0
|
| 58 |
+
|
| 59 |
+
for action, lang_keywords in INTENT_KEYWORDS.items():
|
| 60 |
+
keywords = lang_keywords.get(language, [])
|
| 61 |
+
if not keywords:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
matches = sum(1 for kw in keywords if kw in text_lower)
|
| 65 |
+
confidence = matches / len(keywords)
|
| 66 |
+
|
| 67 |
+
if confidence > best_confidence:
|
| 68 |
+
best_confidence = confidence
|
| 69 |
+
best_action = action
|
| 70 |
+
|
| 71 |
+
return Intent(
|
| 72 |
+
action=best_action,
|
| 73 |
+
entity=INTENT_ENTITIES.get(best_action, "unknown"),
|
| 74 |
+
confidence=round(best_confidence, 3),
|
| 75 |
+
)
|
src/iot/sensor_bridge.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fetches sensor data (soil moisture, weather, irrigation) from the IoT backend API.
|
| 3 |
+
Falls back to synthetic mock data when SENSOR_API_URL is not configured.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import random
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from src.iot.intent_parser import Intent
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SensorData:
|
| 21 |
+
sensor_type: str
|
| 22 |
+
values: dict[str, float]
|
| 23 |
+
timestamp: str
|
| 24 |
+
unit: str = ""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SensorBridge:
|
| 28 |
+
"""Async bridge to IoT sensor API. Uses mock data when no API URL is configured."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, sensor_api_url: str | None = None, timeout_s: float = 5.0) -> None:
|
| 31 |
+
self.sensor_api_url = sensor_api_url
|
| 32 |
+
self.timeout_s = timeout_s
|
| 33 |
+
self._mock_mode = not sensor_api_url
|
| 34 |
+
|
| 35 |
+
if self._mock_mode:
|
| 36 |
+
logger.info("SensorBridge: running in MOCK mode (set SENSOR_API_URL to use real sensors).")
|
| 37 |
+
|
| 38 |
+
async def fetch(self, intent: "Intent", field_id: str | None = None) -> SensorData:
|
| 39 |
+
"""Dispatch to the correct sensor fetch method based on intent entity."""
|
| 40 |
+
action = intent.action
|
| 41 |
+
if action == "check_soil":
|
| 42 |
+
return await self.get_soil_data(field_id or "default")
|
| 43 |
+
elif action == "check_weather":
|
| 44 |
+
return await self.get_weather(field_id or "default")
|
| 45 |
+
elif action == "irrigation_status":
|
| 46 |
+
return await self.get_irrigation(field_id or "default")
|
| 47 |
+
elif action == "pest_alert":
|
| 48 |
+
return await self.get_pest_status(field_id or "default")
|
| 49 |
+
else:
|
| 50 |
+
return SensorData(
|
| 51 |
+
sensor_type="unknown",
|
| 52 |
+
values={},
|
| 53 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
async def get_soil_data(self, location_id: str) -> SensorData:
|
| 57 |
+
if self._mock_mode:
|
| 58 |
+
return SensorData(
|
| 59 |
+
sensor_type="soil",
|
| 60 |
+
values={
|
| 61 |
+
"moisture_pct": round(random.uniform(25, 65), 1),
|
| 62 |
+
"ph": round(random.uniform(5.5, 7.5), 1),
|
| 63 |
+
"nitrogen_ppm": round(random.uniform(10, 40), 1),
|
| 64 |
+
"temperature_c": round(random.uniform(24, 35), 1),
|
| 65 |
+
},
|
| 66 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 67 |
+
)
|
| 68 |
+
return await self._get(f"/sensors/soil/{location_id}", "soil")
|
| 69 |
+
|
| 70 |
+
async def get_weather(self, location_id: str) -> SensorData:
|
| 71 |
+
if self._mock_mode:
|
| 72 |
+
return SensorData(
|
| 73 |
+
sensor_type="weather",
|
| 74 |
+
values={
|
| 75 |
+
"temperature_c": round(random.uniform(28, 42), 1),
|
| 76 |
+
"humidity_pct": round(random.uniform(20, 80), 1),
|
| 77 |
+
"wind_speed_kmh": round(random.uniform(0, 25), 1),
|
| 78 |
+
"rain_probability_pct": round(random.uniform(0, 100), 1),
|
| 79 |
+
},
|
| 80 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 81 |
+
)
|
| 82 |
+
return await self._get(f"/sensors/weather/{location_id}", "weather")
|
| 83 |
+
|
| 84 |
+
async def get_irrigation(self, field_id: str) -> SensorData:
|
| 85 |
+
if self._mock_mode:
|
| 86 |
+
return SensorData(
|
| 87 |
+
sensor_type="irrigation",
|
| 88 |
+
values={
|
| 89 |
+
"flow_rate_lph": round(random.uniform(0, 500), 1),
|
| 90 |
+
"pressure_bar": round(random.uniform(1.0, 4.0), 2),
|
| 91 |
+
"active": float(random.choice([0, 1])),
|
| 92 |
+
"last_irrigation_h_ago": round(random.uniform(1, 48), 1),
|
| 93 |
+
},
|
| 94 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 95 |
+
)
|
| 96 |
+
return await self._get(f"/sensors/irrigation/{field_id}", "irrigation")
|
| 97 |
+
|
| 98 |
+
async def get_pest_status(self, field_id: str) -> SensorData:
|
| 99 |
+
if self._mock_mode:
|
| 100 |
+
return SensorData(
|
| 101 |
+
sensor_type="pest",
|
| 102 |
+
values={
|
| 103 |
+
"trap_count_24h": float(random.randint(0, 50)),
|
| 104 |
+
"alert_level": float(random.randint(0, 3)), # 0=none 1=low 2=medium 3=high
|
| 105 |
+
},
|
| 106 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 107 |
+
)
|
| 108 |
+
return await self._get(f"/sensors/pest/{field_id}", "pest")
|
| 109 |
+
|
| 110 |
+
async def _get(self, path: str, sensor_type: str) -> SensorData:
|
| 111 |
+
import httpx
|
| 112 |
+
url = f"{self.sensor_api_url}{path}"
|
| 113 |
+
async with httpx.AsyncClient(timeout=self.timeout_s) as client:
|
| 114 |
+
response = await client.get(url)
|
| 115 |
+
response.raise_for_status()
|
| 116 |
+
data = response.json()
|
| 117 |
+
return SensorData(
|
| 118 |
+
sensor_type=sensor_type,
|
| 119 |
+
values=data.get("values", data),
|
| 120 |
+
timestamp=data.get("timestamp", datetime.utcnow().isoformat()),
|
| 121 |
+
)
|
src/iot/voice_responder.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generates voice response text from sensor data in the farmer's own language.
|
| 3 |
+
Supports Bambara (bam), Fula (ful), French (fr), and English (en).
|
| 4 |
+
Bambara/Fula templates use short sentences (≤15 words) for best MMS-TTS quality.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import TYPE_CHECKING
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from src.iot.intent_parser import Intent
|
| 12 |
+
from src.iot.sensor_bridge import SensorData
|
| 13 |
+
|
| 14 |
+
# Alert thresholds
|
| 15 |
+
SOIL_MOISTURE_LOW = 30.0 # Below this → immediate irrigation recommended
|
| 16 |
+
SOIL_MOISTURE_HIGH = 70.0 # Above this → drainage warning
|
| 17 |
+
SOIL_PH_LOW = 5.5
|
| 18 |
+
SOIL_PH_HIGH = 7.5
|
| 19 |
+
TEMP_HIGH = 38.0
|
| 20 |
+
PEST_ALERT_HIGH = 2 # Alert level ≥ 2 → warning
|
| 21 |
+
|
| 22 |
+
# ── Bambara templates (≤6 words per sentence for clear MMS-TTS output) ───────
|
| 23 |
+
BAMBARA_TEMPLATES = {
|
| 24 |
+
"soil_moisture_low": "Bunding ji dɔgɔ. I ka foro ji.",
|
| 25 |
+
"soil_moisture_high": "Ji ca kojugu. Foro ma fɛ.",
|
| 26 |
+
"soil_ph_low": "Bunding kɔnɔ jugu. Kalisi fara a kan.",
|
| 27 |
+
"soil_ph_high": "Bunding kɔnɔ tɛmɛ. Soufre fara a kan.",
|
| 28 |
+
"weather_hot": "Teliman gbɛlɛ. Tile ma sigi.",
|
| 29 |
+
"rain_likely": "Sanji bɛ na. Sɔrɔ jɔ.",
|
| 30 |
+
"pest_high": "Dɔgɔw bɛ foro kɔnɔ. Bɔ u.",
|
| 31 |
+
"irrigation_needed": "Foro fɛ ji. Ji sira yɔrɔ.",
|
| 32 |
+
"irrigation_active": "Ji bɛ taa. A bɛ kɛ cogo di.",
|
| 33 |
+
"default": "Kabako jumanw sɔrɔla.",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# ── Fula templates (≤6 words per sentence for clear MMS-TTS output) ──────────
|
| 37 |
+
FULA_TEMPLATES = {
|
| 38 |
+
"soil_moisture_low": "Leydi ndiyam famɗi. Wado ngesa.",
|
| 39 |
+
"soil_moisture_high": "Ndiyam heewi. Leydi famɗaali.",
|
| 40 |
+
"soil_ph_low": "Leydi suurii. Waɗ kalisi.",
|
| 41 |
+
"soil_ph_high": "Leydi alkalii. Waɗ soufre.",
|
| 42 |
+
"weather_hot": "Nguleeki heewi. Muusal.",
|
| 43 |
+
"rain_likely": "Ndiyam wadata. Loosu ngesa.",
|
| 44 |
+
"pest_high": "Biñ-biñ ngesa nder. Fiil ɗen.",
|
| 45 |
+
"irrigation_needed": "Ngesa fɛɗɛli ndiyam. Wado.",
|
| 46 |
+
"irrigation_active": "Ndiyam wona jooni.",
|
| 47 |
+
"default": "Humpito juuti waɗaama.",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class VoiceResponder:
|
| 52 |
+
"""Converts sensor readings into actionable voice messages in the farmer's language."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, language: str = "fr") -> None:
|
| 55 |
+
self.language = language
|
| 56 |
+
|
| 57 |
+
def generate_response(self, intent: "Intent", sensor_data: "SensorData") -> str:
|
| 58 |
+
if self.language == "bam":
|
| 59 |
+
return self._bambara_response(sensor_data)
|
| 60 |
+
elif self.language == "ful":
|
| 61 |
+
return self._fula_response(sensor_data)
|
| 62 |
+
else:
|
| 63 |
+
return self._french_response(sensor_data)
|
| 64 |
+
|
| 65 |
+
# ── Bambara ──────────────────────────────────────────────────────────────
|
| 66 |
+
|
| 67 |
+
def _bambara_response(self, sensor_data: "SensorData") -> str:
|
| 68 |
+
t = sensor_data.sensor_type
|
| 69 |
+
v = sensor_data.values
|
| 70 |
+
T = BAMBARA_TEMPLATES
|
| 71 |
+
|
| 72 |
+
if t == "soil":
|
| 73 |
+
moisture = v.get("moisture_pct")
|
| 74 |
+
if moisture is not None:
|
| 75 |
+
if moisture < SOIL_MOISTURE_LOW:
|
| 76 |
+
return T["soil_moisture_low"]
|
| 77 |
+
elif moisture > SOIL_MOISTURE_HIGH:
|
| 78 |
+
return T["soil_moisture_high"]
|
| 79 |
+
ph = v.get("ph")
|
| 80 |
+
if ph is not None:
|
| 81 |
+
if ph < SOIL_PH_LOW:
|
| 82 |
+
return T["soil_ph_low"]
|
| 83 |
+
elif ph > SOIL_PH_HIGH:
|
| 84 |
+
return T["soil_ph_high"]
|
| 85 |
+
|
| 86 |
+
elif t == "weather":
|
| 87 |
+
temp = v.get("temperature_c")
|
| 88 |
+
rain = v.get("rain_probability_pct")
|
| 89 |
+
if temp is not None and temp > TEMP_HIGH:
|
| 90 |
+
return T["weather_hot"]
|
| 91 |
+
if rain is not None and rain > 70:
|
| 92 |
+
return T["rain_likely"]
|
| 93 |
+
|
| 94 |
+
elif t == "irrigation":
|
| 95 |
+
last = v.get("last_irrigation_h_ago")
|
| 96 |
+
active = v.get("active")
|
| 97 |
+
if active:
|
| 98 |
+
return T["irrigation_active"]
|
| 99 |
+
if last is not None and last > 24:
|
| 100 |
+
return T["irrigation_needed"]
|
| 101 |
+
|
| 102 |
+
elif t == "pest":
|
| 103 |
+
level = int(v.get("alert_level", 0))
|
| 104 |
+
if level >= PEST_ALERT_HIGH:
|
| 105 |
+
return T["pest_high"]
|
| 106 |
+
|
| 107 |
+
return T["default"]
|
| 108 |
+
|
| 109 |
+
# ── Fula ─────────────────────────────────────────────────────────────────
|
| 110 |
+
|
| 111 |
+
def _fula_response(self, sensor_data: "SensorData") -> str:
|
| 112 |
+
t = sensor_data.sensor_type
|
| 113 |
+
v = sensor_data.values
|
| 114 |
+
T = FULA_TEMPLATES
|
| 115 |
+
|
| 116 |
+
if t == "soil":
|
| 117 |
+
moisture = v.get("moisture_pct")
|
| 118 |
+
if moisture is not None:
|
| 119 |
+
if moisture < SOIL_MOISTURE_LOW:
|
| 120 |
+
return T["soil_moisture_low"]
|
| 121 |
+
elif moisture > SOIL_MOISTURE_HIGH:
|
| 122 |
+
return T["soil_moisture_high"]
|
| 123 |
+
ph = v.get("ph")
|
| 124 |
+
if ph is not None:
|
| 125 |
+
if ph < SOIL_PH_LOW:
|
| 126 |
+
return T["soil_ph_low"]
|
| 127 |
+
elif ph > SOIL_PH_HIGH:
|
| 128 |
+
return T["soil_ph_high"]
|
| 129 |
+
|
| 130 |
+
elif t == "weather":
|
| 131 |
+
temp = v.get("temperature_c")
|
| 132 |
+
rain = v.get("rain_probability_pct")
|
| 133 |
+
if temp is not None and temp > TEMP_HIGH:
|
| 134 |
+
return T["weather_hot"]
|
| 135 |
+
if rain is not None and rain > 70:
|
| 136 |
+
return T["rain_likely"]
|
| 137 |
+
|
| 138 |
+
elif t == "irrigation":
|
| 139 |
+
active = v.get("active")
|
| 140 |
+
last = v.get("last_irrigation_h_ago")
|
| 141 |
+
if active:
|
| 142 |
+
return T["irrigation_active"]
|
| 143 |
+
if last is not None and last > 24:
|
| 144 |
+
return T["irrigation_needed"]
|
| 145 |
+
|
| 146 |
+
elif t == "pest":
|
| 147 |
+
level = int(v.get("alert_level", 0))
|
| 148 |
+
if level >= PEST_ALERT_HIGH:
|
| 149 |
+
return T["pest_high"]
|
| 150 |
+
|
| 151 |
+
return T["default"]
|
| 152 |
+
|
| 153 |
+
# ── French (original) ─────────────────────────────────────────────────────
|
| 154 |
+
|
| 155 |
+
def _french_response(self, sensor_data: "SensorData") -> str:
|
| 156 |
+
t = sensor_data.sensor_type
|
| 157 |
+
v = sensor_data.values
|
| 158 |
+
if t == "soil":
|
| 159 |
+
return self._soil_response(v)
|
| 160 |
+
elif t == "weather":
|
| 161 |
+
return self._weather_response(v)
|
| 162 |
+
elif t == "irrigation":
|
| 163 |
+
return self._irrigation_response(v)
|
| 164 |
+
elif t == "pest":
|
| 165 |
+
return self._pest_response(v)
|
| 166 |
+
else:
|
| 167 |
+
return "Données du capteur non disponibles pour le moment."
|
| 168 |
+
|
| 169 |
+
def _soil_response(self, v: dict) -> str:
|
| 170 |
+
parts = []
|
| 171 |
+
moisture = v.get("moisture_pct")
|
| 172 |
+
ph = v.get("ph")
|
| 173 |
+
temp = v.get("temperature_c")
|
| 174 |
+
nitrogen = v.get("nitrogen_ppm")
|
| 175 |
+
|
| 176 |
+
if moisture is not None:
|
| 177 |
+
parts.append(f"Humidité du sol : {moisture:.0f}%.")
|
| 178 |
+
if moisture < SOIL_MOISTURE_LOW:
|
| 179 |
+
parts.append("Irrigation recommandée immédiatement.")
|
| 180 |
+
elif moisture > SOIL_MOISTURE_HIGH:
|
| 181 |
+
parts.append("Sol trop humide, risque d'engorgement.")
|
| 182 |
+
|
| 183 |
+
if ph is not None:
|
| 184 |
+
parts.append(f"pH du sol : {ph:.1f}.")
|
| 185 |
+
if ph < SOIL_PH_LOW:
|
| 186 |
+
parts.append("Sol trop acide — envisagez un amendement calcaire.")
|
| 187 |
+
elif ph > SOIL_PH_HIGH:
|
| 188 |
+
parts.append("Sol trop alcalin — un apport de soufre peut aider.")
|
| 189 |
+
|
| 190 |
+
if temp is not None:
|
| 191 |
+
parts.append(f"Température du sol : {temp:.0f}°C.")
|
| 192 |
+
|
| 193 |
+
if nitrogen is not None:
|
| 194 |
+
parts.append(f"Azote disponible : {nitrogen:.0f} ppm.")
|
| 195 |
+
if nitrogen < 15:
|
| 196 |
+
parts.append("Niveau d'azote faible — envisagez un engrais azoté.")
|
| 197 |
+
|
| 198 |
+
return " ".join(parts) if parts else "Données du sol reçues."
|
| 199 |
+
|
| 200 |
+
def _weather_response(self, v: dict) -> str:
|
| 201 |
+
parts = []
|
| 202 |
+
temp = v.get("temperature_c")
|
| 203 |
+
humidity = v.get("humidity_pct")
|
| 204 |
+
wind = v.get("wind_speed_kmh")
|
| 205 |
+
rain = v.get("rain_probability_pct")
|
| 206 |
+
|
| 207 |
+
if temp is not None:
|
| 208 |
+
parts.append(f"Température : {temp:.0f}°C.")
|
| 209 |
+
if temp > TEMP_HIGH:
|
| 210 |
+
parts.append("Chaleur excessive — évitez les travaux aux heures les plus chaudes.")
|
| 211 |
+
|
| 212 |
+
if humidity is not None:
|
| 213 |
+
parts.append(f"Humidité de l'air : {humidity:.0f}%.")
|
| 214 |
+
|
| 215 |
+
if wind is not None:
|
| 216 |
+
parts.append(f"Vent : {wind:.0f} km/h.")
|
| 217 |
+
|
| 218 |
+
if rain is not None:
|
| 219 |
+
parts.append(f"Probabilité de pluie : {rain:.0f}%.")
|
| 220 |
+
if rain > 70:
|
| 221 |
+
parts.append("Pluie probable — reportez les traitements pesticides.")
|
| 222 |
+
|
| 223 |
+
return " ".join(parts) if parts else "Données météo reçues."
|
| 224 |
+
|
| 225 |
+
def _irrigation_response(self, v: dict) -> str:
|
| 226 |
+
parts = []
|
| 227 |
+
active = v.get("active")
|
| 228 |
+
last = v.get("last_irrigation_h_ago")
|
| 229 |
+
flow = v.get("flow_rate_lph")
|
| 230 |
+
|
| 231 |
+
if active is not None:
|
| 232 |
+
state = "en marche" if active else "arrêtée"
|
| 233 |
+
parts.append(f"Irrigation {state}.")
|
| 234 |
+
|
| 235 |
+
if flow is not None and active:
|
| 236 |
+
parts.append(f"Débit : {flow:.0f} litres par heure.")
|
| 237 |
+
|
| 238 |
+
if last is not None:
|
| 239 |
+
parts.append(f"Dernière irrigation il y a {last:.0f} heures.")
|
| 240 |
+
if last > 24:
|
| 241 |
+
parts.append("Plus de 24 heures sans irrigation — vérifiez les besoins en eau.")
|
| 242 |
+
|
| 243 |
+
return " ".join(parts) if parts else "Statut d'irrigation reçu."
|
| 244 |
+
|
| 245 |
+
def _pest_response(self, v: dict) -> str:
|
| 246 |
+
level = int(v.get("alert_level", 0))
|
| 247 |
+
count = v.get("trap_count_24h")
|
| 248 |
+
|
| 249 |
+
level_labels = {0: "aucune", 1: "faible", 2: "modérée", 3: "élevée"}
|
| 250 |
+
label = level_labels.get(level, "inconnue")
|
| 251 |
+
|
| 252 |
+
parts = [f"Présence d'insectes nuisibles : niveau {label}."]
|
| 253 |
+
|
| 254 |
+
if count is not None:
|
| 255 |
+
parts.append(f"{count:.0f} insectes capturés en 24 heures.")
|
| 256 |
+
|
| 257 |
+
if level >= PEST_ALERT_HIGH:
|
| 258 |
+
parts.append("Traitement recommandé — consultez un agent agricole.")
|
| 259 |
+
|
| 260 |
+
return " ".join(parts)
|
src/optimization/__init__.py
ADDED
|
File without changes
|
src/optimization/onnx_exporter.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Merges LoRA adapter weights into the backbone and exports to ONNX.
|
| 3 |
+
Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime).
|
| 4 |
+
|
| 5 |
+
Requires: optimum[onnxruntime]
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from peft import PeftModel
|
| 15 |
+
from transformers import WhisperProcessor
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ONNXExporter:
|
| 21 |
+
"""Merges a LoRA PeftModel into its base model and exports to ONNX."""
|
| 22 |
+
|
| 23 |
+
def merge_and_export(
|
| 24 |
+
self,
|
| 25 |
+
peft_model: "PeftModel",
|
| 26 |
+
processor: "WhisperProcessor",
|
| 27 |
+
output_dir: str,
|
| 28 |
+
language: str,
|
| 29 |
+
) -> Path:
|
| 30 |
+
"""
|
| 31 |
+
1. Merge LoRA weights into base model (merge_and_unload)
|
| 32 |
+
2. Export merged model to ONNX via optimum
|
| 33 |
+
Returns the output directory path.
|
| 34 |
+
"""
|
| 35 |
+
output_path = Path(output_dir) / language
|
| 36 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
logger.info("Merging LoRA adapter '%s' into base model...", language)
|
| 39 |
+
merged_model = peft_model.merge_and_unload()
|
| 40 |
+
merged_model.eval()
|
| 41 |
+
|
| 42 |
+
logger.info("Exporting to ONNX: %s", output_path)
|
| 43 |
+
self._export_with_optimum(merged_model, processor, str(output_path))
|
| 44 |
+
|
| 45 |
+
return output_path
|
| 46 |
+
|
| 47 |
+
def _export_with_optimum(
|
| 48 |
+
self,
|
| 49 |
+
merged_model,
|
| 50 |
+
processor: "WhisperProcessor",
|
| 51 |
+
output_dir: str,
|
| 52 |
+
) -> None:
|
| 53 |
+
"""Use optimum's ONNX export pipeline."""
|
| 54 |
+
from optimum.exporters.onnx import main_export
|
| 55 |
+
|
| 56 |
+
# Save merged model to a temp directory first
|
| 57 |
+
import tempfile
|
| 58 |
+
|
| 59 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 60 |
+
logger.info("Saving merged model to temp dir for export...")
|
| 61 |
+
merged_model.save_pretrained(tmp_dir)
|
| 62 |
+
processor.save_pretrained(tmp_dir)
|
| 63 |
+
|
| 64 |
+
logger.info("Running optimum ONNX export...")
|
| 65 |
+
main_export(
|
| 66 |
+
model_name_or_path=tmp_dir,
|
| 67 |
+
output=output_dir,
|
| 68 |
+
task="automatic-speech-recognition",
|
| 69 |
+
opset=17,
|
| 70 |
+
optimize="O2",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
logger.info("ONNX export complete: %s", output_dir)
|
| 74 |
+
|
| 75 |
+
def validate(
|
| 76 |
+
self,
|
| 77 |
+
onnx_dir: str,
|
| 78 |
+
processor: "WhisperProcessor",
|
| 79 |
+
test_audio_arrays: list,
|
| 80 |
+
sample_rate: int = 16_000,
|
| 81 |
+
reference_texts: list[str] | None = None,
|
| 82 |
+
) -> dict:
|
| 83 |
+
"""
|
| 84 |
+
Run inference with the exported ONNX model and compute WER vs. references.
|
| 85 |
+
"""
|
| 86 |
+
import numpy as np
|
| 87 |
+
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
| 88 |
+
|
| 89 |
+
logger.info("Validating ONNX model at %s...", onnx_dir)
|
| 90 |
+
ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir)
|
| 91 |
+
|
| 92 |
+
transcriptions = []
|
| 93 |
+
for audio in test_audio_arrays:
|
| 94 |
+
inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
|
| 95 |
+
outputs = ort_model.generate(inputs.input_features)
|
| 96 |
+
text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 97 |
+
transcriptions.append(text)
|
| 98 |
+
|
| 99 |
+
result = {"transcriptions": transcriptions}
|
| 100 |
+
|
| 101 |
+
if reference_texts:
|
| 102 |
+
import jiwer
|
| 103 |
+
wer = jiwer.wer(reference_texts, transcriptions)
|
| 104 |
+
result["wer"] = round(wer, 4)
|
| 105 |
+
|
| 106 |
+
return result
|
src/optimization/quantizer.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BitsAndBytes quantization for GPU-constrained deployment.
|
| 3 |
+
4-bit NF4: reduces Whisper-large-v3-turbo from ~3GB to ~1GB VRAM.
|
| 4 |
+
8-bit: intermediate option with less accuracy loss.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
from typing import TYPE_CHECKING
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import BitsAndBytesConfig, WhisperForConditionalGeneration, WhisperProcessor
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_4bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration:
|
| 22 |
+
"""Load Whisper with 4-bit NF4 quantization. Reduces VRAM to ~1GB."""
|
| 23 |
+
bnb_config = BitsAndBytesConfig(
|
| 24 |
+
load_in_4bit=True,
|
| 25 |
+
bnb_4bit_quant_type="nf4",
|
| 26 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 27 |
+
bnb_4bit_use_double_quant=True,
|
| 28 |
+
)
|
| 29 |
+
logger.info("Loading %s with 4-bit NF4 quantization...", model_id)
|
| 30 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
| 31 |
+
model_id,
|
| 32 |
+
quantization_config=bnb_config,
|
| 33 |
+
device_map="auto",
|
| 34 |
+
token=hf_token,
|
| 35 |
+
)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_8bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration:
|
| 40 |
+
"""Load Whisper with 8-bit quantization. Reduces VRAM to ~1.5GB."""
|
| 41 |
+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 42 |
+
logger.info("Loading %s with 8-bit quantization...", model_id)
|
| 43 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
| 44 |
+
model_id,
|
| 45 |
+
quantization_config=bnb_config,
|
| 46 |
+
device_map="auto",
|
| 47 |
+
token=hf_token,
|
| 48 |
+
)
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ModelQuantizer:
|
| 53 |
+
"""Benchmarks quantized vs full-precision models."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, model_id: str, hf_token: str | None = None) -> None:
|
| 56 |
+
self.model_id = model_id
|
| 57 |
+
self.hf_token = hf_token
|
| 58 |
+
|
| 59 |
+
def benchmark(
|
| 60 |
+
self,
|
| 61 |
+
model: WhisperForConditionalGeneration,
|
| 62 |
+
processor: WhisperProcessor,
|
| 63 |
+
test_audio_arrays: list,
|
| 64 |
+
sample_rate: int = 16_000,
|
| 65 |
+
) -> dict:
|
| 66 |
+
"""Measure latency and memory for a list of audio arrays."""
|
| 67 |
+
import numpy as np
|
| 68 |
+
|
| 69 |
+
device = next(model.parameters()).device
|
| 70 |
+
latencies = []
|
| 71 |
+
|
| 72 |
+
for audio in test_audio_arrays:
|
| 73 |
+
inputs = processor.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt")
|
| 74 |
+
features = inputs.input_features.to(device)
|
| 75 |
+
|
| 76 |
+
if device.type == "cuda":
|
| 77 |
+
torch.cuda.synchronize()
|
| 78 |
+
t0 = time.perf_counter()
|
| 79 |
+
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
model.generate(features, max_new_tokens=50)
|
| 82 |
+
|
| 83 |
+
if device.type == "cuda":
|
| 84 |
+
torch.cuda.synchronize()
|
| 85 |
+
latencies.append((time.perf_counter() - t0) * 1000)
|
| 86 |
+
|
| 87 |
+
result = {
|
| 88 |
+
"mean_latency_ms": round(sum(latencies) / len(latencies), 1),
|
| 89 |
+
"max_latency_ms": round(max(latencies), 1),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if torch.cuda.is_available():
|
| 93 |
+
result["vram_allocated_gb"] = round(torch.cuda.memory_allocated() / 1e9, 2)
|
| 94 |
+
|
| 95 |
+
return result
|
src/optimization/tflite_converter.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas).
|
| 3 |
+
Note: Whisper's encoder and decoder are exported as separate TFLite models and
|
| 4 |
+
orchestrated together at inference time.
|
| 5 |
+
|
| 6 |
+
Requires: onnx-tf, tensorflow (install separately — large dependencies)
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TFLiteConverter:
|
| 17 |
+
"""Converts ONNX Whisper models to TFLite format for edge deployment."""
|
| 18 |
+
|
| 19 |
+
def convert(
|
| 20 |
+
self,
|
| 21 |
+
onnx_encoder_path: str,
|
| 22 |
+
onnx_decoder_path: str,
|
| 23 |
+
output_dir: str,
|
| 24 |
+
quantize: bool = True,
|
| 25 |
+
) -> dict[str, Path]:
|
| 26 |
+
"""
|
| 27 |
+
Convert encoder and decoder ONNX models to TFLite.
|
| 28 |
+
Returns paths to the generated .tflite files.
|
| 29 |
+
"""
|
| 30 |
+
output_path = Path(output_dir)
|
| 31 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
encoder_tflite = output_path / "encoder.tflite"
|
| 34 |
+
decoder_tflite = output_path / "decoder.tflite"
|
| 35 |
+
|
| 36 |
+
logger.info("Converting encoder ONNX → TFLite...")
|
| 37 |
+
self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize)
|
| 38 |
+
|
| 39 |
+
logger.info("Converting decoder ONNX → TFLite...")
|
| 40 |
+
self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize)
|
| 41 |
+
|
| 42 |
+
return {"encoder": encoder_tflite, "decoder": decoder_tflite}
|
| 43 |
+
|
| 44 |
+
def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None:
|
| 45 |
+
"""Convert a single ONNX model to TFLite via onnx-tf + tensorflow."""
|
| 46 |
+
try:
|
| 47 |
+
import onnx
|
| 48 |
+
import onnx_tf
|
| 49 |
+
import tensorflow as tf
|
| 50 |
+
except ImportError as e:
|
| 51 |
+
raise ImportError(
|
| 52 |
+
"TFLite conversion requires onnx-tf and tensorflow. "
|
| 53 |
+
"Install with: pip install onnx-tf tensorflow"
|
| 54 |
+
) from e
|
| 55 |
+
|
| 56 |
+
import tempfile
|
| 57 |
+
|
| 58 |
+
# Step 1: ONNX → TensorFlow SavedModel
|
| 59 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 60 |
+
onnx_model = onnx.load(onnx_path)
|
| 61 |
+
tf_rep = onnx_tf.backend.prepare(onnx_model)
|
| 62 |
+
tf_rep.export_graph(tmp_dir)
|
| 63 |
+
|
| 64 |
+
# Step 2: TF SavedModel → TFLite
|
| 65 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
|
| 66 |
+
|
| 67 |
+
if quantize:
|
| 68 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 69 |
+
|
| 70 |
+
tflite_model = converter.convert()
|
| 71 |
+
|
| 72 |
+
with open(output_path, "wb") as f:
|
| 73 |
+
f.write(tflite_model)
|
| 74 |
+
|
| 75 |
+
size_mb = Path(output_path).stat().st_size / 1e6
|
| 76 |
+
logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb)
|
src/training/__init__.py
ADDED
|
File without changes
|
src/training/callbacks.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom HuggingFace Trainer callbacks:
|
| 3 |
+
- EarlyStoppingOnWER: stops training when WER stops improving
|
| 4 |
+
- AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import TYPE_CHECKING
|
| 11 |
+
|
| 12 |
+
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EarlyStoppingOnWER(TrainerCallback):
|
| 21 |
+
"""
|
| 22 |
+
Stops training if eval WER does not improve by min_delta over `patience` evaluations.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
|
| 26 |
+
self.patience = patience
|
| 27 |
+
self.min_delta = min_delta
|
| 28 |
+
self._best_wer: float = float("inf")
|
| 29 |
+
self._no_improve_count: int = 0
|
| 30 |
+
|
| 31 |
+
def on_evaluate(
|
| 32 |
+
self,
|
| 33 |
+
args: TrainingArguments,
|
| 34 |
+
state: TrainerState,
|
| 35 |
+
control: TrainerControl,
|
| 36 |
+
metrics: dict,
|
| 37 |
+
**kwargs,
|
| 38 |
+
) -> None:
|
| 39 |
+
wer = metrics.get("eval_wer")
|
| 40 |
+
if wer is None:
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
if wer < self._best_wer - self.min_delta:
|
| 44 |
+
self._best_wer = wer
|
| 45 |
+
self._no_improve_count = 0
|
| 46 |
+
logger.info("WER improved to %.4f", wer)
|
| 47 |
+
else:
|
| 48 |
+
self._no_improve_count += 1
|
| 49 |
+
logger.info(
|
| 50 |
+
"WER %.4f did not improve (best: %.4f). No-improve count: %d/%d",
|
| 51 |
+
wer, self._best_wer, self._no_improve_count, self.patience,
|
| 52 |
+
)
|
| 53 |
+
if self._no_improve_count >= self.patience:
|
| 54 |
+
logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience)
|
| 55 |
+
control.should_training_stop = True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AdapterCheckpointCallback(TrainerCallback):
|
| 59 |
+
"""
|
| 60 |
+
Saves only the LoRA adapter weights on each checkpoint event.
|
| 61 |
+
Adapter weights are ~50MB vs ~3GB for the full model.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, adapter_output_dir: str) -> None:
|
| 65 |
+
self.adapter_output_dir = Path(adapter_output_dir)
|
| 66 |
+
|
| 67 |
+
def on_save(
|
| 68 |
+
self,
|
| 69 |
+
args: TrainingArguments,
|
| 70 |
+
state: TrainerState,
|
| 71 |
+
control: TrainerControl,
|
| 72 |
+
model,
|
| 73 |
+
**kwargs,
|
| 74 |
+
) -> None:
|
| 75 |
+
checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}"
|
| 76 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
# model is a PeftModel — save only adapter weights
|
| 79 |
+
if hasattr(model, "save_pretrained"):
|
| 80 |
+
model.save_pretrained(str(checkpoint_dir))
|
| 81 |
+
logger.info("Adapter checkpoint saved: %s", checkpoint_dir)
|
| 82 |
+
else:
|
| 83 |
+
logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.")
|