ground-zero / src /data /agri_dictionary.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Agricultural vocabulary for Bambara and Fula.
Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from transformers import WhisperProcessor
# Bambara (bam) agricultural vocabulary
BAMBARA_VOCAB: dict[str, str] = {
"sɛnɛ": "farming",
"jiriw": "trees",
"nɔgɔ": "soil",
"sani": "fertilizer",
"kogomali": "groundnut",
"kaba": "corn/maize",
"tiga": "peanut",
"ji": "water",
"sanji": "rain",
"teliman": "weather",
"suruku": "pest/predator",
"bunding": "soil/earth",
"sira": "path/way",
"foro": "field",
"dugu": "village/land",
"dibi": "darkness/shade",
"fanga": "strength/fertilizer",
"kungoloni": "insects/pests",
}
# Fula (ful / Fulfulde) agricultural vocabulary
FULA_VOCAB: dict[str, str] = {
"ngesa": "field",
"leydi": "land/soil",
"kosam": "milk",
"nagge": "cattle",
"leeɗe": "crops",
"ndiyam": "water",
"yeeso": "wind/weather",
"laabi": "road/way",
"demoore": "farming",
"hoore": "head/top",
"biñ-biñ": "insects/pests",
"fuɗorde": "sunrise/east field",
"ngaari": "bull",
"mbabba": "donkey",
"ladde": "bush/forest",
"wutte": "clothing/harvest",
}
LANGUAGE_VOCABS: dict[str, dict[str, str]] = {
"bam": BAMBARA_VOCAB,
"ful": FULA_VOCAB,
}
class AgriculturalDictionary:
"""Converts agricultural vocabulary into decoder prompt token IDs for Whisper."""
def get_vocab(self, language: str) -> dict[str, str]:
if language not in LANGUAGE_VOCABS:
raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}")
return LANGUAGE_VOCABS[language]
def get_prompt_text(self, language: str) -> str:
"""Return a comma-joined string of all terms, used as decoder text prompt."""
vocab = self.get_vocab(language)
return ", ".join(vocab.keys())
def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor:
"""
Tokenize the vocabulary as a decoder prompt.
Pass this as `decoder_input_ids` or `prompt_ids` to model.generate()
to bias decoding toward known agricultural terms.
"""
prompt_text = self.get_prompt_text(language)
token_ids = processor.tokenizer(
prompt_text,
return_tensors="pt",
add_special_tokens=False,
).input_ids
return token_ids # shape: (1, N)
def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]:
"""Return flat list of token IDs for all vocabulary terms."""
ids = self.build_prompt_ids(processor, language)
return ids[0].tolist()