Spaces:
Running
Running
| import argparse | |
| import json | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| from src.config import DECODED_DIR, PROMPTS_DIR | |
| from src.generator import Generator | |
| from src.region_registry import get_region_description | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Decode localized DAS into natural dialogue." | |
| ) | |
| parser.add_argument( | |
| "--input_path", | |
| type=str, | |
| required=True, | |
| help="Path to localized JSON file.", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default=None, | |
| help="Path to save decoded output JSON.", | |
| ) | |
| parser.add_argument( | |
| "--language", | |
| type=str, | |
| required=True, | |
| help="Target language label, e.g. 'Swahili'.", | |
| ) | |
| parser.add_argument( | |
| "--region", | |
| type=str, | |
| default="", | |
| help="Optional target region/community, e.g. 'Kenya - Nairobi'.", | |
| ) | |
| parser.add_argument( | |
| "--decode_prompt", | |
| type=str, | |
| default=str(PROMPTS_DIR / "das_decode.md"), | |
| help="Path to DAS decode prompt.", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default=None, | |
| help="Model alias from model_registry.py", | |
| ) | |
| parser.add_argument( | |
| "--max_instances", | |
| type=int, | |
| default=None, | |
| help="Optional cap on number of dialogues to process.", | |
| ) | |
| parser.add_argument( | |
| "--start_idx", | |
| type=int, | |
| default=0, | |
| help="Optional start index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--end_idx", | |
| type=int, | |
| default=None, | |
| help="Optional end index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--dont_use_cached", | |
| action="store_true", | |
| help="Disable cached prompt responses.", | |
| ) | |
| return parser.parse_args() | |
| def load_json(path: str) -> Any: | |
| return json.loads(Path(path).read_text(encoding="utf-8")) | |
| def save_json(path: str, data: Any) -> None: | |
| output_path = Path(path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text( | |
| json.dumps(data, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| def normalize_speaker_id(speaker_id: Any) -> str: | |
| speaker = str(speaker_id) | |
| if speaker == "1": | |
| return "A" | |
| if speaker == "2": | |
| return "B" | |
| return speaker | |
| def preprocess_localized_conversation(localized_das: List[Dict[str, Any]]) -> str: | |
| formatted_turns: List[str] = [] | |
| for idx, turn in enumerate(localized_das, start=1): | |
| speaker = normalize_speaker_id(turn.get("speaker_id", "A")) | |
| functions = turn.get("functions", "") | |
| if isinstance(functions, list): | |
| functions_str = "; ".join(str(f) for f in functions) | |
| else: | |
| functions_str = str(functions) | |
| formatted_turns.append(f"{idx}: {speaker}.{functions_str}") | |
| return "\n".join(formatted_turns) | |
| def get_original_dialogue(item: Dict[str, Any]) -> List[str]: | |
| if "original" in item and isinstance(item["original"], list): | |
| return item["original"] | |
| if "conversation" in item and isinstance(item["conversation"], list): | |
| return item["conversation"] | |
| if "dialogue" in item and isinstance(item["dialogue"], list): | |
| return item["dialogue"] | |
| if "utterances" in item and isinstance(item["utterances"], list): | |
| return item["utterances"] | |
| return [] | |
| def preprocess_decode_input( | |
| data: List[Dict[str, Any]], | |
| language: str, | |
| region: str, | |
| ) -> List[Dict[str, Any]]: | |
| decode_desc = get_region_description(region, "decode", language) or "" | |
| processed: List[Dict[str, Any]] = [] | |
| for item in data: | |
| if "localized_das" not in item: | |
| raise ValueError("Each input item must contain 'localized_das'.") | |
| if "localized_context" not in item: | |
| raise ValueError("Each input item must contain 'localized_context'.") | |
| if not isinstance(item["localized_das"], list): | |
| raise ValueError( | |
| f"localized_das must be a list, got {type(item['localized_das']).__name__}" | |
| ) | |
| new_item = dict(item) | |
| new_item["language"] = language | |
| new_item["region"] = region | |
| new_item["region_description"] = decode_desc | |
| new_item["turns"] = preprocess_localized_conversation(item["localized_das"]) | |
| new_item["localized_context"] = item["localized_context"] | |
| new_item["context"] = item["localized_context"] | |
| processed.append(new_item) | |
| return processed | |
| def strip_code_fences(text: str) -> str: | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", text) | |
| text = re.sub(r"\n?```$", "", text) | |
| return text.strip() | |
| def parse_numbered_dialogue_string(raw: str) -> List[str]: | |
| raw = strip_code_fences(raw) | |
| # Split on numbered turns like "1: ...", "2. ..." | |
| matches = list(re.finditer(r"(?m)^\s*(\d+)[\:\.]\s*", raw)) | |
| if not matches: | |
| lines = [line.strip() for line in raw.splitlines() if line.strip()] | |
| return lines | |
| turns: List[str] = [] | |
| for idx, match in enumerate(matches): | |
| start = match.end() | |
| end = matches[idx + 1].start() if idx + 1 < len(matches) else len(raw) | |
| turn_text = raw[start:end].strip() | |
| if turn_text: | |
| turns.append(turn_text) | |
| return turns | |
| def normalize_generated_conversation(generated_conversation: Any) -> List[str]: | |
| # Case 1: expected list of objects with text | |
| if isinstance(generated_conversation, list): | |
| text_only_dialogue: List[str] = [] | |
| for turn in generated_conversation: | |
| if isinstance(turn, dict) and "text" in turn: | |
| text_only_dialogue.append(str(turn["text"]).strip()) | |
| elif isinstance(turn, str): | |
| text_only_dialogue.append(turn.strip()) | |
| else: | |
| raise ValueError( | |
| f"Unsupported generated_conversation list item: {turn}" | |
| ) | |
| return text_only_dialogue | |
| # Case 2: model returned one big numbered string | |
| if isinstance(generated_conversation, str): | |
| return parse_numbered_dialogue_string(generated_conversation) | |
| raise ValueError( | |
| f"Unsupported generated_conversation type: {type(generated_conversation).__name__}" | |
| ) | |
| def merge_decoded_responses( | |
| base_data: List[Dict[str, Any]], | |
| responses: List[str], | |
| language: str, | |
| ) -> List[Dict[str, Any]]: | |
| merged: List[Dict[str, Any]] = [] | |
| decoded_key = f"decoded_{language.strip().lower().replace(' ', '_')}" | |
| for item, response_text in zip(base_data, responses): | |
| if response_text is None: | |
| print(f"[Decode] Skipping item with failed generation") | |
| continue | |
| response_json = Generator.parse_json_response(response_text) | |
| if "generated_conversation" not in response_json: | |
| raise ValueError( | |
| f"Missing 'generated_conversation' in model response:\n{response_text}" | |
| ) | |
| generated_conversation = response_json["generated_conversation"] | |
| text_only_dialogue = normalize_generated_conversation(generated_conversation) | |
| output_item: Dict[str, Any] = {} | |
| if "dialogue_id" in item: | |
| output_item["dialogue_id"] = item["dialogue_id"] | |
| elif "id" in item: | |
| output_item["dialogue_id"] = item["id"] | |
| output_item["original"] = get_original_dialogue(item) | |
| output_item[decoded_key] = text_only_dialogue | |
| merged.append(output_item) | |
| return merged | |
| def default_output_path(input_path: str, language: str, region: str) -> str: | |
| stem = Path(input_path).stem | |
| suffix_parts = [language.strip().lower().replace(" ", "_")] | |
| if region.strip(): | |
| suffix_parts.append(region.strip().lower().replace(" ", "_").replace("/", "_")) | |
| suffix = "_".join(suffix_parts) | |
| return str(DECODED_DIR / f"{stem}_{suffix}_decoded.json") | |
| def main() -> None: | |
| args = parse_args() | |
| raw_data = load_json(args.input_path) | |
| if not isinstance(raw_data, list): | |
| raise ValueError("Input JSON must be a list of dialogue objects.") | |
| sliced_data = raw_data[args.start_idx:args.end_idx] | |
| if args.max_instances is not None: | |
| sliced_data = sliced_data[: args.max_instances] | |
| output_path = args.output_path or default_output_path( | |
| args.input_path, | |
| args.language, | |
| args.region, | |
| ) | |
| generator = Generator( | |
| model_alias=args.model, | |
| use_cache=not args.dont_use_cached, | |
| ) | |
| processed_data = preprocess_decode_input( | |
| data=sliced_data, | |
| language=args.language, | |
| region=args.region, | |
| ) | |
| print(f"[Decode] Building decoded dialogue for {len(processed_data)} items...") | |
| decode_prompts, decode_response_format = generator.build_prompts( | |
| args.decode_prompt, | |
| processed_data, | |
| ) | |
| decode_responses = generator.prompt( | |
| prompts=decode_prompts, | |
| response_format=decode_response_format, | |
| dont_use_cached=args.dont_use_cached, | |
| skip_failures=True, | |
| ) | |
| final_data = merge_decoded_responses( | |
| processed_data, | |
| decode_responses, | |
| args.language, | |
| ) | |
| save_json(output_path, final_data) | |
| print(f"Saved decoded data to: {output_path}") | |
| generator.print_usage_summary(stage="Decode") | |
| if __name__ == "__main__": | |
| main() |