| import argparse |
| import json |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| from openai import OpenAI |
|
|
|
|
| PROMPT_PATH = Path("/home/mshahidul/readctrl/prompts/support_check_data_generate") |
| API_FILE = Path("/home/mshahidul/api_new.json") |
| INPUT_PATH = Path( |
| "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json" |
| ) |
| OUTPUT_DIR = Path("/home/mshahidul/readctrl/data/extracting_subclaim") |
| DEFAULT_OUTPUT_FILE = "synthetic_subclaims_first200.json" |
|
|
|
|
| def load_openai_client() -> OpenAI: |
| with API_FILE.open("r", encoding="utf-8") as f: |
| api_keys = json.load(f) |
| openai_api_key = api_keys["openai"] |
| return OpenAI(api_key=openai_api_key) |
|
|
|
|
| def normalize_difficulty(label: str) -> str: |
| mapping = { |
| "low_health_literacy": "easy", |
| "intermediate_health_literacy": "intermediate", |
| "proficient_health_literacy": "hard", |
| } |
| return mapping.get(label, "intermediate") |
|
|
|
|
| def clean_json_response(raw: str) -> Dict[str, Any]: |
| cleaned = raw.strip().replace("```json", "").replace("```", "").strip() |
| return json.loads(cleaned) |
|
|
|
|
| def make_prompt(template: str, item: Dict[str, Any]) -> str: |
| payload = { |
| "passage_id": f"{item.get('doc_id', 'unknown')}_{item.get('label', 'unknown')}", |
| "passage": item.get("diff_label_texts", ""), |
| "difficulty_label": normalize_difficulty(item.get("label", "")), |
| } |
| return ( |
| f"{template}\n\n" |
| "Now generate output for this input:\n" |
| f"{json.dumps(payload, ensure_ascii=False, indent=2)}\n" |
| ) |
|
|
|
|
| def load_input_data(limit: int) -> List[Dict[str, Any]]: |
| with INPUT_PATH.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
| return data[:limit] |
|
|
|
|
| def load_existing(path: Path) -> List[Dict[str, Any]]: |
| if not path.exists(): |
| return [] |
| with path.open("r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def save_json(path: Path, data: List[Dict[str, Any]]) -> None: |
| with path.open("w", encoding="utf-8") as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Generate synthetic claim-verification subclaim dataset from diff_label_texts." |
| ) |
| parser.add_argument("--limit", type=int, default=200, help="Number of input items to process.") |
| parser.add_argument("--model", type=str, default="gpt-5", help="OpenAI model name.") |
| parser.add_argument( |
| "--output-file", |
| type=str, |
| default=DEFAULT_OUTPUT_FILE, |
| help="Output filename inside output directory.", |
| ) |
| parser.add_argument( |
| "--save-every", |
| type=int, |
| default=2, |
| help="Persist results after every N processed items.", |
| ) |
| args = parser.parse_args() |
|
|
| with PROMPT_PATH.open("r", encoding="utf-8") as f: |
| prompt_template = f.read().strip() |
|
|
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| output_path = OUTPUT_DIR / args.output_file |
|
|
| data = load_input_data(limit=args.limit) |
| results = load_existing(output_path) |
| done_keys = {item.get("source_key") for item in results} |
|
|
| client = load_openai_client() |
|
|
| for idx, item in enumerate(data): |
| source_key = f"{item.get('doc_id')}_{item.get('label')}_{idx}" |
| if source_key in done_keys: |
| continue |
|
|
| prompt = make_prompt(prompt_template, item) |
| try: |
| response = client.chat.completions.create( |
| model=args.model, |
| messages=[ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": prompt}, |
| ], |
| ) |
| content = response.choices[0].message.content or "" |
| generated = clean_json_response(content) |
| except Exception as e: |
| generated = {"error": str(e), "raw_response": response.choices[0].message.content if "response" in locals() else ""} |
|
|
| results.append( |
| { |
| "source_key": source_key, |
| "doc_id": item.get("doc_id"), |
| "source_label": item.get("label"), |
| "difficulty_label": normalize_difficulty(item.get("label", "")), |
| "generated": generated, |
| } |
| ) |
| done_keys.add(source_key) |
|
|
| if len(results) % args.save_every == 0: |
| save_json(output_path, results) |
| print(f"Saved {len(results)} rows to {output_path}") |
|
|
| save_json(output_path, results) |
| print(f"Done. Saved {len(results)} rows to {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|