jefffffff9 commited on
Commit
76db545
·
0 Parent(s):

Initial commit: Sahel-Agri Voice AI

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/settings.local.json +7 -0
  2. .env.example +20 -0
  3. .gitignore +66 -0
  4. .vscode/extensions.json +5 -0
  5. README.md +39 -0
  6. app.py +611 -0
  7. configs/api_config.yaml +21 -0
  8. configs/base_config.yaml +30 -0
  9. configs/lora_bambara.yaml +19 -0
  10. configs/lora_fula.yaml +19 -0
  11. noise_samples/README.md +20 -0
  12. notebooks/bootstrap_repos.ipynb +308 -0
  13. notebooks/train_colab.ipynb +283 -0
  14. packages.txt +1 -0
  15. requirements.txt +50 -0
  16. scripts/export_onnx.py +67 -0
  17. scripts/run_data_pipeline.py +76 -0
  18. scripts/run_server.py +42 -0
  19. scripts/train_bambara.py +28 -0
  20. scripts/train_fula.py +29 -0
  21. scripts/verify_baseline.py +78 -0
  22. src/__init__.py +0 -0
  23. src/api/__init__.py +0 -0
  24. src/api/app.py +98 -0
  25. src/api/dependencies.py +20 -0
  26. src/api/middleware.py +47 -0
  27. src/api/routes/__init__.py +0 -0
  28. src/api/routes/health.py +25 -0
  29. src/api/routes/iot.py +90 -0
  30. src/api/routes/transcribe.py +74 -0
  31. src/api/schemas.py +36 -0
  32. src/data/__init__.py +0 -0
  33. src/data/agri_dictionary.py +92 -0
  34. src/data/augmentation.py +84 -0
  35. src/data/feature_extractor.py +89 -0
  36. src/data/waxal_loader.py +119 -0
  37. src/engine/__init__.py +0 -0
  38. src/engine/adapter_manager.py +106 -0
  39. src/engine/transcriber.py +132 -0
  40. src/engine/whisper_base.py +77 -0
  41. src/iot/__init__.py +0 -0
  42. src/iot/intent_parser.py +75 -0
  43. src/iot/sensor_bridge.py +121 -0
  44. src/iot/voice_responder.py +260 -0
  45. src/optimization/__init__.py +0 -0
  46. src/optimization/onnx_exporter.py +106 -0
  47. src/optimization/quantizer.py +95 -0
  48. src/optimization/tflite_converter.py +76 -0
  49. src/training/__init__.py +0 -0
  50. src/training/callbacks.py +83 -0
.claude/settings.local.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(pip show:*)"
5
+ ]
6
+ }
7
+ }
.env.example ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace read token (required for accessing google/waxal dataset)
2
+ HF_TOKEN=hf_your_token_here
3
+
4
+ # Model
5
+ MODEL_ID=openai/whisper-large-v3-turbo
6
+
7
+ # Adapter paths (relative to project root)
8
+ BAMBARA_ADAPTER_PATH=./adapters/bambara
9
+ FULA_ADAPTER_PATH=./adapters/fula
10
+
11
+ # IoT sensor API endpoint (leave empty to use mock data in development)
12
+ SENSOR_API_URL=
13
+
14
+ # FastAPI server
15
+ API_HOST=0.0.0.0
16
+ API_PORT=8000
17
+ LOG_LEVEL=INFO
18
+
19
+ # Device: "cuda" for GPU, "cpu" for CPU-only
20
+ DEVICE=cuda
.gitignore ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ .eggs/
11
+
12
+ # Environment
13
+ .env
14
+ venv/
15
+ .venv/
16
+ env/
17
+
18
+ # Model weights (large binary files)
19
+ *.pt
20
+ *.pth
21
+ *.bin
22
+ *.safetensors
23
+ *.ckpt
24
+
25
+ # ONNX / TFLite exports
26
+ *.onnx
27
+ *.tflite
28
+ models/onnx/
29
+ models/tflite/
30
+
31
+ # HuggingFace cache
32
+ data_cache/
33
+ .cache/
34
+
35
+ # Audio noise samples (user must provide their own)
36
+ noise_samples/*.wav
37
+ noise_samples/*.mp3
38
+ noise_samples/*.ogg
39
+
40
+ # Trained adapters (tracked separately or via DVC)
41
+ adapters/bambara/
42
+ adapters/fula/
43
+
44
+ # IDE
45
+ .vscode/settings.json
46
+ .idea/
47
+ *.code-workspace
48
+
49
+ # OS
50
+ .DS_Store
51
+ Thumbs.db
52
+
53
+ # Logs
54
+ *.log
55
+ logs/
56
+
57
+ # Local feedback data (audio + corrections live in HF Dataset repo, not git)
58
+ feedback/
59
+
60
+ # Local model downloads
61
+ models/
62
+
63
+ # Pytest
64
+ .pytest_cache/
65
+ htmlcov/
66
+ .coverage
.vscode/extensions.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "recommendations": [
3
+ "anthropic.claude-code"
4
+ ]
5
+ }
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sahel-Agri Voice AI
3
+ emoji: 🌾
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: "4.44.0"
8
+ app_file: app.py
9
+ hardware: cpu-basic
10
+ pinned: false
11
+ license: mit
12
+ tags:
13
+ - agriculture
14
+ - bambara
15
+ - fula
16
+ - speech-recognition
17
+ - text-to-speech
18
+ - west-africa
19
+ - low-resource-nlp
20
+ ---
21
+
22
+ # 🌾 Sahel-Agri Voice AI
23
+
24
+ Two-way voice assistant for Malian and Guinean farmers. Speak in **Bambara** or **Fula** — get agricultural insights spoken back in your language.
25
+
26
+ ## Features
27
+ - 🎙️ Voice input via microphone or file upload
28
+ - 🌍 Bambara (bam) and Fula (ful) speech recognition via Whisper + LoRA adapters
29
+ - 🔊 Native-language voice responses via Facebook MMS-TTS
30
+ - 📊 Soil, weather, irrigation, and pest alerts from IoT sensors
31
+ - 💾 Feedback saved to HuggingFace Dataset for continuous improvement
32
+
33
+ ## Languages supported
34
+ | Language | STT | TTS |
35
+ |----------|-----|-----|
36
+ | Bambara (bam) | ✅ Whisper + LoRA | ✅ facebook/mms-tts-bam |
37
+ | Fula (ful) | ✅ Whisper + LoRA | ✅ facebook/mms-tts-ful |
38
+ | French (fr) | ✅ Whisper | ✅ facebook/mms-tts-fra |
39
+ | English (en) | ✅ Whisper | ✅ facebook/mms-tts-eng |
app.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sahel-Agri Voice AI — HuggingFace Spaces (ZeroGPU)
3
+ Two-way voice assistant: Bambara / Fula / French / English → voice response
4
+
5
+ Environment variables (set in Space Settings → Secrets):
6
+ HF_TOKEN — HF write-access token
7
+ FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private)
8
+ ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private)
9
+ WHISPER_MODEL_ID — default: openai/whisper-large-v3-turbo
10
+ (use openai/whisper-base for local CPU testing)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import io
16
+ import json
17
+ import os
18
+ import sys
19
+ import tempfile
20
+ import threading
21
+ from datetime import datetime, timezone
22
+ from pathlib import Path
23
+
24
+ import gradio as gr
25
+ import numpy as np
26
+
27
+ ROOT = Path(__file__).parent
28
+ sys.path.insert(0, str(ROOT))
29
+
30
+ # ── env ───────────────────────────────────────────────────────────────────────
31
+ HF_TOKEN = os.environ.get("HF_TOKEN")
32
+ FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback")
33
+ ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters")
34
+ # whisper-small: ~10s on cpu-basic, good multilingual quality.
35
+ # Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
36
+ WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
37
+
38
+ # On local CPU (no HF_TOKEN / no spaces package) fall back gracefully
39
+ _ON_SPACES = os.environ.get("SPACE_ID") is not None
40
+
41
+ SUPPORTED_LANGUAGES = {
42
+ "Bambara (bam)": "bam",
43
+ "Fula (ful)": "ful",
44
+ "French / Français": "fr",
45
+ "English": "en",
46
+ }
47
+
48
+ # ── ZeroGPU decorator (no-op locally) ────────────────────────────────────────
49
+ try:
50
+ import spaces # type: ignore
51
+ _gpu = spaces.GPU(duration=55)
52
+ except ImportError:
53
+ def _gpu(fn): # local fallback: plain function
54
+ return fn
55
+
56
+ # ── Module-level model state (CPU-resident between requests) ─────────────────
57
+ _whisper_model = None # WhisperForConditionalGeneration (base)
58
+ _whisper_processor = None
59
+ _adapter_manager = None # AdapterManager (wraps base model with PEFT if adapters loaded)
60
+ _model_lock = threading.Lock()
61
+ _model_status = "not loaded"
62
+ _adapters_loaded = set() # set of language codes with loaded adapters, e.g. {"bam", "ful"}
63
+
64
+ from src.tts.mms_tts import MMSTTSEngine
65
+ from src.iot.intent_parser import IntentParser
66
+ from src.iot.sensor_bridge import SensorBridge
67
+ from src.iot.voice_responder import VoiceResponder
68
+
69
+ _tts = MMSTTSEngine()
70
+ _intent_parser = IntentParser()
71
+ _sensor_bridge = SensorBridge()
72
+
73
+ # HF API — only instantiate when token present
74
+ _hf_api = None
75
+ if HF_TOKEN:
76
+ from huggingface_hub import HfApi
77
+ _hf_api = HfApi(token=HF_TOKEN)
78
+
79
+
80
+ # ── Model loading ─────────────────────────────────────────────────────────────
81
+
82
+ def _do_load_whisper():
83
+ global _whisper_model, _whisper_processor, _adapter_manager, _model_status
84
+ import torch
85
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
86
+ from src.engine.adapter_manager import AdapterManager
87
+
88
+ _model_status = "loading…"
89
+ try:
90
+ _whisper_processor = WhisperProcessor.from_pretrained(
91
+ WHISPER_MODEL_ID, token=HF_TOKEN
92
+ )
93
+ _whisper_model = WhisperForConditionalGeneration.from_pretrained(
94
+ WHISPER_MODEL_ID,
95
+ torch_dtype=torch.float32,
96
+ token=HF_TOKEN,
97
+ )
98
+ _whisper_model.eval()
99
+
100
+ # Create the AdapterManager wrapping the base model
101
+ _adapter_manager = AdapterManager(base_model=_whisper_model, config={})
102
+
103
+ # Try to load adapters from the local adapter repo snapshot (if already downloaded)
104
+ _try_load_local_adapters()
105
+
106
+ _model_status = f"ready ({WHISPER_MODEL_ID})"
107
+ except Exception as e:
108
+ _model_status = f"error: {e}"
109
+
110
+
111
+ def _try_load_local_adapters() -> None:
112
+ """Load any adapter snapshots that are already on disk (downloaded previously)."""
113
+ global _adapters_loaded
114
+ if _adapter_manager is None:
115
+ return
116
+ if not ADAPTER_REPO_ID:
117
+ return
118
+ try:
119
+ from huggingface_hub import try_to_load_from_cache
120
+ lang_dirs = {"bam": "adapters/bambara", "ful": "adapters/fula"}
121
+ for lang, subdir in lang_dirs.items():
122
+ cached = try_to_load_from_cache(
123
+ repo_id=ADAPTER_REPO_ID,
124
+ filename=f"{subdir}/adapter_config.json",
125
+ repo_type="model",
126
+ token=HF_TOKEN,
127
+ )
128
+ if cached:
129
+ import os
130
+ adapter_path = str(os.path.dirname(cached))
131
+ _adapter_manager.register(lang, adapter_path)
132
+ try:
133
+ _adapter_manager.load_adapter(lang)
134
+ _adapters_loaded.add(lang)
135
+ except Exception:
136
+ pass
137
+ except Exception:
138
+ pass # Adapters not cached yet — will load after first Hub download
139
+
140
+
141
+ def _ensure_whisper_loaded():
142
+ """Load Whisper to CPU in a background thread on first call. Non-blocking."""
143
+ global _model_status
144
+ with _model_lock:
145
+ if _whisper_model is None and "loading" not in _model_status and "error" not in _model_status:
146
+ t = threading.Thread(target=_do_load_whisper, daemon=True)
147
+ t.start()
148
+ return _model_status
149
+
150
+
151
+ def get_model_status() -> str:
152
+ s = _ensure_whisper_loaded()
153
+ if "ready" in s:
154
+ return f"🟢 {s}"
155
+ if "loading" in s:
156
+ return f"🟡 {s}"
157
+ if "error" in s:
158
+ return f"🔴 {s}"
159
+ return f"⚪ {s}"
160
+
161
+
162
+ # ── Core GPU pipeline ─────────────────────────────────────────────────────────
163
+
164
+ @_gpu
165
+ def _run_pipeline(audio_path: str, language_code: str):
166
+ """
167
+ Full STT → Intent → Sensor → TTS pipeline.
168
+ Decorated with @spaces.GPU(duration=55) on HF Spaces; plain function locally.
169
+ Returns: (transcript, response_text, (sample_rate, wav_np))
170
+ """
171
+ import asyncio
172
+ import torch
173
+
174
+ device = "cuda" if torch.cuda.is_available() else "cpu"
175
+
176
+ # ── 1. Whisper STT ────────────────────────────────────────────────────────
177
+ if _whisper_model is None:
178
+ return "⏳ Model still loading…", "", None
179
+
180
+ import librosa
181
+
182
+ audio_np, _ = librosa.load(audio_path, sr=16000, mono=True)
183
+
184
+ # Use adapter-wrapped model if an adapter for this language is loaded;
185
+ # otherwise fall back to base Whisper.
186
+ if _adapter_manager is not None and language_code in _adapters_loaded:
187
+ _adapter_manager.activate(language_code)
188
+ active_model = _adapter_manager.get_model()
189
+ else:
190
+ active_model = _whisper_model
191
+
192
+ active_model.to(device)
193
+ with _model_lock:
194
+ inputs = _whisper_processor.feature_extractor(
195
+ audio_np, sampling_rate=16000, return_tensors="pt"
196
+ )
197
+ input_features = inputs.input_features.to(device)
198
+
199
+ # Bambara and Fula have no Whisper language token — pass None so the model
200
+ # auto-detects or falls back to multilingual decoding.
201
+ if language_code in ("bam", "ful"):
202
+ forced_ids = None
203
+ else:
204
+ forced_ids = _whisper_processor.get_decoder_prompt_ids(
205
+ language=language_code, task="transcribe"
206
+ )
207
+
208
+ with torch.no_grad():
209
+ predicted_ids = active_model.generate(
210
+ input_features,
211
+ forced_decoder_ids=forced_ids if forced_ids else None,
212
+ max_new_tokens=256,
213
+ )
214
+
215
+ transcript = _whisper_processor.batch_decode(
216
+ predicted_ids, skip_special_tokens=True
217
+ )[0].strip()
218
+
219
+ # Free GPU VRAM before TTS
220
+ active_model.to("cpu")
221
+ if device == "cuda":
222
+ torch.cuda.empty_cache()
223
+
224
+ # ── 2. Intent + sensor data (CPU) ─────────────────────────────────────────
225
+ intent = _intent_parser.parse(transcript, language=language_code)
226
+
227
+ try:
228
+ loop = asyncio.new_event_loop()
229
+ sensor_data = loop.run_until_complete(_sensor_bridge.fetch(intent))
230
+ loop.close()
231
+ except Exception:
232
+ from src.iot.sensor_bridge import SensorData
233
+ sensor_data = SensorData(sensor_type="soil", values={
234
+ "moisture_pct": 45.0, "ph": 6.5, "temperature_c": 28.0
235
+ })
236
+
237
+ responder = VoiceResponder(language=language_code)
238
+ response_text = responder.generate_response(intent, sensor_data)
239
+
240
+ # ── 3. MMS-TTS (GPU) ──────────────────────────────────────────────────────
241
+ wav_np, sample_rate = _tts.synthesize(response_text, language_code, device=device)
242
+
243
+ return transcript, response_text, (sample_rate, wav_np)
244
+
245
+
246
+ # ── HF Hub feedback persistence ───────────────────────────────────────────────
247
+
248
+ def _save_feedback_to_hub(
249
+ audio_path: str | None,
250
+ transcript: str,
251
+ corrected_text: str,
252
+ response_text: str,
253
+ rating: int,
254
+ notes: str,
255
+ language_label: str,
256
+ ) -> str:
257
+ language_code = SUPPORTED_LANGUAGES.get(language_label, "bam")
258
+
259
+ if not corrected_text.strip():
260
+ return "⚠️ Corrected text is empty."
261
+
262
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f")
263
+
264
+ record = {
265
+ "id": timestamp,
266
+ "timestamp": datetime.now(timezone.utc).isoformat(),
267
+ "language": language_code,
268
+ "audio_file": f"audio/{language_code}_{timestamp}.wav",
269
+ "whisper_output": transcript,
270
+ "corrected_text": corrected_text.strip(),
271
+ "response_text": response_text,
272
+ "rating": rating,
273
+ "notes": notes.strip(),
274
+ "is_correction": transcript.strip() != corrected_text.strip(),
275
+ "model": WHISPER_MODEL_ID,
276
+ }
277
+
278
+ if _hf_api is None:
279
+ # Local: save to disk instead
280
+ fb_dir = ROOT / "feedback"
281
+ fb_dir.mkdir(exist_ok=True)
282
+ (fb_dir / "audio").mkdir(exist_ok=True)
283
+ corrections_path = fb_dir / "corrections.jsonl"
284
+ if audio_path:
285
+ import shutil
286
+ shutil.copy2(audio_path, fb_dir / "audio" / f"{language_code}_{timestamp}.wav")
287
+ with open(corrections_path, "a", encoding="utf-8") as f:
288
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
289
+ total = sum(1 for _ in open(corrections_path, encoding="utf-8"))
290
+ return f"✅ Saved locally (#{total}) — HF_TOKEN not set, Hub upload skipped."
291
+
292
+ try:
293
+ # Upload audio
294
+ if audio_path:
295
+ _hf_api.upload_file(
296
+ path_or_fileobj=audio_path,
297
+ path_in_repo=f"audio/{language_code}_{timestamp}.wav",
298
+ repo_id=FEEDBACK_REPO_ID,
299
+ repo_type="dataset",
300
+ )
301
+
302
+ # Download → append → re-upload corrections.jsonl (with retry on conflict)
303
+ from huggingface_hub import hf_hub_download
304
+ for attempt in range(2):
305
+ try:
306
+ local_jsonl = hf_hub_download(
307
+ repo_id=FEEDBACK_REPO_ID,
308
+ filename="corrections.jsonl",
309
+ repo_type="dataset",
310
+ token=HF_TOKEN,
311
+ )
312
+ with open(local_jsonl, encoding="utf-8") as f:
313
+ existing = f.read()
314
+ except Exception:
315
+ existing = ""
316
+
317
+ updated = existing + json.dumps(record, ensure_ascii=False) + "\n"
318
+ buf = io.BytesIO(updated.encode("utf-8"))
319
+
320
+ try:
321
+ _hf_api.upload_file(
322
+ path_or_fileobj=buf,
323
+ path_in_repo="corrections.jsonl",
324
+ repo_id=FEEDBACK_REPO_ID,
325
+ repo_type="dataset",
326
+ )
327
+ break
328
+ except Exception as e:
329
+ if attempt == 1:
330
+ return f"⚠️ Audio uploaded but corrections.jsonl update failed: {e}"
331
+
332
+ total = updated.count("\n")
333
+ return f"✅ Saved to Hub (#{total}) — {FEEDBACK_REPO_ID}"
334
+
335
+ except Exception as e:
336
+ return f"❌ Hub upload error: {e}"
337
+
338
+
339
+ # ── Adapter reload ────────────────────────────────────────────────────────────
340
+
341
+ def _reload_adapters_from_hub() -> str:
342
+ global _adapters_loaded
343
+ if _hf_api is None:
344
+ return "⚠️ HF_TOKEN not set — cannot download adapters."
345
+ if _adapter_manager is None:
346
+ return "⏳ Base model not loaded yet — wait for model to finish loading and try again."
347
+ try:
348
+ from huggingface_hub import snapshot_download
349
+ local_dir = snapshot_download(
350
+ repo_id=ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN
351
+ )
352
+ results = []
353
+ for lang, subdir in (("bam", "adapters/bambara"), ("ful", "adapters/fula")):
354
+ adapter_path = Path(local_dir) / subdir
355
+ if not adapter_path.exists():
356
+ results.append(f"⚠️ {lang}: `{subdir}` not found in repo")
357
+ continue
358
+ # Check that this looks like a valid PEFT adapter
359
+ if not (adapter_path / "adapter_config.json").exists():
360
+ results.append(f"⚠️ {lang}: `{subdir}` missing adapter_config.json — run training first")
361
+ continue
362
+ try:
363
+ _adapter_manager.register(lang, str(adapter_path))
364
+ _adapter_manager.load_adapter(lang)
365
+ _adapters_loaded.add(lang)
366
+ results.append(f"✅ {lang}: adapter loaded from `{subdir}`")
367
+ except Exception as e:
368
+ results.append(f"❌ {lang}: load failed — {e}")
369
+
370
+ summary = "\n".join(results)
371
+ active = ", ".join(_adapters_loaded) if _adapters_loaded else "none"
372
+ return f"{summary}\n\n**Active adapters:** {active}\n**Repo:** `{ADAPTER_REPO_ID}`"
373
+ except Exception as e:
374
+ return f"❌ Adapter reload failed: {e}"
375
+
376
+
377
+ def _get_adapter_status() -> str:
378
+ lines = []
379
+
380
+ # Show which adapters are currently active in memory
381
+ if _adapters_loaded:
382
+ lines.append(f"**Active adapters (in memory):** {', '.join(sorted(_adapters_loaded))}")
383
+ else:
384
+ lines.append("**Active adapters:** none — using base Whisper")
385
+
386
+ if _hf_api is None:
387
+ lines.append("_HF_TOKEN not set — Hub check skipped._")
388
+ return "\n".join(lines)
389
+
390
+ try:
391
+ from huggingface_hub import list_repo_files
392
+ files = list(list_repo_files(ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN))
393
+ bam_ok = any("bambara" in f and "adapter_config" in f for f in files)
394
+ ful_ok = any("fula" in f and "adapter_config" in f for f in files)
395
+ lines += [
396
+ f"\n**Hub repo:** `{ADAPTER_REPO_ID}`",
397
+ f"- Bambara (bam): {'✅ trained adapter present' if bam_ok else '⚠️ not yet trained — run bootstrap notebook'}",
398
+ f"- Fula (ful): {'✅ trained adapter present' if ful_ok else '⚠️ not yet trained — run bootstrap notebook'}",
399
+ ]
400
+ if bam_ok or ful_ok:
401
+ lines.append("\n_Click **Reload Adapters** to activate them._")
402
+ except Exception as e:
403
+ lines.append(f"_Could not read Hub repo: {e}_")
404
+
405
+ return "\n".join(lines)
406
+
407
+
408
+ # ── Main ask handler ──────────────────────────────────────────────────────────
409
+
410
+ def handle_ask(audio_path, language_label):
411
+ if audio_path is None:
412
+ return "⚠️ No audio — press Record or upload a file.", "", None
413
+
414
+ language_code = SUPPORTED_LANGUAGES.get(language_label, "bam")
415
+ status = _ensure_whisper_loaded()
416
+
417
+ if _whisper_model is None:
418
+ return f"⏳ Model loading ({status}). Wait a moment and try again.", "", None
419
+
420
+ try:
421
+ transcript, response_text, audio_out = _run_pipeline(audio_path, language_code)
422
+ return transcript, response_text, audio_out
423
+ except Exception as e:
424
+ return f"❌ {e}", "", None
425
+
426
+
427
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
428
+
429
+ def build_ui() -> gr.Blocks:
430
+ with gr.Blocks(title="Sahel-Agri Voice AI") as demo:
431
+ gr.Markdown("# 🌾 Sahel-Agri Voice AI")
432
+ gr.Markdown(
433
+ "Speak in **Bambara** or **Fula** — get agricultural insights spoken back "
434
+ "in your language. Also supports French and English."
435
+ )
436
+
437
+ model_status_box = gr.Textbox(
438
+ value=get_model_status,
439
+ label="Model status",
440
+ interactive=False,
441
+ every=3,
442
+ )
443
+
444
+ with gr.Tabs():
445
+
446
+ # ── Tab 1: Voice Assistant ────────────────────────────────────────
447
+ with gr.TabItem("🎙️ Voice Assistant"):
448
+ with gr.Row():
449
+ with gr.Column(scale=1):
450
+ language_dd = gr.Dropdown(
451
+ choices=list(SUPPORTED_LANGUAGES.keys()),
452
+ value="Bambara (bam)",
453
+ label="Language / Kan",
454
+ )
455
+ audio_input = gr.Audio(
456
+ sources=["microphone", "upload"],
457
+ type="filepath",
458
+ label="Record or upload audio",
459
+ )
460
+ ask_btn = gr.Button("▶ Ask / Ɲinɛ", variant="primary")
461
+
462
+ with gr.Column(scale=1):
463
+ transcript_box = gr.Textbox(
464
+ label="Whisper heard",
465
+ lines=3,
466
+ placeholder="Your words will appear here…",
467
+ interactive=False,
468
+ )
469
+ response_box = gr.Textbox(
470
+ label="Response / Jaabi",
471
+ lines=3,
472
+ placeholder="Agricultural advice will appear here…",
473
+ interactive=False,
474
+ )
475
+ audio_output = gr.Audio(
476
+ label="Voice response",
477
+ autoplay=True,
478
+ interactive=False,
479
+ )
480
+
481
+ ask_btn.click(
482
+ fn=handle_ask,
483
+ inputs=[audio_input, language_dd],
484
+ outputs=[transcript_box, response_box, audio_output],
485
+ )
486
+
487
+ # ── Tab 2: Feedback & Correction ─────────────────────────────────
488
+ with gr.TabItem("📝 Feedback & Correction"):
489
+ gr.Markdown(
490
+ "Help improve the model by correcting transcription errors. "
491
+ "Your audio and corrections are saved to the training dataset."
492
+ )
493
+ with gr.Row():
494
+ with gr.Column():
495
+ fb_lang = gr.Dropdown(
496
+ choices=list(SUPPORTED_LANGUAGES.keys()),
497
+ value="Bambara (bam)",
498
+ label="Language",
499
+ )
500
+ fb_audio = gr.Audio(
501
+ sources=["microphone", "upload"],
502
+ type="filepath",
503
+ label="Audio (re-record or upload)",
504
+ )
505
+ fb_transcript = gr.Textbox(
506
+ label="Whisper output (what it heard)",
507
+ lines=3,
508
+ placeholder="Paste or type what Whisper said…",
509
+ )
510
+ fb_corrected = gr.Textbox(
511
+ label="Corrected transcription (what was actually said)",
512
+ lines=3,
513
+ placeholder="Type the correct text here…",
514
+ )
515
+
516
+ with gr.Column():
517
+ fb_response = gr.Textbox(
518
+ label="Response text (optional — for rating)",
519
+ lines=2,
520
+ placeholder="Copy the response from Tab 1…",
521
+ )
522
+ fb_rating = gr.Slider(
523
+ minimum=1, maximum=5, step=1, value=3,
524
+ label="Response quality (1 = poor, 5 = excellent)",
525
+ )
526
+ fb_notes = gr.Textbox(
527
+ label="Notes (optional)",
528
+ lines=2,
529
+ placeholder="e.g. noisy background, strong accent…",
530
+ )
531
+ save_btn = gr.Button("💾 Save to Dataset", variant="secondary")
532
+ save_status = gr.Textbox(
533
+ label="Save status", interactive=False, lines=2
534
+ )
535
+
536
+ save_btn.click(
537
+ fn=_save_feedback_to_hub,
538
+ inputs=[
539
+ fb_audio, fb_transcript, fb_corrected,
540
+ fb_response, fb_rating, fb_notes, fb_lang,
541
+ ],
542
+ outputs=[save_status],
543
+ )
544
+
545
+ # ── Tab 3: Training Status ────────────────────────────────────────
546
+ with gr.TabItem("🔧 Training Status"):
547
+ gr.Markdown(
548
+ "After collecting ≥10 corrections per language, run the training "
549
+ "notebook on Google Colab (free GPU), then reload adapters here."
550
+ )
551
+ adapter_status_md = gr.Markdown(value=_get_adapter_status())
552
+ reload_btn = gr.Button("🔄 Reload Adapters from Hub")
553
+ reload_out = gr.Markdown()
554
+
555
+ gr.Markdown("---")
556
+ gr.Markdown(
557
+ "**Training notebook**: "
558
+ "`notebooks/train_colab.ipynb` — open in Colab, run all cells."
559
+ )
560
+ gr.Markdown(
561
+ "**Feedback dataset**: "
562
+ f"`{FEEDBACK_REPO_ID}` (private, auto-updated on each save)"
563
+ )
564
+ gr.Markdown(
565
+ "**Adapter repo**: "
566
+ f"`{ADAPTER_REPO_ID}` (private, updated after each training run)"
567
+ )
568
+
569
+ reload_btn.click(
570
+ fn=_reload_adapters_from_hub,
571
+ outputs=[reload_out],
572
+ )
573
+ reload_btn.click(
574
+ fn=_get_adapter_status,
575
+ outputs=[adapter_status_md],
576
+ )
577
+
578
+ return demo
579
+
580
+
581
+ # ── Entry point ───────────────────────────────────────────────────────────────
582
+
583
+ if __name__ == "__main__":
584
+ from dotenv import load_dotenv
585
+ load_dotenv()
586
+
587
+ # Re-read env after dotenv
588
+ HF_TOKEN = os.environ.get("HF_TOKEN")
589
+ FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback")
590
+ ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters")
591
+ WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
592
+
593
+ if HF_TOKEN:
594
+ from huggingface_hub import HfApi
595
+ _hf_api = HfApi(token=HF_TOKEN)
596
+
597
+ # Kick off background model load immediately
598
+ _ensure_whisper_loaded()
599
+
600
+ print(f"Whisper model : {WHISPER_MODEL_ID}")
601
+ print(f"Feedback repo : {FEEDBACK_REPO_ID}")
602
+ print(f"Adapter repo : {ADAPTER_REPO_ID}")
603
+ print(f"HF_TOKEN set : {'yes' if HF_TOKEN else 'no (local-only mode)'}")
604
+ print()
605
+
606
+ demo = build_ui()
607
+ demo.launch(
608
+ server_port=9001,
609
+ inbrowser=True,
610
+ share=False,
611
+ )
configs/api_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ server:
2
+ host: "0.0.0.0"
3
+ port: 8000
4
+ workers: 1 # Single worker: shares GPU model in memory
5
+ timeout_keep_alive: 30
6
+
7
+ inference:
8
+ default_language: "bam"
9
+ max_audio_size_mb: 10
10
+ supported_languages:
11
+ - "bam"
12
+ - "ful"
13
+
14
+ iot:
15
+ sensor_poll_timeout_s: 5
16
+ response_language: "fr" # French for farmer-facing TTS output
17
+ intent_confidence_threshold: 0.7
18
+
19
+ rate_limit:
20
+ requests_per_minute: 60
21
+ burst: 10
configs/base_config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ id: "openai/whisper-large-v3-turbo"
3
+ task: "transcribe"
4
+ max_new_tokens: 128
5
+ chunk_length_s: 30
6
+
7
+ training:
8
+ output_dir: "./adapters"
9
+ per_device_train_batch_size: 4
10
+ gradient_accumulation_steps: 4
11
+ warmup_steps: 200
12
+ max_steps: 4000
13
+ save_steps: 500
14
+ eval_steps: 500
15
+ learning_rate: 1.0e-4
16
+ fp16: true
17
+ # CRITICAL on Windows: multiprocessing spawn breaks with tokenizers
18
+ dataloader_num_workers: 0
19
+
20
+ audio:
21
+ sample_rate: 16000
22
+ max_duration_s: 30
23
+ noise_snr_db_range: [5, 20]
24
+ augmentation_prob: 0.6
25
+
26
+ paths:
27
+ data_cache: "./data_cache"
28
+ adapters: "./adapters"
29
+ models: "./models"
30
+ noise_samples: "./noise_samples"
configs/lora_bambara.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ language: "bam"
2
+ language_code: "bm" # ISO 639-1 code used for Whisper forced_decoder_ids
3
+ dataset_subset: "bam"
4
+ adapter_name: "bambara"
5
+ output_dir: "./adapters/bambara"
6
+
7
+ lora:
8
+ r: 32
9
+ lora_alpha: 64
10
+ target_modules:
11
+ - "q_proj"
12
+ - "v_proj"
13
+ - "k_proj"
14
+ - "out_proj"
15
+ - "fc1"
16
+ - "fc2"
17
+ lora_dropout: 0.05
18
+ bias: "none"
19
+ task_type: "SEQ_2_SEQ_LM"
configs/lora_fula.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ language: "ful"
2
+ language_code: "ff" # ISO 639-1 code used for Whisper forced_decoder_ids
3
+ dataset_subset: "ful"
4
+ adapter_name: "fula"
5
+ output_dir: "./adapters/fula"
6
+
7
+ lora:
8
+ r: 16 # Smaller rank — Fula dataset is smaller than Bambara
9
+ lora_alpha: 32
10
+ target_modules:
11
+ - "q_proj"
12
+ - "v_proj"
13
+ - "k_proj"
14
+ - "out_proj"
15
+ - "fc1"
16
+ - "fc2"
17
+ lora_dropout: 0.05
18
+ bias: "none"
19
+ task_type: "SEQ_2_SEQ_LM"
noise_samples/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Field Noise Samples
2
+
3
+ Place `.wav` audio files here to enable realistic field-noise augmentation during training.
4
+
5
+ ## Required Files (16kHz mono, any duration ≥5s)
6
+ - `tractor_engine.wav` — diesel tractor idling or working
7
+ - `wind_field.wav` — wind in open farmland
8
+ - `livestock_ambient.wav` — cattle, goats, or chickens in background
9
+
10
+ ## Suggested Sources
11
+ - [Freesound.org](https://freesound.org) — search "tractor", "wind field", "livestock ambient" (filter by CC0 / CC-BY)
12
+ - Field recordings from partner NGOs or agricultural organizations in Mali/Guinea
13
+
14
+ ## Licensing Note
15
+ Ensure all audio files are licensed for use in ML training datasets.
16
+ CC0 (public domain) or CC-BY are preferred.
17
+
18
+ ## Without Noise Files
19
+ The augmenter will fall back to Gaussian noise only.
20
+ Training will still work but model robustness to real-world conditions may be reduced.
notebooks/bootstrap_repos.ipynb ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "version": "3.10.0"
13
+ },
14
+ "colab": {
15
+ "provenance": [],
16
+ "gpuType": "T4"
17
+ },
18
+ "accelerator": "GPU"
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "id": "cell-title",
24
+ "metadata": {},
25
+ "source": [
26
+ "# 🌾 Sahel-Agri Voice AI — One-Time Bootstrap\n",
27
+ "\n",
28
+ "**Run this notebook ONCE** before deploying your Space. It:\n",
29
+ "\n",
30
+ "1. Creates the three HuggingFace repos (`sahel-agri-feedback`, `sahel-agri-adapters`, `sahel-agri-voice`)\n",
31
+ "2. Seeds the feedback dataset with a `corrections.jsonl` placeholder\n",
32
+ "3. Trains v0 LoRA adapters for **Bambara** and **Fula** on the full Google Waxal dataset\n",
33
+ "4. Pushes adapters to `ous-sow/sahel-agri-adapters`\n",
34
+ "\n",
35
+ "After this notebook completes, push your project code to the Space and your app will start\n",
36
+ "with working Bambara/Fula speech recognition from day 1 — **no user corrections needed yet**.\n",
37
+ "\n",
38
+ "For subsequent improvement runs (after collecting farmer feedback), use `train_colab.ipynb`.\n",
39
+ "\n",
40
+ "---\n",
41
+ "**Before running:** Runtime → Change runtime type → **T4 GPU**"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "cell-gpu-check",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# Cell 1 — GPU check\n",
52
+ "import subprocess\n",
53
+ "result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
54
+ "if result.returncode != 0:\n",
55
+ " raise RuntimeError('No GPU! Runtime → Change runtime type → T4 GPU')\n",
56
+ "print(result.stdout[:500])\n",
57
+ "print('✅ GPU ready')"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "cell-install",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "# Cell 2 — Install dependencies\n",
68
+ "!pip install -q \\\n",
69
+ " torch==2.11.0 torchaudio==2.11.0 \\\n",
70
+ " transformers==5.5.0 datasets==4.8.4 \\\n",
71
+ " accelerate==1.13.0 evaluate==0.4.2 \\\n",
72
+ " huggingface-hub==1.9.0 peft==0.18.1 \\\n",
73
+ " librosa==0.10.2 soundfile==0.12.1 \\\n",
74
+ " jiwer==3.0.4 pyyaml==6.0.2\n",
75
+ "print('✅ Packages installed')"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "cell-hf-login",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": "# Cell 3 — HuggingFace login\n# Colab: 🔑 icon (left sidebar) → Add new secret → name=HF_TOKEN\nimport os\ntry:\n from google.colab import userdata # type: ignore\n HF_TOKEN = userdata.get('HF_TOKEN')\nexcept Exception:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError(\n 'HF_TOKEN not found.\\n'\n 'Colab: click the 🔑 icon → Add new secret → name=HF_TOKEN'\n )\n\nfrom huggingface_hub import login, HfApi\nlogin(token=HF_TOKEN, add_to_git_credential=False)\napi = HfApi(token=HF_TOKEN)\n\nHF_USERNAME = 'ous-sow'\nFEEDBACK_REPO_ID = f'{HF_USERNAME}/sahel-agri-feedback'\nADAPTER_REPO_ID = f'{HF_USERNAME}/sahel-agri-adapters'\nSPACE_REPO_ID = f'{HF_USERNAME}/sahel-agri-voice'\n# whisper-small trains on Colab T4 in ~25 min and runs on CPU in ~10s.\n# Change to 'openai/whisper-large-v3-turbo' only if you upgrade to a GPU Space.\nWHISPER_MODEL_ID = 'openai/whisper-small'\n\nprint(f'✅ Logged in as {HF_USERNAME}')"
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "cell-create-repos",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# Cell 4 — Create HuggingFace repos (skips if they already exist)\n",
94
+ "from huggingface_hub import RepoUrl\n",
95
+ "\n",
96
+ "def create_repo_if_missing(repo_id, repo_type, private=True):\n",
97
+ " try:\n",
98
+ " url = api.create_repo(\n",
99
+ " repo_id=repo_id,\n",
100
+ " repo_type=repo_type,\n",
101
+ " private=private,\n",
102
+ " exist_ok=True,\n",
103
+ " )\n",
104
+ " print(f' ✅ {repo_type}: {repo_id}')\n",
105
+ " return url\n",
106
+ " except Exception as e:\n",
107
+ " print(f' ⚠️ {repo_id}: {e}')\n",
108
+ "\n",
109
+ "print('Creating repos...')\n",
110
+ "create_repo_if_missing(FEEDBACK_REPO_ID, 'dataset', private=True)\n",
111
+ "create_repo_if_missing(ADAPTER_REPO_ID, 'model', private=True)\n",
112
+ "create_repo_if_missing(SPACE_REPO_ID, 'space', private=False)\n",
113
+ "\n",
114
+ "# Seed the feedback dataset with an empty corrections.jsonl\n",
115
+ "import io\n",
116
+ "try:\n",
117
+ " api.upload_file(\n",
118
+ " path_or_fileobj=io.BytesIO(b''),\n",
119
+ " path_in_repo='corrections.jsonl',\n",
120
+ " repo_id=FEEDBACK_REPO_ID,\n",
121
+ " repo_type='dataset',\n",
122
+ " commit_message='Init: empty corrections.jsonl',\n",
123
+ " )\n",
124
+ " print(f' ✅ {FEEDBACK_REPO_ID}/corrections.jsonl initialised')\n",
125
+ "except Exception as e:\n",
126
+ " print(f' ⚠️ corrections.jsonl upload: {e} (may already exist)')"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "cell-clone-space",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "# Cell 5 — Clone Space code (so we can use src/ and configs/)\n",
137
+ "# If the Space is brand new and has no code yet, clone from the local zip instead.\n",
138
+ "import sys\n",
139
+ "from pathlib import Path\n",
140
+ "from huggingface_hub import snapshot_download\n",
141
+ "\n",
142
+ "try:\n",
143
+ " space_dir = Path(snapshot_download(\n",
144
+ " repo_id=SPACE_REPO_ID, repo_type='space', token=HF_TOKEN\n",
145
+ " ))\n",
146
+ " print(f'Space code: {space_dir}')\n",
147
+ "except Exception as e:\n",
148
+ " print(f'Could not download Space ({e})')\n",
149
+ " print('Uploading project code to Space first...')\n",
150
+ " # If you have the project on Colab already (e.g. mounted Drive), set:\n",
151
+ " # space_dir = Path('/content/drive/MyDrive/voice-model')\n",
152
+ " # Otherwise upload via git (see README step 6) and re-run this cell.\n",
153
+ " raise RuntimeError(\n",
154
+ " 'Push your project to the Space first:\\n'\n",
155
+ " ' git remote add space https://huggingface.co/spaces/ous-sow/sahel-agri-voice\\n'\n",
156
+ " ' git push space main\\n'\n",
157
+ " 'Then re-run this notebook.'\n",
158
+ " )\n",
159
+ "\n",
160
+ "sys.path.insert(0, str(space_dir))"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "id": "cell-train-bam",
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "# Cell 6 — Train v0 Bambara adapter on full Waxal (bam)\n",
171
+ "#\n",
172
+ "# Uses streaming — Waxal is ~4h of audio, we cap at 2000 samples for Colab budget.\n",
173
+ "# Full training (~4000 steps) on the entire dataset: use a Kaggle P100 (12h limit).\n",
174
+ "import os, yaml\n",
175
+ "os.environ['HF_TOKEN'] = HF_TOKEN\n",
176
+ "\n",
177
+ "from src.training.trainer import WhisperLoRATrainer\n",
178
+ "\n",
179
+ "WAXAL_CAP = 2000 # raise to 10000+ on Kaggle for a stronger v0 model\n",
180
+ "\n",
181
+ "base_cfg = str(space_dir / 'configs' / 'base_config.yaml')\n",
182
+ "bam_cfg_src = str(space_dir / 'configs' / 'lora_bambara.yaml')\n",
183
+ "bam_out = '/tmp/sahel_adapter_bam'\n",
184
+ "\n",
185
+ "# Override output_dir\n",
186
+ "with open(bam_cfg_src) as f:\n",
187
+ " bam_config = yaml.safe_load(f)\n",
188
+ "bam_config['output_dir'] = bam_out\n",
189
+ "tmp_bam_cfg = '/tmp/lora_bam.yaml'\n",
190
+ "with open(tmp_bam_cfg, 'w') as f:\n",
191
+ " yaml.dump(bam_config, f)\n",
192
+ "\n",
193
+ "# Also override max_steps in base config to match Waxal cap\n",
194
+ "with open(base_cfg) as f:\n",
195
+ " base_config = yaml.safe_load(f)\n",
196
+ "# ~2 steps per sample @ batch_size=4, gradient_acc=4\n",
197
+ "base_config['training']['max_steps'] = max(500, WAXAL_CAP // 8)\n",
198
+ "tmp_base_cfg = '/tmp/base_config.yaml'\n",
199
+ "with open(tmp_base_cfg, 'w') as f:\n",
200
+ " yaml.dump(base_config, f)\n",
201
+ "\n",
202
+ "print(f'Training Bambara v0 adapter (Waxal cap={WAXAL_CAP}, max_steps={base_config[\"training\"][\"max_steps\"]})...')\n",
203
+ "trainer_bam = WhisperLoRATrainer(\n",
204
+ " base_config_path=tmp_base_cfg,\n",
205
+ " language_config_path=tmp_bam_cfg,\n",
206
+ ")\n",
207
+ "trainer_bam.setup()\n",
208
+ "\n",
209
+ "# No feedback yet — materialise Waxal and train\n",
210
+ "trainer_bam.merge_extra_data([], repeat=1, waxal_cap=WAXAL_CAP)\n",
211
+ "\n",
212
+ "trainer_bam.train()\n",
213
+ "print(f'✅ Bambara v0 adapter saved to {bam_out}')"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "cell-train-ful",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "# Cell 7 — Train v0 Fula adapter on full Waxal (ful)\n",
224
+ "ful_cfg_src = str(space_dir / 'configs' / 'lora_fula.yaml')\n",
225
+ "ful_out = '/tmp/sahel_adapter_ful'\n",
226
+ "\n",
227
+ "with open(ful_cfg_src) as f:\n",
228
+ " ful_config = yaml.safe_load(f)\n",
229
+ "ful_config['output_dir'] = ful_out\n",
230
+ "tmp_ful_cfg = '/tmp/lora_ful.yaml'\n",
231
+ "with open(tmp_ful_cfg, 'w') as f:\n",
232
+ " yaml.dump(ful_config, f)\n",
233
+ "\n",
234
+ "print(f'Training Fula v0 adapter (Waxal cap={WAXAL_CAP})...')\n",
235
+ "trainer_ful = WhisperLoRATrainer(\n",
236
+ " base_config_path=tmp_base_cfg,\n",
237
+ " language_config_path=tmp_ful_cfg,\n",
238
+ ")\n",
239
+ "trainer_ful.setup()\n",
240
+ "trainer_ful.merge_extra_data([], repeat=1, waxal_cap=WAXAL_CAP)\n",
241
+ "trainer_ful.train()\n",
242
+ "print(f'✅ Fula v0 adapter saved to {ful_out}')"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "id": "cell-push-adapters",
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "# Cell 8 — Push both adapters to HF Model repo\n",
253
+ "from huggingface_hub import HfApi\n",
254
+ "api = HfApi(token=HF_TOKEN)\n",
255
+ "\n",
256
+ "for lang, out_dir, path_in_repo in [\n",
257
+ " ('bam', bam_out, 'adapters/bambara'),\n",
258
+ " ('ful', ful_out, 'adapters/fula'),\n",
259
+ "]:\n",
260
+ " api.upload_folder(\n",
261
+ " folder_path=out_dir,\n",
262
+ " repo_id=ADAPTER_REPO_ID,\n",
263
+ " repo_type='model',\n",
264
+ " path_in_repo=path_in_repo,\n",
265
+ " commit_message=f'v0 {lang} adapter trained on Waxal (cap={WAXAL_CAP} samples)',\n",
266
+ " )\n",
267
+ " print(f'✅ {lang} → {ADAPTER_REPO_ID}/{path_in_repo}')\n",
268
+ "\n",
269
+ "print()\n",
270
+ "print('Bootstrap complete!')\n",
271
+ "print()\n",
272
+ "print('Next steps:')\n",
273
+ "print(' 1. Push your project code to the Space (git push space main)')\n",
274
+ "print(' 2. In Space Settings → Secrets, add HF_TOKEN, FEEDBACK_REPO_ID, ADAPTER_REPO_ID')\n",
275
+ "print(' 3. Space will build — your app at https://huggingface.co/spaces/ous-sow/sahel-agri-voice')\n",
276
+ "print(' 4. Tab 3 → Reload Adapters — Bambara + Fula adapters will be loaded')\n",
277
+ "print(' 5. Collect farmer corrections, then run train_colab.ipynb to keep improving')"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "id": "cell-verify",
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "# Cell 9 — Quick verification: list what was pushed to the adapter repo\n",
288
+ "from huggingface_hub import list_repo_files\n",
289
+ "\n",
290
+ "files = sorted(list_repo_files(ADAPTER_REPO_ID, repo_type='model', token=HF_TOKEN))\n",
291
+ "print(f'Files in {ADAPTER_REPO_ID}:')\n",
292
+ "for f in files:\n",
293
+ " print(f' {f}')\n",
294
+ "\n",
295
+ "bam_ok = any('bambara/adapter_config.json' in f for f in files)\n",
296
+ "ful_ok = any('fula/adapter_config.json' in f for f in files)\n",
297
+ "print()\n",
298
+ "print(f'Bambara adapter: {\"✅\" if bam_ok else \"❌\"}')\n",
299
+ "print(f'Fula adapter: {\"✅\" if ful_ok else \"❌\"}')\n",
300
+ "\n",
301
+ "if bam_ok and ful_ok:\n",
302
+ " print('\\n🎉 Both adapters ready. Your Space will use them automatically on the next reload.')\n",
303
+ "else:\n",
304
+ " print('\\n⚠️ Some adapters are missing — check the training cells above for errors.')"
305
+ ]
306
+ }
307
+ ]
308
+ }
notebooks/train_colab.ipynb ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "version": "3.10.0"
13
+ },
14
+ "colab": {
15
+ "provenance": [],
16
+ "gpuType": "T4"
17
+ },
18
+ "accelerator": "GPU"
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "id": "cell-title",
24
+ "metadata": {},
25
+ "source": [
26
+ "# 🌾 Sahel-Agri Voice AI — Fine-tune on Farmer Feedback\n",
27
+ "\n",
28
+ "**Run after collecting ≥10 corrections in the Space.** \n",
29
+ "First run? Use `bootstrap_repos.ipynb` instead to train the v0 Waxal adapter.\n",
30
+ "\n",
31
+ "This notebook fine-tunes the existing LoRA adapter using:\n",
32
+ "- **Waxal baseline** (up to 500 samples) — keeps the model grounded\n",
33
+ "- **Farmer corrections** (3× upsampled) — targeted improvement from real field use\n",
34
+ "\n",
35
+ "**Before running:** Runtime → Change runtime type → **T4 GPU**"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "cell-gpu-check",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# Cell 1 — GPU check\n",
46
+ "import subprocess\n",
47
+ "result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
48
+ "if result.returncode != 0:\n",
49
+ " raise RuntimeError('No GPU! Runtime → Change runtime type → T4 GPU')\n",
50
+ "print(result.stdout[:500])\n",
51
+ "print('✅ GPU ready')"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "cell-install",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# Cell 2 — Install dependencies (matching Space versions)\n",
62
+ "!pip install -q \\\n",
63
+ " torch==2.11.0 torchaudio==2.11.0 \\\n",
64
+ " transformers==5.5.0 datasets==4.8.4 \\\n",
65
+ " accelerate==1.13.0 evaluate==0.4.2 \\\n",
66
+ " huggingface-hub==1.9.0 peft==0.18.1 \\\n",
67
+ " librosa==0.10.2 soundfile==0.12.1 \\\n",
68
+ " jiwer==3.0.4 pyyaml==6.0.2\n",
69
+ "print('✅ Packages installed')"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "cell-hf-login",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": "# Cell 3 — HuggingFace login\n# Colab: 🔑 icon (left sidebar) → Add new secret → name=HF_TOKEN\n# Kaggle: Add Data → add as Kaggle secret named HF_TOKEN\nimport os\ntry:\n from google.colab import userdata # type: ignore\n HF_TOKEN = userdata.get('HF_TOKEN')\nexcept Exception:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError('HF_TOKEN not found — see instructions above.')\n\nfrom huggingface_hub import login\nlogin(token=HF_TOKEN, add_to_git_credential=False)\n\nSPACE_REPO_ID = 'ous-sow/sahel-agri-voice'\nFEEDBACK_REPO_ID = 'ous-sow/sahel-agri-feedback'\nADAPTER_REPO_ID = 'ous-sow/sahel-agri-adapters'\n# Must match what the Space uses — whisper-small for cpu-basic, whisper-large-v3-turbo for GPU.\nWHISPER_MODEL_ID = 'openai/whisper-small'\nTRAIN_LANG = 'bam' # ← change to 'ful' for Fula\n\nprint(f'✅ Logged in | training language: {TRAIN_LANG}')"
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "cell-download",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "# Cell 4 — Download Space code and feedback corrections\n",
88
+ "import json, shutil, sys\n",
89
+ "from pathlib import Path\n",
90
+ "from huggingface_hub import snapshot_download, hf_hub_download\n",
91
+ "\n",
92
+ "# Get Space code (contains src/, configs/)\n",
93
+ "space_dir = Path(snapshot_download(\n",
94
+ " repo_id=SPACE_REPO_ID, repo_type='space', token=HF_TOKEN\n",
95
+ "))\n",
96
+ "sys.path.insert(0, str(space_dir))\n",
97
+ "print(f'Space code: {space_dir}')\n",
98
+ "\n",
99
+ "# Download feedback corrections.jsonl\n",
100
+ "jsonl_path = hf_hub_download(\n",
101
+ " repo_id=FEEDBACK_REPO_ID,\n",
102
+ " filename='corrections.jsonl',\n",
103
+ " repo_type='dataset',\n",
104
+ " token=HF_TOKEN,\n",
105
+ ")\n",
106
+ "with open(jsonl_path, encoding='utf-8') as f:\n",
107
+ " all_records = [json.loads(l) for l in f if l.strip()]\n",
108
+ "\n",
109
+ "corrections = [\n",
110
+ " r for r in all_records\n",
111
+ " if r.get('is_correction') and r['language'] == TRAIN_LANG\n",
112
+ "]\n",
113
+ "print(f'Total feedback records : {len(all_records)}')\n",
114
+ "print(f'Corrections for {TRAIN_LANG} : {len(corrections)}')\n",
115
+ "\n",
116
+ "if len(corrections) < 5:\n",
117
+ " print('⚠️ Very few corrections — consider collecting more before training.')\n",
118
+ " print(' Training will proceed with Waxal only (corrections will be skipped).')"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "cell-download-audio",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Cell 5 — Download feedback audio files from HF Dataset repo\n",
129
+ "fb_audio_dir = Path('/tmp/sahel_feedback_audio')\n",
130
+ "fb_audio_dir.mkdir(exist_ok=True)\n",
131
+ "\n",
132
+ "skipped = 0\n",
133
+ "for rec in corrections:\n",
134
+ " local_path = fb_audio_dir / Path(rec['audio_file']).name\n",
135
+ " if local_path.exists():\n",
136
+ " continue\n",
137
+ " try:\n",
138
+ " dl = hf_hub_download(\n",
139
+ " repo_id=FEEDBACK_REPO_ID,\n",
140
+ " filename=rec['audio_file'],\n",
141
+ " repo_type='dataset',\n",
142
+ " token=HF_TOKEN,\n",
143
+ " )\n",
144
+ " shutil.copy(dl, local_path)\n",
145
+ " except Exception as e:\n",
146
+ " skipped += 1\n",
147
+ " print(f' skip {rec[\"audio_file\"]}: {e}')\n",
148
+ "\n",
149
+ "# Point records at local paths\n",
150
+ "for rec in corrections:\n",
151
+ " local = fb_audio_dir / Path(rec['audio_file']).name\n",
152
+ " if local.exists():\n",
153
+ " rec['audio_file'] = str(local)\n",
154
+ "\n",
155
+ "available = [r for r in corrections if Path(r['audio_file']).exists()]\n",
156
+ "print(f'Downloaded {len(available)} / {len(corrections)} audio files (skipped {skipped})')"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "cell-train",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "# Cell 6 — Fine-tune: Waxal baseline + farmer corrections\n",
167
+ "#\n",
168
+ "# WhisperLoRATrainer.setup() loads Waxal (streaming).\n",
169
+ "# merge_extra_data() materialises Waxal (up to 500 samples),\n",
170
+ "# appends corrections (3× upsampled), shuffles the combined dataset.\n",
171
+ "# train() runs standard Seq2SeqTrainer on the merged dataset.\n",
172
+ "\n",
173
+ "import os\n",
174
+ "os.environ['HF_TOKEN'] = HF_TOKEN\n",
175
+ "\n",
176
+ "from src.training.trainer import WhisperLoRATrainer\n",
177
+ "\n",
178
+ "lang_config_map = {'bam': 'lora_bambara.yaml', 'ful': 'lora_fula.yaml'}\n",
179
+ "base_cfg = str(space_dir / 'configs' / 'base_config.yaml')\n",
180
+ "lang_cfg = str(space_dir / 'configs' / lang_config_map[TRAIN_LANG])\n",
181
+ "output_dir = f'/tmp/sahel_adapter_{TRAIN_LANG}'\n",
182
+ "\n",
183
+ "# Override output_dir so adapter saves to /tmp on Colab\n",
184
+ "import yaml\n",
185
+ "with open(lang_cfg) as f:\n",
186
+ " lang_config = yaml.safe_load(f)\n",
187
+ "lang_config['output_dir'] = output_dir\n",
188
+ "tmp_lang_cfg = f'/tmp/lora_{TRAIN_LANG}_tmp.yaml'\n",
189
+ "with open(tmp_lang_cfg, 'w') as f:\n",
190
+ " yaml.dump(lang_config, f)\n",
191
+ "\n",
192
+ "trainer = WhisperLoRATrainer(\n",
193
+ " base_config_path=base_cfg,\n",
194
+ " language_config_path=tmp_lang_cfg,\n",
195
+ ")\n",
196
+ "trainer.setup()\n",
197
+ "\n",
198
+ "if available:\n",
199
+ " print(f'Merging {len(available)} corrections (×3) with Waxal baseline (cap=500)...')\n",
200
+ " trainer.merge_extra_data(available, repeat=3, waxal_cap=500)\n",
201
+ "else:\n",
202
+ " print('No corrections available — training on Waxal only.')\n",
203
+ "\n",
204
+ "trainer.train()\n",
205
+ "print(f'✅ Training complete — adapter at {output_dir}')"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "id": "cell-push",
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "# Cell 7 — Push adapter to HF Model repo\n",
216
+ "from huggingface_hub import HfApi\n",
217
+ "api = HfApi(token=HF_TOKEN)\n",
218
+ "\n",
219
+ "path_in_repo = 'adapters/bambara' if TRAIN_LANG == 'bam' else 'adapters/fula'\n",
220
+ "n_corrections = len(available)\n",
221
+ "\n",
222
+ "api.upload_folder(\n",
223
+ " folder_path=output_dir,\n",
224
+ " repo_id=ADAPTER_REPO_ID,\n",
225
+ " repo_type='model',\n",
226
+ " path_in_repo=path_in_repo,\n",
227
+ " commit_message=(\n",
228
+ " f'Fine-tune {TRAIN_LANG}: Waxal baseline + {n_corrections} farmer corrections'\n",
229
+ " ),\n",
230
+ ")\n",
231
+ "print(f'✅ Pushed to {ADAPTER_REPO_ID}/{path_in_repo}')\n",
232
+ "print('\\nNext: Space → Tab 3 → Reload Adapters from Hub')"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "id": "cell-sanity",
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "# Cell 8 — Sanity check: compare WER before vs after adapter\n",
243
+ "import random, torch, librosa, jiwer\n",
244
+ "from transformers import WhisperForConditionalGeneration, WhisperProcessor\n",
245
+ "from peft import PeftModel\n",
246
+ "\n",
247
+ "if not available:\n",
248
+ " print('No test samples — skipping sanity check.')\n",
249
+ "else:\n",
250
+ " test_rec = random.choice(available)\n",
251
+ " print(f'Audio : {Path(test_rec[\"audio_file\"]).name}')\n",
252
+ " print(f'Expected : {test_rec[\"corrected_text\"]}')\n",
253
+ " print(f'Pre-train: {test_rec[\"whisper_output\"]}')\n",
254
+ "\n",
255
+ " # Load base + adapter\n",
256
+ " processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n",
257
+ " base = WhisperForConditionalGeneration.from_pretrained(\n",
258
+ " WHISPER_MODEL_ID, torch_dtype=torch.float16, token=HF_TOKEN\n",
259
+ " ).to('cuda')\n",
260
+ " model = PeftModel.from_pretrained(base, output_dir).eval()\n",
261
+ "\n",
262
+ " audio_np, _ = librosa.load(test_rec['audio_file'], sr=16000, mono=True)\n",
263
+ " feats = processor.feature_extractor(\n",
264
+ " audio_np, sampling_rate=16000, return_tensors='pt'\n",
265
+ " ).input_features.half().to('cuda')\n",
266
+ "\n",
267
+ " with torch.no_grad():\n",
268
+ " ids = model.generate(feats, max_new_tokens=256)\n",
269
+ " result = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()\n",
270
+ " print(f'Post-train: {result}')\n",
271
+ "\n",
272
+ " ref = test_rec['corrected_text']\n",
273
+ " wer_before = jiwer.wer(ref, test_rec['whisper_output']) if test_rec.get('whisper_output') else 1.0\n",
274
+ " wer_after = jiwer.wer(ref, result)\n",
275
+ " print(f'\\nWER before: {wer_before:.1%} → WER after: {wer_after:.1%}')\n",
276
+ " if wer_after < wer_before:\n",
277
+ " print('✅ Adapter improved transcription quality!')\n",
278
+ " else:\n",
279
+ " print('ℹ️ No improvement on this single sample — collect more corrections and retrain.')"
280
+ ]
281
+ }
282
+ ]
283
+ }
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Sahel-Agri Voice AI — Python Dependencies
3
+ # HuggingFace Spaces (ZeroGPU) deployment — CUDA pre-installed, no +cu128 suffix
4
+ #
5
+ # Local CPU test:
6
+ # pip install -r requirements.txt
7
+ # -----------------------------------------------------------------------------
8
+
9
+ # PyTorch (CPU build — works on HF Spaces cpu-basic and locally)
10
+ torch==2.11.0
11
+ torchaudio==2.11.0
12
+
13
+ # HuggingFace core
14
+ transformers==5.5.0
15
+ datasets==4.8.4
16
+ accelerate==1.13.0
17
+ evaluate==0.4.2
18
+ huggingface-hub==1.9.0
19
+
20
+ # PEFT (LoRA adapters)
21
+ peft==0.18.1
22
+
23
+ # Audio processing
24
+ librosa==0.10.2
25
+ soundfile==0.12.1
26
+ audiomentations==0.43.1
27
+
28
+ # Quantization (CPU: installs fine; 4-bit/8-bit requires GPU at runtime)
29
+ bitsandbytes==0.49.2
30
+
31
+ # Metrics
32
+ jiwer==3.0.4
33
+
34
+ # Config & environment
35
+ pyyaml==6.0.2
36
+ python-dotenv==1.1.0
37
+
38
+ # Gradio (must match sdk_version in README.md)
39
+ gradio==4.44.0
40
+
41
+ # Pydantic v2
42
+ pydantic==2.11.3
43
+
44
+ # Testing
45
+ pytest==8.3.5
46
+ pytest-asyncio==0.26.0
47
+
48
+ # Utilities
49
+ numpy==2.2.4
50
+ scipy==1.15.2
scripts/export_onnx.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 4a: Merge LoRA adapters and export language-specific ONNX models.
3
+ Validates that ONNX WER is within 2% of PyTorch baseline.
4
+
5
+ Usage:
6
+ python scripts/export_onnx.py
7
+ """
8
+ import logging
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ from dotenv import load_dotenv
16
+
17
+ load_dotenv()
18
+
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
20
+
21
+ import yaml
22
+
23
+ from src.optimization.onnx_exporter import ONNXExporter
24
+
25
+
26
+ def export_language(language: str, adapter_path: str, config: dict) -> None:
27
+ from peft import PeftModel
28
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
29
+
30
+ hf_token = os.getenv("HF_TOKEN")
31
+ model_id = config["model"]["id"]
32
+
33
+ print(f"\n[{language.upper()}] Loading base model...")
34
+ base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token)
35
+ processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)
36
+
37
+ print(f"[{language.upper()}] Loading adapter from {adapter_path}...")
38
+ peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language)
39
+
40
+ output_dir = f"{config['paths']['models']}/onnx/{language}"
41
+ exporter = ONNXExporter()
42
+ result_path = exporter.merge_and_export(peft_model, processor, output_dir, language)
43
+ print(f"[{language.upper()}] ONNX exported to: {result_path}")
44
+
45
+
46
+ def main() -> None:
47
+ with open("configs/base_config.yaml") as f:
48
+ config = yaml.safe_load(f)
49
+
50
+ print("=" * 60)
51
+ print("Sahel-Agri Voice AI — ONNX Export")
52
+ print("=" * 60)
53
+
54
+ bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
55
+ fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")
56
+
57
+ for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]:
58
+ if Path(adapter_path).exists():
59
+ export_language(language, adapter_path, config)
60
+ else:
61
+ print(f"\nSkipping {language}: adapter not found at {adapter_path}")
62
+
63
+ print("\nExport complete.")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
scripts/run_data_pipeline.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 2: Download google/waxal, apply augmentation, print statistics.
3
+ Streams examples and caches to data_cache/ as Arrow files.
4
+
5
+ Usage:
6
+ python scripts/run_data_pipeline.py --subset bam --max-examples 100
7
+ """
8
+ import argparse
9
+ import sys
10
+ import time
11
+ from pathlib import Path
12
+
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ import os
16
+
17
+ from dotenv import load_dotenv
18
+
19
+ load_dotenv()
20
+
21
+
22
+ def main(subset: str, max_examples: int) -> None:
23
+ import yaml
24
+ from transformers import WhisperProcessor
25
+
26
+ from src.data.augmentation import FieldNoiseAugmenter
27
+ from src.data.waxal_loader import WaxalDataLoader
28
+
29
+ with open("configs/base_config.yaml") as f:
30
+ config = yaml.safe_load(f)
31
+
32
+ hf_token = os.getenv("HF_TOKEN")
33
+ model_id = config["model"]["id"]
34
+
35
+ print("=" * 60)
36
+ print(f"Waxal Data Pipeline — subset: {subset}")
37
+ print("=" * 60)
38
+
39
+ print(f"\n[1/4] Loading WhisperProcessor ({model_id})...")
40
+ processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)
41
+
42
+ print("[2/4] Initializing augmenter...")
43
+ augmenter = FieldNoiseAugmenter(config["paths"]["noise_samples"], config)
44
+ print(f" Augmenter ready: {augmenter.is_ready()}")
45
+
46
+ print(f"[3/4] Streaming google/waxal subset={subset}...")
47
+ loader = WaxalDataLoader(subset, config, hf_token=hf_token)
48
+
49
+ t0 = time.time()
50
+ count = 0
51
+ total_duration = 0.0
52
+
53
+ for example in loader.iter_processed(processor, split="train", augmenter=augmenter):
54
+ count += 1
55
+ # input_features shape: (80, 3000) = 30 seconds at most
56
+ # Estimate actual audio duration from non-padding frames
57
+ total_duration += 30.0 # max chunk
58
+ if count >= max_examples:
59
+ break
60
+
61
+ elapsed = time.time() - t0
62
+
63
+ print(f"\n[4/4] Results:")
64
+ print(f" Examples processed: {count}")
65
+ print(f" Approx total audio: {total_duration / 3600:.2f} hours")
66
+ print(f" Processing time: {elapsed:.1f}s")
67
+ print(f" Throughput: {count / elapsed:.1f} examples/sec")
68
+ print(f"\nData pipeline PASSED.")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument("--subset", default="bam", choices=["bam", "ful"])
74
+ parser.add_argument("--max-examples", type=int, default=50)
75
+ args = parser.parse_args()
76
+ main(args.subset, args.max_examples)
scripts/run_server.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 4b: Start the FastAPI inference server.
3
+
4
+ Usage:
5
+ python scripts/run_server.py
6
+ python scripts/run_server.py --host 0.0.0.0 --port 8000
7
+ """
8
+ import argparse
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ sys.path.insert(0, str(Path(__file__).parent.parent))
13
+
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ import uvicorn
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Start Sahel-Agri Voice AI server")
22
+ parser.add_argument("--host", default="0.0.0.0")
23
+ parser.add_argument("--port", type=int, default=8000)
24
+ parser.add_argument("--reload", action="store_true", help="Enable hot-reload (dev only)")
25
+ args = parser.parse_args()
26
+
27
+ print(f"Starting server on http://{args.host}:{args.port}")
28
+ print("Endpoints:")
29
+ print(f" GET http://localhost:{args.port}/api/v1/health")
30
+ print(f" POST http://localhost:{args.port}/api/v1/transcribe")
31
+ print(f" POST http://localhost:{args.port}/api/v1/query")
32
+ print(f" GET http://localhost:{args.port}/docs (Swagger UI)")
33
+ print()
34
+
35
+ uvicorn.run(
36
+ "src.api.app:app",
37
+ host=args.host,
38
+ port=args.port,
39
+ workers=1, # Single worker: GPU model shared in memory
40
+ reload=args.reload,
41
+ log_level="info",
42
+ )
scripts/train_bambara.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3a: Fine-tune LoRA adapter for Bambara (bam).
3
+
4
+ Usage:
5
+ python scripts/train_bambara.py
6
+ """
7
+ import logging
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
18
+
19
+ from src.training.trainer import WhisperLoRATrainer
20
+
21
+ if __name__ == "__main__":
22
+ trainer = WhisperLoRATrainer(
23
+ base_config_path="configs/base_config.yaml",
24
+ language_config_path="configs/lora_bambara.yaml",
25
+ )
26
+ trainer.setup()
27
+ trainer.train()
28
+ print("\nBambara training complete. Adapter saved to adapters/bambara/")
scripts/train_fula.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3b: Fine-tune LoRA adapter for Fula (ful).
3
+ Trains on the same frozen backbone as Bambara — base model weights are NOT modified.
4
+
5
+ Usage:
6
+ python scripts/train_fula.py
7
+ """
8
+ import logging
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ sys.path.insert(0, str(Path(__file__).parent.parent))
13
+
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
19
+
20
+ from src.training.trainer import WhisperLoRATrainer
21
+
22
+ if __name__ == "__main__":
23
+ trainer = WhisperLoRATrainer(
24
+ base_config_path="configs/base_config.yaml",
25
+ language_config_path="configs/lora_fula.yaml",
26
+ )
27
+ trainer.setup()
28
+ trainer.train()
29
+ print("\nFula training complete. Adapter saved to adapters/fula/")
scripts/verify_baseline.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 1 smoke test: load Whisper, run inference on a sample audio clip.
3
+ Prints model info, inference time, GPU memory usage, and sample transcript.
4
+
5
+ Usage:
6
+ python scripts/verify_baseline.py
7
+ """
8
+ import sys
9
+ import time
10
+ from pathlib import Path
11
+
12
+ # Allow imports from project root
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def main() -> None:
20
+ from src.engine.whisper_base import WhisperBackbone
21
+
22
+ print("=" * 60)
23
+ print("Sahel-Agri Voice AI — Baseline Verification")
24
+ print("=" * 60)
25
+
26
+ # 1. Check environment
27
+ print(f"\nPython: {sys.version.split()[0]}")
28
+ print(f"PyTorch: {torch.__version__}")
29
+ print(f"CUDA available: {torch.cuda.is_available()}")
30
+ if torch.cuda.is_available():
31
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
32
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
33
+
34
+ # 2. Load model
35
+ print("\n[1/3] Loading backbone model...")
36
+ t0 = time.time()
37
+ backbone = WhisperBackbone("configs/base_config.yaml")
38
+ backbone.load(device="cuda")
39
+ load_time = time.time() - t0
40
+ print(f" Loaded in {load_time:.1f}s")
41
+
42
+ if torch.cuda.is_available():
43
+ used = torch.cuda.memory_allocated() / 1e9
44
+ reserved = torch.cuda.memory_reserved() / 1e9
45
+ print(f" GPU memory: {used:.2f} GB allocated / {reserved:.2f} GB reserved")
46
+
47
+ # 3. Generate synthetic test audio (1 second of silence with slight noise)
48
+ print("\n[2/3] Generating test audio (1s white noise)...")
49
+ sample_rate = 16000
50
+ duration = 1.0
51
+ audio = np.random.randn(int(sample_rate * duration)).astype(np.float32) * 0.01
52
+
53
+ # 4. Run inference
54
+ print("[3/3] Running inference...")
55
+ processor = backbone.processor
56
+ model = backbone.model
57
+
58
+ inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
59
+ input_features = inputs.input_features.to(backbone.device)
60
+ if backbone.device == "cuda":
61
+ input_features = input_features.half()
62
+
63
+ t0 = time.time()
64
+ with torch.no_grad():
65
+ predicted_ids = model.generate(input_features, max_new_tokens=50)
66
+ infer_time = time.time() - t0
67
+
68
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
69
+
70
+ print(f"\n{'=' * 60}")
71
+ print(f"Transcript: '{transcription}' (noise input — blank expected)")
72
+ print(f"Inference time: {infer_time * 1000:.0f} ms")
73
+ print(f"\nBaseline verification PASSED.")
74
+ print(f"{'=' * 60}")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
src/__init__.py ADDED
File without changes
src/api/__init__.py ADDED
File without changes
src/api/app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application factory.
3
+ Uses lifespan context manager to load the Whisper model at startup
4
+ and register language adapters — keeping a single backbone in GPU memory.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import os
10
+ from contextlib import asynccontextmanager
11
+
12
+ import yaml
13
+ from fastapi import FastAPI
14
+
15
+ from src.api.middleware import register_middleware
16
+ from src.api.routes import health, iot, transcribe
17
+ from src.engine.adapter_manager import AdapterManager
18
+ from src.engine.transcriber import Transcriber
19
+ from src.engine.whisper_base import WhisperBackbone
20
+ from src.iot.sensor_bridge import SensorBridge
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ logging.basicConfig(
25
+ level=os.getenv("LOG_LEVEL", "INFO"),
26
+ format="%(asctime)s %(levelname)s %(name)s — %(message)s",
27
+ )
28
+
29
+
30
+ @asynccontextmanager
31
+ async def lifespan(app: FastAPI):
32
+ """Load model at startup, free GPU memory at shutdown."""
33
+ with open("configs/base_config.yaml") as f:
34
+ config = yaml.safe_load(f)
35
+
36
+ hf_token = os.getenv("HF_TOKEN")
37
+ device = os.getenv("DEVICE", "cuda")
38
+ bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
39
+ fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")
40
+ sensor_api_url = os.getenv("SENSOR_API_URL") or None
41
+
42
+ # 1. Load backbone
43
+ logger.info("Loading Whisper backbone...")
44
+ backbone = WhisperBackbone("configs/base_config.yaml")
45
+ backbone.load(device=device, hf_token=hf_token)
46
+
47
+ # 2. Register adapters (they are loaded on first use via activate())
48
+ adapter_manager = AdapterManager(backbone.model, config)
49
+ adapter_manager.register("bam", bambara_path)
50
+ adapter_manager.register("ful", fula_path)
51
+
52
+ # 3. Pre-load the default adapter to warm up VRAM
53
+ try:
54
+ adapter_manager.load_adapter("bam")
55
+ logger.info("Default adapter 'bam' pre-loaded.")
56
+ except Exception as e:
57
+ logger.warning("Could not pre-load 'bam' adapter: %s", e)
58
+
59
+ # 4. Create transcriber and sensor bridge
60
+ transcriber = Transcriber(backbone, adapter_manager)
61
+ sensor_bridge = SensorBridge(sensor_api_url=sensor_api_url)
62
+
63
+ # 5. Attach to app.state for dependency injection
64
+ app.state.backbone = backbone
65
+ app.state.adapter_manager = adapter_manager
66
+ app.state.transcriber = transcriber
67
+ app.state.sensor_bridge = sensor_bridge
68
+
69
+ logger.info("Sahel-Agri Voice AI server ready.")
70
+ yield
71
+
72
+ # Shutdown
73
+ logger.info("Shutting down — freeing GPU memory...")
74
+ backbone.free()
75
+
76
+
77
+ def create_app() -> FastAPI:
78
+ app = FastAPI(
79
+ title="Sahel-Agri Voice AI",
80
+ description=(
81
+ "Modular STT engine for Bambara and Fula — serving Mali and Guinea farmers "
82
+ "via voice-first agricultural intelligence."
83
+ ),
84
+ version="0.1.0",
85
+ lifespan=lifespan,
86
+ )
87
+
88
+ register_middleware(app)
89
+
90
+ # Register routes
91
+ app.include_router(health.router, prefix="/api/v1", tags=["health"])
92
+ app.include_router(transcribe.router, prefix="/api/v1", tags=["transcribe"])
93
+ app.include_router(iot.router, prefix="/api/v1", tags=["iot"])
94
+
95
+ return app
96
+
97
+
98
+ app = create_app()
src/api/dependencies.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI dependency injection: retrieves shared model objects from app.state."""
2
+ from __future__ import annotations
3
+
4
+ from fastapi import Request
5
+
6
+ from src.engine.adapter_manager import AdapterManager
7
+ from src.engine.transcriber import Transcriber
8
+ from src.iot.sensor_bridge import SensorBridge
9
+
10
+
11
+ def get_transcriber(request: Request) -> Transcriber:
12
+ return request.app.state.transcriber
13
+
14
+
15
+ def get_adapter_manager(request: Request) -> AdapterManager:
16
+ return request.app.state.adapter_manager
17
+
18
+
19
+ def get_sensor_bridge(request: Request) -> SensorBridge:
20
+ return request.app.state.sensor_bridge
src/api/middleware.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CORS, structured request logging, and rate-limit middleware."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import time
6
+ import uuid
7
+
8
+ from fastapi import FastAPI, Request, Response
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from slowapi import Limiter, _rate_limit_exceeded_handler
11
+ from slowapi.errors import RateLimitExceeded
12
+ from slowapi.util import get_remote_address
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
17
+
18
+
19
+ def register_middleware(app: FastAPI) -> None:
20
+ """Attach all middleware to the FastAPI app."""
21
+
22
+ # CORS — allow WhatsApp webhook domain and local development
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"], # Tighten in production with specific domains
26
+ allow_credentials=True,
27
+ allow_methods=["GET", "POST"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Rate limiting
32
+ app.state.limiter = limiter
33
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
34
+
35
+ @app.middleware("http")
36
+ async def logging_middleware(request: Request, call_next) -> Response:
37
+ request_id = str(uuid.uuid4())[:8]
38
+ t0 = time.perf_counter()
39
+ response = await call_next(request)
40
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
41
+ logger.info(
42
+ "req_id=%s method=%s path=%s status=%d latency_ms=%d",
43
+ request_id, request.method, request.url.path,
44
+ response.status_code, elapsed_ms,
45
+ )
46
+ response.headers["X-Request-ID"] = request_id
47
+ return response
src/api/routes/__init__.py ADDED
File without changes
src/api/routes/health.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GET /api/v1/health — model status and adapter availability."""
2
+ from __future__ import annotations
3
+
4
+ from fastapi import APIRouter, Depends, Request
5
+
6
+ from src.api.dependencies import get_adapter_manager
7
+ from src.api.schemas import HealthResponse
8
+ from src.engine.adapter_manager import AdapterManager
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.get("/health", response_model=HealthResponse)
14
+ async def health_check(
15
+ request: Request,
16
+ adapter_manager: AdapterManager = Depends(get_adapter_manager),
17
+ ) -> HealthResponse:
18
+ model_loaded = hasattr(request.app.state, "transcriber")
19
+ return HealthResponse(
20
+ status="ok" if model_loaded else "loading",
21
+ model_loaded=model_loaded,
22
+ active_adapter=adapter_manager.get_active(),
23
+ adapters_available=adapter_manager.list_available(),
24
+ adapters_loaded=adapter_manager.list_loaded(),
25
+ )
src/api/routes/iot.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """POST /api/v1/query — full pipeline: audio → transcription → intent → sensor → voice response."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ import time
8
+ from typing import Annotated, Optional
9
+
10
+ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
11
+
12
+ from src.api.dependencies import get_sensor_bridge, get_transcriber
13
+ from src.api.schemas import IoTQueryResponse
14
+ from src.engine.transcriber import Transcriber
15
+ from src.iot.intent_parser import IntentParser
16
+ from src.iot.sensor_bridge import SensorBridge
17
+ from src.iot.voice_responder import VoiceResponder
18
+
19
+ logger = logging.getLogger(__name__)
20
+ router = APIRouter()
21
+
22
+ _intent_parser = IntentParser()
23
+ _voice_responder = VoiceResponder(language="fr")
24
+
25
+ SUPPORTED_LANGUAGES = {"bam", "ful"}
26
+ MAX_AUDIO_BYTES = 10 * 1024 * 1024
27
+
28
+
29
+ @router.post("/query", response_model=IoTQueryResponse)
30
+ async def agricultural_query(
31
+ audio_file: Annotated[UploadFile, File(description="Audio file with farmer's voice query")],
32
+ language: Annotated[str, Form(description="Language code: 'bam' or 'ful'")] = "bam",
33
+ field_id: Annotated[Optional[str], Form(description="Field/location ID for sensor lookup")] = None,
34
+ transcriber: Transcriber = Depends(get_transcriber),
35
+ sensor_bridge: SensorBridge = Depends(get_sensor_bridge),
36
+ ) -> IoTQueryResponse:
37
+ t0 = time.perf_counter()
38
+
39
+ if language not in SUPPORTED_LANGUAGES:
40
+ raise HTTPException(
41
+ status_code=422,
42
+ detail=f"Unsupported language '{language}'. Supported: {sorted(SUPPORTED_LANGUAGES)}",
43
+ )
44
+
45
+ audio_bytes = await audio_file.read()
46
+ if len(audio_bytes) > MAX_AUDIO_BYTES:
47
+ raise HTTPException(status_code=413, detail="Audio file too large. Max 10 MB.")
48
+
49
+ ext = os.path.splitext(audio_file.filename or "audio.wav")[1].lower() or ".wav"
50
+ tmp_path = None
51
+ try:
52
+ # Step 1: Transcribe
53
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
54
+ tmp.write(audio_bytes)
55
+ tmp_path = tmp.name
56
+
57
+ transcription_result = transcriber.transcribe_file(tmp_path, language)
58
+
59
+ # Step 2: Parse intent
60
+ intent = _intent_parser.parse(transcription_result.text, language)
61
+
62
+ # Step 3: Fetch sensor data
63
+ sensor_data = await sensor_bridge.fetch(intent, field_id=field_id)
64
+
65
+ # Step 4: Generate voice response
66
+ voice_response = _voice_responder.generate_response(intent, sensor_data)
67
+
68
+ except HTTPException:
69
+ raise
70
+ except Exception as e:
71
+ logger.error("IoT query failed: %s", e, exc_info=True)
72
+ raise HTTPException(status_code=500, detail=str(e))
73
+ finally:
74
+ if tmp_path and os.path.exists(tmp_path):
75
+ os.unlink(tmp_path)
76
+
77
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
78
+
79
+ return IoTQueryResponse(
80
+ transcription=transcription_result.text,
81
+ language=language,
82
+ intent={
83
+ "action": intent.action,
84
+ "entity": intent.entity,
85
+ "confidence": intent.confidence,
86
+ },
87
+ sensor_data=sensor_data.values,
88
+ voice_response=voice_response,
89
+ processing_time_ms=elapsed_ms,
90
+ )
src/api/routes/transcribe.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """POST /api/v1/transcribe — convert uploaded audio to text."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ from typing import Annotated
8
+
9
+ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
10
+
11
+ from src.api.dependencies import get_transcriber
12
+ from src.api.schemas import TranscribeResponse
13
+ from src.engine.transcriber import Transcriber
14
+
15
+ logger = logging.getLogger(__name__)
16
+ router = APIRouter()
17
+
18
+ SUPPORTED_LANGUAGES = {"bam", "ful"}
19
+ SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".webm"}
20
+ MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB
21
+
22
+
23
+ @router.post("/transcribe", response_model=TranscribeResponse)
24
+ async def transcribe_audio(
25
+ audio_file: Annotated[UploadFile, File(description="Audio file (wav/mp3/ogg/m4a/flac/webm)")],
26
+ language: Annotated[str, Form(description="Language code: 'bam' (Bambara) or 'ful' (Fula)")] = "bam",
27
+ transcriber: Transcriber = Depends(get_transcriber),
28
+ ) -> TranscribeResponse:
29
+ # Validate language
30
+ if language not in SUPPORTED_LANGUAGES:
31
+ raise HTTPException(
32
+ status_code=422,
33
+ detail=f"Unsupported language '{language}'. Supported: {sorted(SUPPORTED_LANGUAGES)}",
34
+ )
35
+
36
+ # Validate file extension
37
+ filename = audio_file.filename or "audio.wav"
38
+ ext = os.path.splitext(filename)[1].lower()
39
+ if ext not in SUPPORTED_EXTENSIONS:
40
+ raise HTTPException(
41
+ status_code=422,
42
+ detail=f"Unsupported file type '{ext}'. Supported: {sorted(SUPPORTED_EXTENSIONS)}",
43
+ )
44
+
45
+ # Read and size-check
46
+ audio_bytes = await audio_file.read()
47
+ if len(audio_bytes) > MAX_AUDIO_BYTES:
48
+ raise HTTPException(
49
+ status_code=413,
50
+ detail=f"File too large ({len(audio_bytes) / 1e6:.1f} MB). Max 10 MB.",
51
+ )
52
+
53
+ # Windows-safe temp file: delete=False + manual unlink in finally
54
+ tmp_path = None
55
+ try:
56
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
57
+ tmp.write(audio_bytes)
58
+ tmp_path = tmp.name
59
+
60
+ result = transcriber.transcribe_file(tmp_path, language)
61
+ except Exception as e:
62
+ logger.error("Transcription failed: %s", e, exc_info=True)
63
+ raise HTTPException(status_code=500, detail=str(e))
64
+ finally:
65
+ if tmp_path and os.path.exists(tmp_path):
66
+ os.unlink(tmp_path)
67
+
68
+ return TranscribeResponse(
69
+ text=result.text,
70
+ language=result.language,
71
+ duration_s=result.duration_s,
72
+ processing_time_ms=result.processing_time_ms,
73
+ confidence=result.confidence,
74
+ )
src/api/schemas.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic v2 request and response models for all API endpoints."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Literal, Optional
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class TranscribeResponse(BaseModel):
10
+ text: str
11
+ language: str
12
+ duration_s: float
13
+ processing_time_ms: int
14
+ confidence: Optional[float] = None
15
+
16
+
17
+ class IoTQueryResponse(BaseModel):
18
+ transcription: str
19
+ language: str
20
+ intent: dict
21
+ sensor_data: dict
22
+ voice_response: str
23
+ processing_time_ms: int
24
+
25
+
26
+ class HealthResponse(BaseModel):
27
+ status: str
28
+ model_loaded: bool
29
+ active_adapter: Optional[str]
30
+ adapters_available: list[str]
31
+ adapters_loaded: list[str]
32
+
33
+
34
+ class ErrorResponse(BaseModel):
35
+ error: str
36
+ detail: str
src/data/__init__.py ADDED
File without changes
src/data/agri_dictionary.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agricultural vocabulary for Bambara and Fula.
3
+ Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ import torch
10
+
11
+ if TYPE_CHECKING:
12
+ from transformers import WhisperProcessor
13
+
14
+ # Bambara (bam) agricultural vocabulary
15
+ BAMBARA_VOCAB: dict[str, str] = {
16
+ "sɛnɛ": "farming",
17
+ "jiriw": "trees",
18
+ "nɔgɔ": "soil",
19
+ "sani": "fertilizer",
20
+ "kogomali": "groundnut",
21
+ "kaba": "corn/maize",
22
+ "tiga": "peanut",
23
+ "ji": "water",
24
+ "sanji": "rain",
25
+ "teliman": "weather",
26
+ "suruku": "pest/predator",
27
+ "bunding": "soil/earth",
28
+ "sira": "path/way",
29
+ "foro": "field",
30
+ "dugu": "village/land",
31
+ "dibi": "darkness/shade",
32
+ "fanga": "strength/fertilizer",
33
+ "kungoloni": "insects/pests",
34
+ }
35
+
36
+ # Fula (ful / Fulfulde) agricultural vocabulary
37
+ FULA_VOCAB: dict[str, str] = {
38
+ "ngesa": "field",
39
+ "leydi": "land/soil",
40
+ "kosam": "milk",
41
+ "nagge": "cattle",
42
+ "leeɗe": "crops",
43
+ "ndiyam": "water",
44
+ "yeeso": "wind/weather",
45
+ "laabi": "road/way",
46
+ "demoore": "farming",
47
+ "hoore": "head/top",
48
+ "biñ-biñ": "insects/pests",
49
+ "fuɗorde": "sunrise/east field",
50
+ "ngaari": "bull",
51
+ "mbabba": "donkey",
52
+ "ladde": "bush/forest",
53
+ "wutte": "clothing/harvest",
54
+ }
55
+
56
+ LANGUAGE_VOCABS: dict[str, dict[str, str]] = {
57
+ "bam": BAMBARA_VOCAB,
58
+ "ful": FULA_VOCAB,
59
+ }
60
+
61
+
62
+ class AgriculturalDictionary:
63
+ """Converts agricultural vocabulary into decoder prompt token IDs for Whisper."""
64
+
65
+ def get_vocab(self, language: str) -> dict[str, str]:
66
+ if language not in LANGUAGE_VOCABS:
67
+ raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}")
68
+ return LANGUAGE_VOCABS[language]
69
+
70
+ def get_prompt_text(self, language: str) -> str:
71
+ """Return a comma-joined string of all terms, used as decoder text prompt."""
72
+ vocab = self.get_vocab(language)
73
+ return ", ".join(vocab.keys())
74
+
75
+ def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor:
76
+ """
77
+ Tokenize the vocabulary as a decoder prompt.
78
+ Pass this as `decoder_input_ids` or `prompt_ids` to model.generate()
79
+ to bias decoding toward known agricultural terms.
80
+ """
81
+ prompt_text = self.get_prompt_text(language)
82
+ token_ids = processor.tokenizer(
83
+ prompt_text,
84
+ return_tensors="pt",
85
+ add_special_tokens=False,
86
+ ).input_ids
87
+ return token_ids # shape: (1, N)
88
+
89
+ def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]:
90
+ """Return flat list of token IDs for all vocabulary terms."""
91
+ ids = self.build_prompt_ids(processor, language)
92
+ return ids[0].tolist()
src/data/augmentation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Field noise augmentation for West African farm environments.
3
+ Mixes clean speech with tractor, wind, and livestock audio samples.
4
+ Degrades gracefully to Gaussian noise when no .wav files are present.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class FieldNoiseAugmenter:
17
+ """
18
+ Applies audiomentations transforms that simulate noisy field conditions.
19
+ If the noise_dir has no .wav files, falls back to Gaussian noise only.
20
+ """
21
+
22
+ def __init__(self, noise_dir: str, config: dict) -> None:
23
+ self.noise_dir = Path(noise_dir)
24
+ self.config = config
25
+ self._compose = None
26
+ self._gaussian_only = False
27
+ self._build_pipeline()
28
+
29
+ def _build_pipeline(self) -> None:
30
+ try:
31
+ from audiomentations import (
32
+ AddBackgroundNoise,
33
+ AddGaussianNoise,
34
+ Compose,
35
+ RoomSimulator,
36
+ TimeStretch,
37
+ )
38
+ except ImportError:
39
+ logger.warning("audiomentations not installed — augmentation disabled.")
40
+ self._compose = None
41
+ return
42
+
43
+ snr_range = self.config.get("audio", {}).get("noise_snr_db_range", [5, 20])
44
+ prob = self.config.get("audio", {}).get("augmentation_prob", 0.6)
45
+
46
+ wav_files = list(self.noise_dir.glob("*.wav")) if self.noise_dir.exists() else []
47
+
48
+ transforms = []
49
+
50
+ if wav_files:
51
+ transforms.append(
52
+ AddBackgroundNoise(
53
+ sounds_path=str(self.noise_dir),
54
+ min_snr_db=float(snr_range[0]),
55
+ max_snr_db=float(snr_range[1]),
56
+ p=prob,
57
+ )
58
+ )
59
+ logger.info("FieldNoiseAugmenter: loaded %d noise files from %s", len(wav_files), self.noise_dir)
60
+ else:
61
+ logger.warning(
62
+ "FieldNoiseAugmenter: no .wav files found in %s — using Gaussian noise only. "
63
+ "Populate noise_samples/ for realistic field augmentation.",
64
+ self.noise_dir,
65
+ )
66
+ self._gaussian_only = True
67
+
68
+ transforms += [
69
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.3),
70
+ TimeStretch(min_rate=0.9, max_rate=1.1, leave_length_unchanged=True, p=0.2),
71
+ RoomSimulator(p=0.3),
72
+ ]
73
+
74
+ self._compose = Compose(transforms)
75
+
76
+ def augment(self, audio: np.ndarray, sr: int) -> np.ndarray:
77
+ """Apply augmentation pipeline to a float32 audio array."""
78
+ if self._compose is None:
79
+ return audio
80
+ return self._compose(samples=audio, sample_rate=sr)
81
+
82
+ def is_ready(self) -> bool:
83
+ """Returns True if augmentation is available (even Gaussian-only)."""
84
+ return self._compose is not None
src/data/feature_extractor.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers import WhisperProcessor
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ TARGET_SR = 16_000
20
+ MEL_FRAMES = 3000 # 30 seconds at 100 frames/sec
21
+ N_MELS = 80
22
+
23
+
24
+ class AudioFeatureExtractor:
25
+ """Wraps WhisperProcessor to extract and normalize audio features."""
26
+
27
+ def __init__(self, processor: "WhisperProcessor", config: dict) -> None:
28
+ self.processor = processor
29
+ self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR)
30
+
31
+ def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor:
32
+ """
33
+ Resample audio to 16kHz, extract log-mel features.
34
+ Returns tensor of shape (80, 3000).
35
+ """
36
+ if sr != TARGET_SR:
37
+ tensor = torch.from_numpy(audio).unsqueeze(0)
38
+ tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR)
39
+ audio = tensor.squeeze(0).numpy()
40
+
41
+ inputs = self.processor.feature_extractor(
42
+ audio,
43
+ sampling_rate=TARGET_SR,
44
+ return_tensors="pt",
45
+ )
46
+ features = inputs.input_features[0] # (80, 3000)
47
+ return features
48
+
49
+ def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor:
50
+ """Ensure features are exactly (80, 3000)."""
51
+ _, t = features.shape
52
+ if t < MEL_FRAMES:
53
+ pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype)
54
+ features = torch.cat([features, pad], dim=-1)
55
+ elif t > MEL_FRAMES:
56
+ features = features[:, :MEL_FRAMES]
57
+ return features
58
+
59
+
60
+ @dataclass
61
+ class DataCollatorSpeechSeq2SeqWithPadding:
62
+ """
63
+ Pads input_features to uniform length and label sequences with -100
64
+ (so they are ignored in the cross-entropy loss).
65
+ Compatible with HuggingFace Seq2SeqTrainer.
66
+ """
67
+ processor: Any
68
+ decoder_start_token_id: int
69
+
70
+ def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
71
+ # Separate input_features and labels
72
+ input_features = [{"input_features": f["input_features"]} for f in features]
73
+ label_features = [{"input_ids": f["labels"]} for f in features]
74
+
75
+ # Pad input features (processor handles this)
76
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
77
+
78
+ # Pad labels
79
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
80
+ labels = labels_batch["input_ids"].masked_fill(
81
+ labels_batch.attention_mask.ne(1), -100
82
+ )
83
+
84
+ # Remove decoder start token if it was prepended
85
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
86
+ labels = labels[:, 1:]
87
+
88
+ batch["labels"] = labels
89
+ return batch
src/data/waxal_loader.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads and preprocesses the google/waxal dataset for Bambara (bam) and Fula (ful).
3
+ Uses streaming to avoid downloading the full corpus before training.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from typing import TYPE_CHECKING, Callable, Iterator
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+ from datasets import load_dataset
14
+
15
+ if TYPE_CHECKING:
16
+ from datasets import Dataset, IterableDataset
17
+ from transformers import WhisperProcessor
18
+
19
+ from src.data.augmentation import FieldNoiseAugmenter
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # google/waxal column names
24
+ AUDIO_COL = "audio"
25
+ TEXT_COL = "transcription"
26
+ TARGET_SR = 16_000
27
+
28
+
29
+ class WaxalDataLoader:
30
+ """Streams the google/waxal dataset and prepares examples for Whisper training."""
31
+
32
+ def __init__(
33
+ self,
34
+ subset: str,
35
+ config: dict,
36
+ hf_token: str | None = None,
37
+ ) -> None:
38
+ if subset not in ("bam", "ful"):
39
+ raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'")
40
+ self.subset = subset
41
+ self.config = config
42
+ self.hf_token = hf_token
43
+
44
+ def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset":
45
+ """Return a single split of google/waxal."""
46
+ logger.info("Loading google/waxal subset=%s split=%s streaming=%s", self.subset, split, streaming)
47
+ ds = load_dataset(
48
+ "google/waxal",
49
+ self.subset,
50
+ split=split,
51
+ token=self.hf_token,
52
+ streaming=streaming,
53
+ trust_remote_code=True,
54
+ )
55
+ if streaming:
56
+ ds = ds.shuffle(seed=42, buffer_size=1000)
57
+ return ds
58
+
59
+ def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]:
60
+ """Return train / validation / test splits."""
61
+ splits = {}
62
+ for split in ("train", "validation", "test"):
63
+ try:
64
+ splits[split] = self.load_split(split, streaming=streaming)
65
+ except Exception:
66
+ logger.warning("Split '%s' not available for subset '%s'", split, self.subset)
67
+ return splits
68
+
69
+ def make_preprocess_fn(
70
+ self,
71
+ processor: "WhisperProcessor",
72
+ augmenter: "FieldNoiseAugmenter | None" = None,
73
+ ) -> Callable[[dict], dict]:
74
+ """Return a function that converts a raw Waxal example into model inputs."""
75
+
76
+ def preprocess(example: dict) -> dict:
77
+ # Extract and resample audio
78
+ audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32)
79
+ orig_sr: int = example[AUDIO_COL]["sampling_rate"]
80
+
81
+ if orig_sr != TARGET_SR:
82
+ tensor = torch.from_numpy(audio_array).unsqueeze(0)
83
+ tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR)
84
+ audio_array = tensor.squeeze(0).numpy()
85
+
86
+ # Apply field noise augmentation if provided
87
+ if augmenter is not None and augmenter.is_ready():
88
+ audio_array = augmenter.augment(audio_array, TARGET_SR)
89
+
90
+ # Extract log-mel features
91
+ inputs = processor.feature_extractor(
92
+ audio_array,
93
+ sampling_rate=TARGET_SR,
94
+ return_tensors="np",
95
+ )
96
+ input_features = inputs.input_features[0] # shape (80, 3000)
97
+
98
+ # Tokenize transcript
99
+ text: str = example[TEXT_COL]
100
+ labels = processor.tokenizer(text, return_tensors="np").input_ids[0]
101
+
102
+ return {
103
+ "input_features": input_features,
104
+ "labels": labels,
105
+ }
106
+
107
+ return preprocess
108
+
109
+ def iter_processed(
110
+ self,
111
+ processor: "WhisperProcessor",
112
+ split: str = "train",
113
+ augmenter: "FieldNoiseAugmenter | None" = None,
114
+ ) -> Iterator[dict]:
115
+ """Yield preprocessed examples one at a time (streaming)."""
116
+ ds = self.load_split(split, streaming=True)
117
+ fn = self.make_preprocess_fn(processor, augmenter)
118
+ for example in ds:
119
+ yield fn(example)
src/engine/__init__.py ADDED
File without changes
src/engine/adapter_manager.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA adapter hot-swap manager.
3
+
4
+ Uses PEFT's multi-adapter API:
5
+ - model.load_adapter(path, adapter_name=lang) — first load (~2s per adapter)
6
+ - model.set_adapter(lang) — subsequent swap (~50ms)
7
+
8
+ This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights,
9
+ vs reloading the full 1.5GB model per language.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ from peft import PeftModel
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers import WhisperForConditionalGeneration
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AdapterManager:
26
+ """Manages registration and hot-swapping of LoRA language adapters."""
27
+
28
+ def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None:
29
+ self._base_model = base_model
30
+ self._config = config
31
+ self._registry: dict[str, str] = {} # language_code -> adapter_path
32
+ self._peft_model: PeftModel | None = None
33
+ self._active: str | None = None
34
+
35
+ def register(self, language: str, adapter_path: str) -> None:
36
+ """Register an adapter path. Does not load it yet."""
37
+ path = Path(adapter_path)
38
+ if not path.exists():
39
+ logger.warning(
40
+ "Adapter path '%s' for language '%s' does not exist. "
41
+ "Run training first, or check the path.",
42
+ adapter_path, language,
43
+ )
44
+ self._registry[language] = str(path)
45
+ logger.info("Registered adapter '%s' → %s", language, adapter_path)
46
+
47
+ def load_adapter(self, language: str) -> None:
48
+ """
49
+ Load an adapter into the model for the first time.
50
+ Slow (~2s): reads adapter weights from disk.
51
+ Subsequent activate() calls reuse the already-loaded weights.
52
+ """
53
+ if language not in self._registry:
54
+ raise KeyError(f"No adapter registered for language '{language}'. "
55
+ f"Available: {list(self._registry)}")
56
+
57
+ adapter_path = self._registry[language]
58
+
59
+ if self._peft_model is None:
60
+ # First adapter: wrap the base model with PeftModel
61
+ logger.info("Wrapping base model with first adapter '%s'...", language)
62
+ self._peft_model = PeftModel.from_pretrained(
63
+ self._base_model,
64
+ adapter_path,
65
+ adapter_name=language,
66
+ )
67
+ else:
68
+ # Subsequent adapters: load into the existing PeftModel
69
+ logger.info("Loading adapter '%s' into existing PeftModel...", language)
70
+ self._peft_model.load_adapter(adapter_path, adapter_name=language)
71
+
72
+ self._active = language
73
+ logger.info("Adapter '%s' loaded and active.", language)
74
+
75
+ def activate(self, language: str) -> None:
76
+ """
77
+ Hot-swap to a previously loaded adapter (~50ms).
78
+ Call load_adapter() first if this adapter hasn't been loaded.
79
+ """
80
+ if self._peft_model is None:
81
+ self.load_adapter(language)
82
+ return
83
+
84
+ loaded = set(self._peft_model.peft_config.keys())
85
+ if language not in loaded:
86
+ self.load_adapter(language)
87
+ return
88
+
89
+ self._peft_model.set_adapter(language)
90
+ self._active = language
91
+ logger.debug("Hot-swapped to adapter '%s'.", language)
92
+
93
+ def get_model(self) -> "WhisperForConditionalGeneration | PeftModel":
94
+ """Return the PeftModel (or base model if no adapter loaded yet)."""
95
+ return self._peft_model if self._peft_model is not None else self._base_model
96
+
97
+ def get_active(self) -> str | None:
98
+ return self._active
99
+
100
+ def list_available(self) -> list[str]:
101
+ return list(self._registry.keys())
102
+
103
+ def list_loaded(self) -> list[str]:
104
+ if self._peft_model is None:
105
+ return []
106
+ return list(self._peft_model.peft_config.keys())
src/engine/transcriber.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Public inference interface.
3
+ Accepts audio as a file path or numpy array and returns transcribed text.
4
+ Handles chunking for audio longer than 30 seconds.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import os
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ if TYPE_CHECKING:
19
+ from src.engine.adapter_manager import AdapterManager
20
+ from src.engine.whisper_base import WhisperBackbone
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ TARGET_SR = 16_000
25
+
26
+
27
+ @dataclass
28
+ class TranscriptionResult:
29
+ text: str
30
+ language: str
31
+ duration_s: float
32
+ processing_time_ms: int
33
+ confidence: float | None = None
34
+
35
+
36
+ class Transcriber:
37
+ """
38
+ Composes WhisperBackbone + AdapterManager to provide a simple transcription API.
39
+ Thread-safety: Not thread-safe by design — use one worker process.
40
+ """
41
+
42
+ def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None:
43
+ self._backbone = backbone
44
+ self._adapter_manager = adapter_manager
45
+
46
+ def transcribe(
47
+ self,
48
+ audio: np.ndarray,
49
+ sample_rate: int,
50
+ language: str,
51
+ use_agri_prompt: bool = True,
52
+ ) -> TranscriptionResult:
53
+ """
54
+ Transcribe a float32 audio array.
55
+ For audio > 30s, uses transformers pipeline with chunking.
56
+ """
57
+ t0 = time.time()
58
+
59
+ # Activate the correct language adapter
60
+ self._adapter_manager.activate(language)
61
+
62
+ processor = self._backbone.processor
63
+ model = self._adapter_manager.get_model()
64
+ device = self._backbone.device
65
+ duration_s = len(audio) / sample_rate
66
+
67
+ if duration_s <= 30.0:
68
+ text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device)
69
+ else:
70
+ text = self._transcribe_long(audio, sample_rate, language, processor, model, device)
71
+
72
+ elapsed_ms = int((time.time() - t0) * 1000)
73
+ return TranscriptionResult(
74
+ text=text.strip(),
75
+ language=language,
76
+ duration_s=duration_s,
77
+ processing_time_ms=elapsed_ms,
78
+ )
79
+
80
+ def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult:
81
+ """Load audio from disk and transcribe."""
82
+ import librosa
83
+ audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
84
+ return self.transcribe(audio, sr, language)
85
+
86
+ def _transcribe_chunk(
87
+ self,
88
+ audio: np.ndarray,
89
+ sr: int,
90
+ language: str,
91
+ processor,
92
+ model,
93
+ device: str,
94
+ ) -> str:
95
+ """Transcribe a single ≤30s chunk."""
96
+ inputs = processor.feature_extractor(
97
+ audio, sampling_rate=sr, return_tensors="pt"
98
+ )
99
+ input_features = inputs.input_features.to(device)
100
+ if device == "cuda":
101
+ input_features = input_features.half()
102
+
103
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
104
+ language=language, task="transcribe"
105
+ )
106
+
107
+ with torch.no_grad():
108
+ predicted_ids = model.generate(
109
+ input_features,
110
+ forced_decoder_ids=forced_decoder_ids,
111
+ max_new_tokens=128,
112
+ )
113
+
114
+ return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
115
+
116
+ def _transcribe_long(
117
+ self,
118
+ audio: np.ndarray,
119
+ sr: int,
120
+ language: str,
121
+ processor,
122
+ model,
123
+ device: str,
124
+ ) -> str:
125
+ """Chunk audio into 30s segments and concatenate transcriptions."""
126
+ chunk_size = TARGET_SR * 30
127
+ chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)]
128
+ parts = []
129
+ for chunk in chunks:
130
+ text = self._transcribe_chunk(chunk, sr, language, processor, model, device)
131
+ parts.append(text)
132
+ return " ".join(parts)
src/engine/whisper_base.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads the Whisper backbone model and processor once.
3
+ All other modules receive references to this shared instance.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import yaml
12
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class WhisperBackbone:
18
+ """Singleton-style loader for the Whisper base model and processor."""
19
+
20
+ def __init__(self, config_path: str = "configs/base_config.yaml") -> None:
21
+ config_path = Path(config_path)
22
+ with open(config_path) as f:
23
+ cfg = yaml.safe_load(f)
24
+ self._model_id: str = cfg["model"]["id"]
25
+ self._model: WhisperForConditionalGeneration | None = None
26
+ self._processor: WhisperProcessor | None = None
27
+ self._device: str = "cpu"
28
+
29
+ def load(self, device: str = "cuda", hf_token: str | None = None) -> None:
30
+ """Load model and processor into memory. Call once at startup."""
31
+ self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu"
32
+ logger.info("Loading %s on %s", self._model_id, self._device)
33
+
34
+ self._processor = WhisperProcessor.from_pretrained(
35
+ self._model_id,
36
+ token=hf_token,
37
+ )
38
+
39
+ dtype = torch.float16 if self._device == "cuda" else torch.float32
40
+ self._model = WhisperForConditionalGeneration.from_pretrained(
41
+ self._model_id,
42
+ torch_dtype=dtype,
43
+ token=hf_token,
44
+ ).to(self._device)
45
+
46
+ self._model.eval()
47
+ logger.info("Model loaded successfully (dtype=%s, device=%s)", dtype, self._device)
48
+
49
+ @property
50
+ def model(self) -> WhisperForConditionalGeneration:
51
+ if self._model is None:
52
+ raise RuntimeError("Call WhisperBackbone.load() before accessing the model.")
53
+ return self._model
54
+
55
+ @property
56
+ def processor(self) -> WhisperProcessor:
57
+ if self._processor is None:
58
+ raise RuntimeError("Call WhisperBackbone.load() before accessing the processor.")
59
+ return self._processor
60
+
61
+ @property
62
+ def device(self) -> str:
63
+ return self._device
64
+
65
+ @property
66
+ def model_id(self) -> str:
67
+ return self._model_id
68
+
69
+ def free(self) -> None:
70
+ """Release GPU memory."""
71
+ del self._model
72
+ del self._processor
73
+ self._model = None
74
+ self._processor = None
75
+ if torch.cuda.is_available():
76
+ torch.cuda.empty_cache()
77
+ logger.info("Backbone freed from memory.")
src/iot/__init__.py ADDED
File without changes
src/iot/intent_parser.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maps transcribed Bambara/Fula text to structured intents for IoT sensor queries.
3
+ Uses keyword matching (no ML required for v1).
4
+ Confidence = fraction of intent keywords present in the transcription.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+
11
+ @dataclass
12
+ class Intent:
13
+ action: str # e.g., "check_soil", "check_weather"
14
+ entity: str # e.g., "soil", "weather"
15
+ parameters: dict = field(default_factory=dict)
16
+ confidence: float = 0.0
17
+
18
+
19
+ # Intent keyword taxonomy for Bambara (bam) and Fula (ful)
20
+ INTENT_KEYWORDS: dict[str, dict[str, list[str]]] = {
21
+ "check_soil": {
22
+ "bam": ["bunding", "nɔgɔ", "dugu", "foro", "sani"],
23
+ "ful": ["leydi", "ngesa", "ladde"],
24
+ },
25
+ "check_weather": {
26
+ "bam": ["teliman", "sanji", "dibi", "sira"],
27
+ "ful": ["yeeso", "fuɗorde"],
28
+ },
29
+ "irrigation_status": {
30
+ "bam": ["ji", "sanji", "foro"],
31
+ "ful": ["ndiyam", "ngesa"],
32
+ },
33
+ "pest_alert": {
34
+ "bam": ["kungoloni", "suruku"],
35
+ "ful": ["biñ-biñ"],
36
+ },
37
+ }
38
+
39
+ INTENT_ENTITIES = {
40
+ "check_soil": "soil",
41
+ "check_weather": "weather",
42
+ "irrigation_status": "irrigation",
43
+ "pest_alert": "pest",
44
+ }
45
+
46
+
47
+ class IntentParser:
48
+ """Parses a transcription string into a structured Intent."""
49
+
50
+ def parse(self, text: str, language: str) -> Intent:
51
+ """
52
+ Find the best matching intent by counting keyword overlaps.
53
+ Returns the highest-confidence intent.
54
+ """
55
+ text_lower = text.lower()
56
+ best_action = "unknown"
57
+ best_confidence = 0.0
58
+
59
+ for action, lang_keywords in INTENT_KEYWORDS.items():
60
+ keywords = lang_keywords.get(language, [])
61
+ if not keywords:
62
+ continue
63
+
64
+ matches = sum(1 for kw in keywords if kw in text_lower)
65
+ confidence = matches / len(keywords)
66
+
67
+ if confidence > best_confidence:
68
+ best_confidence = confidence
69
+ best_action = action
70
+
71
+ return Intent(
72
+ action=best_action,
73
+ entity=INTENT_ENTITIES.get(best_action, "unknown"),
74
+ confidence=round(best_confidence, 3),
75
+ )
src/iot/sensor_bridge.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fetches sensor data (soil moisture, weather, irrigation) from the IoT backend API.
3
+ Falls back to synthetic mock data when SENSOR_API_URL is not configured.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import random
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from typing import TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from src.iot.intent_parser import Intent
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class SensorData:
21
+ sensor_type: str
22
+ values: dict[str, float]
23
+ timestamp: str
24
+ unit: str = ""
25
+
26
+
27
+ class SensorBridge:
28
+ """Async bridge to IoT sensor API. Uses mock data when no API URL is configured."""
29
+
30
+ def __init__(self, sensor_api_url: str | None = None, timeout_s: float = 5.0) -> None:
31
+ self.sensor_api_url = sensor_api_url
32
+ self.timeout_s = timeout_s
33
+ self._mock_mode = not sensor_api_url
34
+
35
+ if self._mock_mode:
36
+ logger.info("SensorBridge: running in MOCK mode (set SENSOR_API_URL to use real sensors).")
37
+
38
+ async def fetch(self, intent: "Intent", field_id: str | None = None) -> SensorData:
39
+ """Dispatch to the correct sensor fetch method based on intent entity."""
40
+ action = intent.action
41
+ if action == "check_soil":
42
+ return await self.get_soil_data(field_id or "default")
43
+ elif action == "check_weather":
44
+ return await self.get_weather(field_id or "default")
45
+ elif action == "irrigation_status":
46
+ return await self.get_irrigation(field_id or "default")
47
+ elif action == "pest_alert":
48
+ return await self.get_pest_status(field_id or "default")
49
+ else:
50
+ return SensorData(
51
+ sensor_type="unknown",
52
+ values={},
53
+ timestamp=datetime.utcnow().isoformat(),
54
+ )
55
+
56
+ async def get_soil_data(self, location_id: str) -> SensorData:
57
+ if self._mock_mode:
58
+ return SensorData(
59
+ sensor_type="soil",
60
+ values={
61
+ "moisture_pct": round(random.uniform(25, 65), 1),
62
+ "ph": round(random.uniform(5.5, 7.5), 1),
63
+ "nitrogen_ppm": round(random.uniform(10, 40), 1),
64
+ "temperature_c": round(random.uniform(24, 35), 1),
65
+ },
66
+ timestamp=datetime.utcnow().isoformat(),
67
+ )
68
+ return await self._get(f"/sensors/soil/{location_id}", "soil")
69
+
70
+ async def get_weather(self, location_id: str) -> SensorData:
71
+ if self._mock_mode:
72
+ return SensorData(
73
+ sensor_type="weather",
74
+ values={
75
+ "temperature_c": round(random.uniform(28, 42), 1),
76
+ "humidity_pct": round(random.uniform(20, 80), 1),
77
+ "wind_speed_kmh": round(random.uniform(0, 25), 1),
78
+ "rain_probability_pct": round(random.uniform(0, 100), 1),
79
+ },
80
+ timestamp=datetime.utcnow().isoformat(),
81
+ )
82
+ return await self._get(f"/sensors/weather/{location_id}", "weather")
83
+
84
+ async def get_irrigation(self, field_id: str) -> SensorData:
85
+ if self._mock_mode:
86
+ return SensorData(
87
+ sensor_type="irrigation",
88
+ values={
89
+ "flow_rate_lph": round(random.uniform(0, 500), 1),
90
+ "pressure_bar": round(random.uniform(1.0, 4.0), 2),
91
+ "active": float(random.choice([0, 1])),
92
+ "last_irrigation_h_ago": round(random.uniform(1, 48), 1),
93
+ },
94
+ timestamp=datetime.utcnow().isoformat(),
95
+ )
96
+ return await self._get(f"/sensors/irrigation/{field_id}", "irrigation")
97
+
98
+ async def get_pest_status(self, field_id: str) -> SensorData:
99
+ if self._mock_mode:
100
+ return SensorData(
101
+ sensor_type="pest",
102
+ values={
103
+ "trap_count_24h": float(random.randint(0, 50)),
104
+ "alert_level": float(random.randint(0, 3)), # 0=none 1=low 2=medium 3=high
105
+ },
106
+ timestamp=datetime.utcnow().isoformat(),
107
+ )
108
+ return await self._get(f"/sensors/pest/{field_id}", "pest")
109
+
110
+ async def _get(self, path: str, sensor_type: str) -> SensorData:
111
+ import httpx
112
+ url = f"{self.sensor_api_url}{path}"
113
+ async with httpx.AsyncClient(timeout=self.timeout_s) as client:
114
+ response = await client.get(url)
115
+ response.raise_for_status()
116
+ data = response.json()
117
+ return SensorData(
118
+ sensor_type=sensor_type,
119
+ values=data.get("values", data),
120
+ timestamp=data.get("timestamp", datetime.utcnow().isoformat()),
121
+ )
src/iot/voice_responder.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generates voice response text from sensor data in the farmer's own language.
3
+ Supports Bambara (bam), Fula (ful), French (fr), and English (en).
4
+ Bambara/Fula templates use short sentences (≤15 words) for best MMS-TTS quality.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ if TYPE_CHECKING:
11
+ from src.iot.intent_parser import Intent
12
+ from src.iot.sensor_bridge import SensorData
13
+
14
+ # Alert thresholds
15
+ SOIL_MOISTURE_LOW = 30.0 # Below this → immediate irrigation recommended
16
+ SOIL_MOISTURE_HIGH = 70.0 # Above this → drainage warning
17
+ SOIL_PH_LOW = 5.5
18
+ SOIL_PH_HIGH = 7.5
19
+ TEMP_HIGH = 38.0
20
+ PEST_ALERT_HIGH = 2 # Alert level ≥ 2 → warning
21
+
22
+ # ── Bambara templates (≤6 words per sentence for clear MMS-TTS output) ───────
23
+ BAMBARA_TEMPLATES = {
24
+ "soil_moisture_low": "Bunding ji dɔgɔ. I ka foro ji.",
25
+ "soil_moisture_high": "Ji ca kojugu. Foro ma fɛ.",
26
+ "soil_ph_low": "Bunding kɔnɔ jugu. Kalisi fara a kan.",
27
+ "soil_ph_high": "Bunding kɔnɔ tɛmɛ. Soufre fara a kan.",
28
+ "weather_hot": "Teliman gbɛlɛ. Tile ma sigi.",
29
+ "rain_likely": "Sanji bɛ na. Sɔrɔ jɔ.",
30
+ "pest_high": "Dɔgɔw bɛ foro kɔnɔ. Bɔ u.",
31
+ "irrigation_needed": "Foro fɛ ji. Ji sira yɔrɔ.",
32
+ "irrigation_active": "Ji bɛ taa. A bɛ kɛ cogo di.",
33
+ "default": "Kabako jumanw sɔrɔla.",
34
+ }
35
+
36
+ # ── Fula templates (≤6 words per sentence for clear MMS-TTS output) ──────────
37
+ FULA_TEMPLATES = {
38
+ "soil_moisture_low": "Leydi ndiyam famɗi. Wado ngesa.",
39
+ "soil_moisture_high": "Ndiyam heewi. Leydi famɗaali.",
40
+ "soil_ph_low": "Leydi suurii. Waɗ kalisi.",
41
+ "soil_ph_high": "Leydi alkalii. Waɗ soufre.",
42
+ "weather_hot": "Nguleeki heewi. Muusal.",
43
+ "rain_likely": "Ndiyam wadata. Loosu ngesa.",
44
+ "pest_high": "Biñ-biñ ngesa nder. Fiil ɗen.",
45
+ "irrigation_needed": "Ngesa fɛɗɛli ndiyam. Wado.",
46
+ "irrigation_active": "Ndiyam wona jooni.",
47
+ "default": "Humpito juuti waɗaama.",
48
+ }
49
+
50
+
51
+ class VoiceResponder:
52
+ """Converts sensor readings into actionable voice messages in the farmer's language."""
53
+
54
+ def __init__(self, language: str = "fr") -> None:
55
+ self.language = language
56
+
57
+ def generate_response(self, intent: "Intent", sensor_data: "SensorData") -> str:
58
+ if self.language == "bam":
59
+ return self._bambara_response(sensor_data)
60
+ elif self.language == "ful":
61
+ return self._fula_response(sensor_data)
62
+ else:
63
+ return self._french_response(sensor_data)
64
+
65
+ # ── Bambara ──────────────────────────────────────────────────────────────
66
+
67
+ def _bambara_response(self, sensor_data: "SensorData") -> str:
68
+ t = sensor_data.sensor_type
69
+ v = sensor_data.values
70
+ T = BAMBARA_TEMPLATES
71
+
72
+ if t == "soil":
73
+ moisture = v.get("moisture_pct")
74
+ if moisture is not None:
75
+ if moisture < SOIL_MOISTURE_LOW:
76
+ return T["soil_moisture_low"]
77
+ elif moisture > SOIL_MOISTURE_HIGH:
78
+ return T["soil_moisture_high"]
79
+ ph = v.get("ph")
80
+ if ph is not None:
81
+ if ph < SOIL_PH_LOW:
82
+ return T["soil_ph_low"]
83
+ elif ph > SOIL_PH_HIGH:
84
+ return T["soil_ph_high"]
85
+
86
+ elif t == "weather":
87
+ temp = v.get("temperature_c")
88
+ rain = v.get("rain_probability_pct")
89
+ if temp is not None and temp > TEMP_HIGH:
90
+ return T["weather_hot"]
91
+ if rain is not None and rain > 70:
92
+ return T["rain_likely"]
93
+
94
+ elif t == "irrigation":
95
+ last = v.get("last_irrigation_h_ago")
96
+ active = v.get("active")
97
+ if active:
98
+ return T["irrigation_active"]
99
+ if last is not None and last > 24:
100
+ return T["irrigation_needed"]
101
+
102
+ elif t == "pest":
103
+ level = int(v.get("alert_level", 0))
104
+ if level >= PEST_ALERT_HIGH:
105
+ return T["pest_high"]
106
+
107
+ return T["default"]
108
+
109
+ # ── Fula ─────────────────────────────────────────────────────────────────
110
+
111
+ def _fula_response(self, sensor_data: "SensorData") -> str:
112
+ t = sensor_data.sensor_type
113
+ v = sensor_data.values
114
+ T = FULA_TEMPLATES
115
+
116
+ if t == "soil":
117
+ moisture = v.get("moisture_pct")
118
+ if moisture is not None:
119
+ if moisture < SOIL_MOISTURE_LOW:
120
+ return T["soil_moisture_low"]
121
+ elif moisture > SOIL_MOISTURE_HIGH:
122
+ return T["soil_moisture_high"]
123
+ ph = v.get("ph")
124
+ if ph is not None:
125
+ if ph < SOIL_PH_LOW:
126
+ return T["soil_ph_low"]
127
+ elif ph > SOIL_PH_HIGH:
128
+ return T["soil_ph_high"]
129
+
130
+ elif t == "weather":
131
+ temp = v.get("temperature_c")
132
+ rain = v.get("rain_probability_pct")
133
+ if temp is not None and temp > TEMP_HIGH:
134
+ return T["weather_hot"]
135
+ if rain is not None and rain > 70:
136
+ return T["rain_likely"]
137
+
138
+ elif t == "irrigation":
139
+ active = v.get("active")
140
+ last = v.get("last_irrigation_h_ago")
141
+ if active:
142
+ return T["irrigation_active"]
143
+ if last is not None and last > 24:
144
+ return T["irrigation_needed"]
145
+
146
+ elif t == "pest":
147
+ level = int(v.get("alert_level", 0))
148
+ if level >= PEST_ALERT_HIGH:
149
+ return T["pest_high"]
150
+
151
+ return T["default"]
152
+
153
+ # ── French (original) ─────────────────────────────────────────────────────
154
+
155
+ def _french_response(self, sensor_data: "SensorData") -> str:
156
+ t = sensor_data.sensor_type
157
+ v = sensor_data.values
158
+ if t == "soil":
159
+ return self._soil_response(v)
160
+ elif t == "weather":
161
+ return self._weather_response(v)
162
+ elif t == "irrigation":
163
+ return self._irrigation_response(v)
164
+ elif t == "pest":
165
+ return self._pest_response(v)
166
+ else:
167
+ return "Données du capteur non disponibles pour le moment."
168
+
169
+ def _soil_response(self, v: dict) -> str:
170
+ parts = []
171
+ moisture = v.get("moisture_pct")
172
+ ph = v.get("ph")
173
+ temp = v.get("temperature_c")
174
+ nitrogen = v.get("nitrogen_ppm")
175
+
176
+ if moisture is not None:
177
+ parts.append(f"Humidité du sol : {moisture:.0f}%.")
178
+ if moisture < SOIL_MOISTURE_LOW:
179
+ parts.append("Irrigation recommandée immédiatement.")
180
+ elif moisture > SOIL_MOISTURE_HIGH:
181
+ parts.append("Sol trop humide, risque d'engorgement.")
182
+
183
+ if ph is not None:
184
+ parts.append(f"pH du sol : {ph:.1f}.")
185
+ if ph < SOIL_PH_LOW:
186
+ parts.append("Sol trop acide — envisagez un amendement calcaire.")
187
+ elif ph > SOIL_PH_HIGH:
188
+ parts.append("Sol trop alcalin — un apport de soufre peut aider.")
189
+
190
+ if temp is not None:
191
+ parts.append(f"Température du sol : {temp:.0f}°C.")
192
+
193
+ if nitrogen is not None:
194
+ parts.append(f"Azote disponible : {nitrogen:.0f} ppm.")
195
+ if nitrogen < 15:
196
+ parts.append("Niveau d'azote faible — envisagez un engrais azoté.")
197
+
198
+ return " ".join(parts) if parts else "Données du sol reçues."
199
+
200
+ def _weather_response(self, v: dict) -> str:
201
+ parts = []
202
+ temp = v.get("temperature_c")
203
+ humidity = v.get("humidity_pct")
204
+ wind = v.get("wind_speed_kmh")
205
+ rain = v.get("rain_probability_pct")
206
+
207
+ if temp is not None:
208
+ parts.append(f"Température : {temp:.0f}°C.")
209
+ if temp > TEMP_HIGH:
210
+ parts.append("Chaleur excessive — évitez les travaux aux heures les plus chaudes.")
211
+
212
+ if humidity is not None:
213
+ parts.append(f"Humidité de l'air : {humidity:.0f}%.")
214
+
215
+ if wind is not None:
216
+ parts.append(f"Vent : {wind:.0f} km/h.")
217
+
218
+ if rain is not None:
219
+ parts.append(f"Probabilité de pluie : {rain:.0f}%.")
220
+ if rain > 70:
221
+ parts.append("Pluie probable — reportez les traitements pesticides.")
222
+
223
+ return " ".join(parts) if parts else "Données météo reçues."
224
+
225
+ def _irrigation_response(self, v: dict) -> str:
226
+ parts = []
227
+ active = v.get("active")
228
+ last = v.get("last_irrigation_h_ago")
229
+ flow = v.get("flow_rate_lph")
230
+
231
+ if active is not None:
232
+ state = "en marche" if active else "arrêtée"
233
+ parts.append(f"Irrigation {state}.")
234
+
235
+ if flow is not None and active:
236
+ parts.append(f"Débit : {flow:.0f} litres par heure.")
237
+
238
+ if last is not None:
239
+ parts.append(f"Dernière irrigation il y a {last:.0f} heures.")
240
+ if last > 24:
241
+ parts.append("Plus de 24 heures sans irrigation — vérifiez les besoins en eau.")
242
+
243
+ return " ".join(parts) if parts else "Statut d'irrigation reçu."
244
+
245
+ def _pest_response(self, v: dict) -> str:
246
+ level = int(v.get("alert_level", 0))
247
+ count = v.get("trap_count_24h")
248
+
249
+ level_labels = {0: "aucune", 1: "faible", 2: "modérée", 3: "élevée"}
250
+ label = level_labels.get(level, "inconnue")
251
+
252
+ parts = [f"Présence d'insectes nuisibles : niveau {label}."]
253
+
254
+ if count is not None:
255
+ parts.append(f"{count:.0f} insectes capturés en 24 heures.")
256
+
257
+ if level >= PEST_ALERT_HIGH:
258
+ parts.append("Traitement recommandé — consultez un agent agricole.")
259
+
260
+ return " ".join(parts)
src/optimization/__init__.py ADDED
File without changes
src/optimization/onnx_exporter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Merges LoRA adapter weights into the backbone and exports to ONNX.
3
+ Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime).
4
+
5
+ Requires: optimum[onnxruntime]
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from peft import PeftModel
15
+ from transformers import WhisperProcessor
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ONNXExporter:
21
+ """Merges a LoRA PeftModel into its base model and exports to ONNX."""
22
+
23
+ def merge_and_export(
24
+ self,
25
+ peft_model: "PeftModel",
26
+ processor: "WhisperProcessor",
27
+ output_dir: str,
28
+ language: str,
29
+ ) -> Path:
30
+ """
31
+ 1. Merge LoRA weights into base model (merge_and_unload)
32
+ 2. Export merged model to ONNX via optimum
33
+ Returns the output directory path.
34
+ """
35
+ output_path = Path(output_dir) / language
36
+ output_path.mkdir(parents=True, exist_ok=True)
37
+
38
+ logger.info("Merging LoRA adapter '%s' into base model...", language)
39
+ merged_model = peft_model.merge_and_unload()
40
+ merged_model.eval()
41
+
42
+ logger.info("Exporting to ONNX: %s", output_path)
43
+ self._export_with_optimum(merged_model, processor, str(output_path))
44
+
45
+ return output_path
46
+
47
+ def _export_with_optimum(
48
+ self,
49
+ merged_model,
50
+ processor: "WhisperProcessor",
51
+ output_dir: str,
52
+ ) -> None:
53
+ """Use optimum's ONNX export pipeline."""
54
+ from optimum.exporters.onnx import main_export
55
+
56
+ # Save merged model to a temp directory first
57
+ import tempfile
58
+
59
+ with tempfile.TemporaryDirectory() as tmp_dir:
60
+ logger.info("Saving merged model to temp dir for export...")
61
+ merged_model.save_pretrained(tmp_dir)
62
+ processor.save_pretrained(tmp_dir)
63
+
64
+ logger.info("Running optimum ONNX export...")
65
+ main_export(
66
+ model_name_or_path=tmp_dir,
67
+ output=output_dir,
68
+ task="automatic-speech-recognition",
69
+ opset=17,
70
+ optimize="O2",
71
+ )
72
+
73
+ logger.info("ONNX export complete: %s", output_dir)
74
+
75
+ def validate(
76
+ self,
77
+ onnx_dir: str,
78
+ processor: "WhisperProcessor",
79
+ test_audio_arrays: list,
80
+ sample_rate: int = 16_000,
81
+ reference_texts: list[str] | None = None,
82
+ ) -> dict:
83
+ """
84
+ Run inference with the exported ONNX model and compute WER vs. references.
85
+ """
86
+ import numpy as np
87
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
88
+
89
+ logger.info("Validating ONNX model at %s...", onnx_dir)
90
+ ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir)
91
+
92
+ transcriptions = []
93
+ for audio in test_audio_arrays:
94
+ inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
95
+ outputs = ort_model.generate(inputs.input_features)
96
+ text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
97
+ transcriptions.append(text)
98
+
99
+ result = {"transcriptions": transcriptions}
100
+
101
+ if reference_texts:
102
+ import jiwer
103
+ wer = jiwer.wer(reference_texts, transcriptions)
104
+ result["wer"] = round(wer, 4)
105
+
106
+ return result
src/optimization/quantizer.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BitsAndBytes quantization for GPU-constrained deployment.
3
+ 4-bit NF4: reduces Whisper-large-v3-turbo from ~3GB to ~1GB VRAM.
4
+ 8-bit: intermediate option with less accuracy loss.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import time
10
+ from typing import TYPE_CHECKING
11
+
12
+ import torch
13
+ from transformers import BitsAndBytesConfig, WhisperForConditionalGeneration, WhisperProcessor
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def load_4bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration:
22
+ """Load Whisper with 4-bit NF4 quantization. Reduces VRAM to ~1GB."""
23
+ bnb_config = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_quant_type="nf4",
26
+ bnb_4bit_compute_dtype=torch.float16,
27
+ bnb_4bit_use_double_quant=True,
28
+ )
29
+ logger.info("Loading %s with 4-bit NF4 quantization...", model_id)
30
+ model = WhisperForConditionalGeneration.from_pretrained(
31
+ model_id,
32
+ quantization_config=bnb_config,
33
+ device_map="auto",
34
+ token=hf_token,
35
+ )
36
+ return model
37
+
38
+
39
+ def load_8bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration:
40
+ """Load Whisper with 8-bit quantization. Reduces VRAM to ~1.5GB."""
41
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
42
+ logger.info("Loading %s with 8-bit quantization...", model_id)
43
+ model = WhisperForConditionalGeneration.from_pretrained(
44
+ model_id,
45
+ quantization_config=bnb_config,
46
+ device_map="auto",
47
+ token=hf_token,
48
+ )
49
+ return model
50
+
51
+
52
+ class ModelQuantizer:
53
+ """Benchmarks quantized vs full-precision models."""
54
+
55
+ def __init__(self, model_id: str, hf_token: str | None = None) -> None:
56
+ self.model_id = model_id
57
+ self.hf_token = hf_token
58
+
59
+ def benchmark(
60
+ self,
61
+ model: WhisperForConditionalGeneration,
62
+ processor: WhisperProcessor,
63
+ test_audio_arrays: list,
64
+ sample_rate: int = 16_000,
65
+ ) -> dict:
66
+ """Measure latency and memory for a list of audio arrays."""
67
+ import numpy as np
68
+
69
+ device = next(model.parameters()).device
70
+ latencies = []
71
+
72
+ for audio in test_audio_arrays:
73
+ inputs = processor.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt")
74
+ features = inputs.input_features.to(device)
75
+
76
+ if device.type == "cuda":
77
+ torch.cuda.synchronize()
78
+ t0 = time.perf_counter()
79
+
80
+ with torch.no_grad():
81
+ model.generate(features, max_new_tokens=50)
82
+
83
+ if device.type == "cuda":
84
+ torch.cuda.synchronize()
85
+ latencies.append((time.perf_counter() - t0) * 1000)
86
+
87
+ result = {
88
+ "mean_latency_ms": round(sum(latencies) / len(latencies), 1),
89
+ "max_latency_ms": round(max(latencies), 1),
90
+ }
91
+
92
+ if torch.cuda.is_available():
93
+ result["vram_allocated_gb"] = round(torch.cuda.memory_allocated() / 1e9, 2)
94
+
95
+ return result
src/optimization/tflite_converter.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas).
3
+ Note: Whisper's encoder and decoder are exported as separate TFLite models and
4
+ orchestrated together at inference time.
5
+
6
+ Requires: onnx-tf, tensorflow (install separately — large dependencies)
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class TFLiteConverter:
17
+ """Converts ONNX Whisper models to TFLite format for edge deployment."""
18
+
19
+ def convert(
20
+ self,
21
+ onnx_encoder_path: str,
22
+ onnx_decoder_path: str,
23
+ output_dir: str,
24
+ quantize: bool = True,
25
+ ) -> dict[str, Path]:
26
+ """
27
+ Convert encoder and decoder ONNX models to TFLite.
28
+ Returns paths to the generated .tflite files.
29
+ """
30
+ output_path = Path(output_dir)
31
+ output_path.mkdir(parents=True, exist_ok=True)
32
+
33
+ encoder_tflite = output_path / "encoder.tflite"
34
+ decoder_tflite = output_path / "decoder.tflite"
35
+
36
+ logger.info("Converting encoder ONNX → TFLite...")
37
+ self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize)
38
+
39
+ logger.info("Converting decoder ONNX → TFLite...")
40
+ self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize)
41
+
42
+ return {"encoder": encoder_tflite, "decoder": decoder_tflite}
43
+
44
+ def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None:
45
+ """Convert a single ONNX model to TFLite via onnx-tf + tensorflow."""
46
+ try:
47
+ import onnx
48
+ import onnx_tf
49
+ import tensorflow as tf
50
+ except ImportError as e:
51
+ raise ImportError(
52
+ "TFLite conversion requires onnx-tf and tensorflow. "
53
+ "Install with: pip install onnx-tf tensorflow"
54
+ ) from e
55
+
56
+ import tempfile
57
+
58
+ # Step 1: ONNX → TensorFlow SavedModel
59
+ with tempfile.TemporaryDirectory() as tmp_dir:
60
+ onnx_model = onnx.load(onnx_path)
61
+ tf_rep = onnx_tf.backend.prepare(onnx_model)
62
+ tf_rep.export_graph(tmp_dir)
63
+
64
+ # Step 2: TF SavedModel → TFLite
65
+ converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
66
+
67
+ if quantize:
68
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
69
+
70
+ tflite_model = converter.convert()
71
+
72
+ with open(output_path, "wb") as f:
73
+ f.write(tflite_model)
74
+
75
+ size_mb = Path(output_path).stat().st_size / 1e6
76
+ logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb)
src/training/__init__.py ADDED
File without changes
src/training/callbacks.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom HuggingFace Trainer callbacks:
3
+ - EarlyStoppingOnWER: stops training when WER stops improving
4
+ - AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING
11
+
12
+ from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
13
+
14
+ if TYPE_CHECKING:
15
+ pass
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class EarlyStoppingOnWER(TrainerCallback):
21
+ """
22
+ Stops training if eval WER does not improve by min_delta over `patience` evaluations.
23
+ """
24
+
25
+ def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
26
+ self.patience = patience
27
+ self.min_delta = min_delta
28
+ self._best_wer: float = float("inf")
29
+ self._no_improve_count: int = 0
30
+
31
+ def on_evaluate(
32
+ self,
33
+ args: TrainingArguments,
34
+ state: TrainerState,
35
+ control: TrainerControl,
36
+ metrics: dict,
37
+ **kwargs,
38
+ ) -> None:
39
+ wer = metrics.get("eval_wer")
40
+ if wer is None:
41
+ return
42
+
43
+ if wer < self._best_wer - self.min_delta:
44
+ self._best_wer = wer
45
+ self._no_improve_count = 0
46
+ logger.info("WER improved to %.4f", wer)
47
+ else:
48
+ self._no_improve_count += 1
49
+ logger.info(
50
+ "WER %.4f did not improve (best: %.4f). No-improve count: %d/%d",
51
+ wer, self._best_wer, self._no_improve_count, self.patience,
52
+ )
53
+ if self._no_improve_count >= self.patience:
54
+ logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience)
55
+ control.should_training_stop = True
56
+
57
+
58
+ class AdapterCheckpointCallback(TrainerCallback):
59
+ """
60
+ Saves only the LoRA adapter weights on each checkpoint event.
61
+ Adapter weights are ~50MB vs ~3GB for the full model.
62
+ """
63
+
64
+ def __init__(self, adapter_output_dir: str) -> None:
65
+ self.adapter_output_dir = Path(adapter_output_dir)
66
+
67
+ def on_save(
68
+ self,
69
+ args: TrainingArguments,
70
+ state: TrainerState,
71
+ control: TrainerControl,
72
+ model,
73
+ **kwargs,
74
+ ) -> None:
75
+ checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}"
76
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
77
+
78
+ # model is a PeftModel — save only adapter weights
79
+ if hasattr(model, "save_pretrained"):
80
+ model.save_pretrained(str(checkpoint_dir))
81
+ logger.info("Adapter checkpoint saved: %s", checkpoint_dir)
82
+ else:
83
+ logger.warning("Model does not have save_pretrained — skipping adapter checkpoint.")