afridialeval / src /decode.py
millicentochieng's picture
Upload folder using huggingface_hub
e2b8b61 verified
import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, List
from src.config import DECODED_DIR, PROMPTS_DIR
from src.generator import Generator
from src.region_registry import get_region_description
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Decode localized DAS into natural dialogue."
)
parser.add_argument(
"--input_path",
type=str,
required=True,
help="Path to localized JSON file.",
)
parser.add_argument(
"--output_path",
type=str,
default=None,
help="Path to save decoded output JSON.",
)
parser.add_argument(
"--language",
type=str,
required=True,
help="Target language label, e.g. 'Swahili'.",
)
parser.add_argument(
"--region",
type=str,
default="",
help="Optional target region/community, e.g. 'Kenya - Nairobi'.",
)
parser.add_argument(
"--decode_prompt",
type=str,
default=str(PROMPTS_DIR / "das_decode.md"),
help="Path to DAS decode prompt.",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Model alias from model_registry.py",
)
parser.add_argument(
"--max_instances",
type=int,
default=None,
help="Optional cap on number of dialogues to process.",
)
parser.add_argument(
"--start_idx",
type=int,
default=0,
help="Optional start index for slicing input data.",
)
parser.add_argument(
"--end_idx",
type=int,
default=None,
help="Optional end index for slicing input data.",
)
parser.add_argument(
"--dont_use_cached",
action="store_true",
help="Disable cached prompt responses.",
)
return parser.parse_args()
def load_json(path: str) -> Any:
return json.loads(Path(path).read_text(encoding="utf-8"))
def save_json(path: str, data: Any) -> None:
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(
json.dumps(data, indent=2, ensure_ascii=False),
encoding="utf-8",
)
def normalize_speaker_id(speaker_id: Any) -> str:
speaker = str(speaker_id)
if speaker == "1":
return "A"
if speaker == "2":
return "B"
return speaker
def preprocess_localized_conversation(localized_das: List[Dict[str, Any]]) -> str:
formatted_turns: List[str] = []
for idx, turn in enumerate(localized_das, start=1):
speaker = normalize_speaker_id(turn.get("speaker_id", "A"))
functions = turn.get("functions", "")
if isinstance(functions, list):
functions_str = "; ".join(str(f) for f in functions)
else:
functions_str = str(functions)
formatted_turns.append(f"{idx}: {speaker}.{functions_str}")
return "\n".join(formatted_turns)
def get_original_dialogue(item: Dict[str, Any]) -> List[str]:
if "original" in item and isinstance(item["original"], list):
return item["original"]
if "conversation" in item and isinstance(item["conversation"], list):
return item["conversation"]
if "dialogue" in item and isinstance(item["dialogue"], list):
return item["dialogue"]
if "utterances" in item and isinstance(item["utterances"], list):
return item["utterances"]
return []
def preprocess_decode_input(
data: List[Dict[str, Any]],
language: str,
region: str,
) -> List[Dict[str, Any]]:
decode_desc = get_region_description(region, "decode", language) or ""
processed: List[Dict[str, Any]] = []
for item in data:
if "localized_das" not in item:
raise ValueError("Each input item must contain 'localized_das'.")
if "localized_context" not in item:
raise ValueError("Each input item must contain 'localized_context'.")
if not isinstance(item["localized_das"], list):
raise ValueError(
f"localized_das must be a list, got {type(item['localized_das']).__name__}"
)
new_item = dict(item)
new_item["language"] = language
new_item["region"] = region
new_item["region_description"] = decode_desc
new_item["turns"] = preprocess_localized_conversation(item["localized_das"])
new_item["localized_context"] = item["localized_context"]
new_item["context"] = item["localized_context"]
processed.append(new_item)
return processed
def strip_code_fences(text: str) -> str:
text = text.strip()
if text.startswith("```"):
text = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", text)
text = re.sub(r"\n?```$", "", text)
return text.strip()
def parse_numbered_dialogue_string(raw: str) -> List[str]:
raw = strip_code_fences(raw)
# Split on numbered turns like "1: ...", "2. ..."
matches = list(re.finditer(r"(?m)^\s*(\d+)[\:\.]\s*", raw))
if not matches:
lines = [line.strip() for line in raw.splitlines() if line.strip()]
return lines
turns: List[str] = []
for idx, match in enumerate(matches):
start = match.end()
end = matches[idx + 1].start() if idx + 1 < len(matches) else len(raw)
turn_text = raw[start:end].strip()
if turn_text:
turns.append(turn_text)
return turns
def normalize_generated_conversation(generated_conversation: Any) -> List[str]:
# Case 1: expected list of objects with text
if isinstance(generated_conversation, list):
text_only_dialogue: List[str] = []
for turn in generated_conversation:
if isinstance(turn, dict) and "text" in turn:
text_only_dialogue.append(str(turn["text"]).strip())
elif isinstance(turn, str):
text_only_dialogue.append(turn.strip())
else:
raise ValueError(
f"Unsupported generated_conversation list item: {turn}"
)
return text_only_dialogue
# Case 2: model returned one big numbered string
if isinstance(generated_conversation, str):
return parse_numbered_dialogue_string(generated_conversation)
raise ValueError(
f"Unsupported generated_conversation type: {type(generated_conversation).__name__}"
)
def merge_decoded_responses(
base_data: List[Dict[str, Any]],
responses: List[str],
language: str,
) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
decoded_key = f"decoded_{language.strip().lower().replace(' ', '_')}"
for item, response_text in zip(base_data, responses):
if response_text is None:
print(f"[Decode] Skipping item with failed generation")
continue
response_json = Generator.parse_json_response(response_text)
if "generated_conversation" not in response_json:
raise ValueError(
f"Missing 'generated_conversation' in model response:\n{response_text}"
)
generated_conversation = response_json["generated_conversation"]
text_only_dialogue = normalize_generated_conversation(generated_conversation)
output_item: Dict[str, Any] = {}
if "dialogue_id" in item:
output_item["dialogue_id"] = item["dialogue_id"]
elif "id" in item:
output_item["dialogue_id"] = item["id"]
output_item["original"] = get_original_dialogue(item)
output_item[decoded_key] = text_only_dialogue
merged.append(output_item)
return merged
def default_output_path(input_path: str, language: str, region: str) -> str:
stem = Path(input_path).stem
suffix_parts = [language.strip().lower().replace(" ", "_")]
if region.strip():
suffix_parts.append(region.strip().lower().replace(" ", "_").replace("/", "_"))
suffix = "_".join(suffix_parts)
return str(DECODED_DIR / f"{stem}_{suffix}_decoded.json")
def main() -> None:
args = parse_args()
raw_data = load_json(args.input_path)
if not isinstance(raw_data, list):
raise ValueError("Input JSON must be a list of dialogue objects.")
sliced_data = raw_data[args.start_idx:args.end_idx]
if args.max_instances is not None:
sliced_data = sliced_data[: args.max_instances]
output_path = args.output_path or default_output_path(
args.input_path,
args.language,
args.region,
)
generator = Generator(
model_alias=args.model,
use_cache=not args.dont_use_cached,
)
processed_data = preprocess_decode_input(
data=sliced_data,
language=args.language,
region=args.region,
)
print(f"[Decode] Building decoded dialogue for {len(processed_data)} items...")
decode_prompts, decode_response_format = generator.build_prompts(
args.decode_prompt,
processed_data,
)
decode_responses = generator.prompt(
prompts=decode_prompts,
response_format=decode_response_format,
dont_use_cached=args.dont_use_cached,
skip_failures=True,
)
final_data = merge_decoded_responses(
processed_data,
decode_responses,
args.language,
)
save_json(output_path, final_data)
print(f"Saved decoded data to: {output_path}")
generator.print_usage_summary(stage="Decode")
if __name__ == "__main__":
main()