| import argparse |
| import json |
| import os |
| import time |
| import urllib.error |
| import urllib.request |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional |
|
|
| from tqdm import tqdm |
|
|
|
|
| api_file = "/home/mshahidul/api_new.json" |
| with open(api_file, "r", encoding="utf-8") as f: |
| api_keys = json.load(f) |
|
|
| DEFAULT_API_BASE = "https://api.openai.com/v1" |
| DEFAULT_INPUT_PATH = ( |
| "/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/combine/" |
| "verified_combined_0-80.json" |
| ) |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result" |
| DEFAULT_PROMPT_LOW_PATH = ( |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low" |
| ) |
| DEFAULT_PROMPT_INTERMEDIATE_PATH = ( |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate" |
| ) |
| DEFAULT_PROMPT_PROFICIENT_PATH = ( |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient" |
| ) |
| DEFAULT_MODELS = "gpt-5-mini,gpt-5-nano" |
|
|
| VALID_LABELS = { |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Generate outputs with gpt-5-mini and gpt-5-nano using " |
| "verified_combined dataset and literacy-level prompts." |
| ) |
| ) |
| parser.add_argument("--api-base", default=os.environ.get("OPENAI_API_BASE", DEFAULT_API_BASE)) |
| parser.add_argument( |
| "--api-key", |
| default=os.environ.get("OPENAI_API_KEY", api_keys["openai"]), |
| ) |
| parser.add_argument("--models", default=DEFAULT_MODELS, help="Comma-separated model list.") |
| parser.add_argument("--input-path", default=DEFAULT_INPUT_PATH) |
| parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| parser.add_argument("--prompt-low-path", default=DEFAULT_PROMPT_LOW_PATH) |
| parser.add_argument( |
| "--prompt-intermediate-path", |
| default=DEFAULT_PROMPT_INTERMEDIATE_PATH, |
| ) |
| parser.add_argument( |
| "--prompt-proficient-path", |
| default=DEFAULT_PROMPT_PROFICIENT_PATH, |
| ) |
| parser.add_argument( |
| "--max-samples", |
| type=int, |
| default=-1, |
| help="Use -1 for all rows.", |
| ) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--timeout-seconds", type=int, default=120) |
| parser.add_argument("--max-retries", type=int, default=2) |
| parser.add_argument("--retry-wait-seconds", type=float, default=2.0) |
| return parser.parse_args() |
|
|
|
|
| def check_api_base(api_base: str, api_key: str, timeout_seconds: int) -> None: |
| models_url = api_base.rstrip("/") + "/models" |
| req = urllib.request.Request(models_url, method="GET") |
| if api_key: |
| req.add_header("Authorization", f"Bearer {api_key}") |
| try: |
| with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: |
| if resp.status >= 400: |
| raise RuntimeError( |
| f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| ) |
| except urllib.error.URLError as exc: |
| raise ConnectionError( |
| "Cannot reach OpenAI-compatible endpoint. " |
| f"api_base={api_base}. Check network/API base/API key." |
| ) from exc |
|
|
|
|
| def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: |
| prompt_path_by_label = { |
| "low_health_literacy": args.prompt_low_path, |
| "intermediate_health_literacy": args.prompt_intermediate_path, |
| "proficient_health_literacy": args.prompt_proficient_path, |
| } |
| templates: Dict[str, str] = {} |
| for label, path in prompt_path_by_label.items(): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
| with open(path, "r", encoding="utf-8") as f: |
| templates[label] = f.read() |
| return templates |
|
|
|
|
| def infer_source_lang(fulltext: str) -> str: |
| if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): |
| return "English" |
| return "Unknown" |
|
|
|
|
| def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: |
| return ( |
| template.replace("{source_lang}", source_lang) |
| .replace("{gold_summary}", summary) |
| .replace("{full_text}", fulltext) |
| ) |
|
|
|
|
| def load_verified_rows(path: str) -> List[Dict[str, Any]]: |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Input file not found: {path}") |
| with open(path, "r", encoding="utf-8") as f: |
| parsed = json.load(f) |
| if not isinstance(parsed, list): |
| raise ValueError(f"Expected top-level JSON array in {path}") |
| return [row for row in parsed if isinstance(row, dict)] |
|
|
|
|
| def parse_models(models_arg: str) -> List[str]: |
| models = [m.strip() for m in models_arg.split(",") if m.strip()] |
| if not models: |
| raise ValueError("No models provided. Example: --models gpt-5-mini,gpt-5-nano") |
| return models |
|
|
|
|
| def _clean_json_block(text: str) -> str: |
| cleaned = text.strip() |
| if "```json" in cleaned: |
| cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() |
| elif "```" in cleaned: |
| cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() |
| return cleaned |
|
|
|
|
| def extract_generated_text(raw_response: str, expected_label: str) -> str: |
| cleaned = _clean_json_block(raw_response) |
| try: |
| parsed = json.loads(cleaned) |
| except json.JSONDecodeError: |
| return raw_response.strip() |
|
|
| if isinstance(parsed, dict): |
| value = parsed.get(expected_label) |
| if isinstance(value, str) and value.strip(): |
| return value.strip() |
| return raw_response.strip() |
|
|
|
|
| def call_chat_completion( |
| *, |
| api_base: str, |
| api_key: str, |
| model: str, |
| prompt: str, |
| temperature: float, |
| timeout_seconds: int, |
| max_retries: int, |
| retry_wait_seconds: float, |
| ) -> str: |
| url = api_base.rstrip("/") + "/chat/completions" |
| payload = { |
| "model": model, |
| "messages": [{"role": "user", "content": prompt}], |
| } |
| data = json.dumps(payload).encode("utf-8") |
|
|
| last_error: Optional[Exception] = None |
| for attempt in range(max_retries + 1): |
| req = urllib.request.Request(url, data=data, method="POST") |
| req.add_header("Content-Type", "application/json") |
| if api_key: |
| req.add_header("Authorization", f"Bearer {api_key}") |
| try: |
| with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: |
| body = resp.read().decode("utf-8") |
| parsed = json.loads(body) |
| return str(parsed["choices"][0]["message"]["content"]).strip() |
| except urllib.error.HTTPError as exc: |
| retriable = exc.code in (408, 409, 429, 500, 502, 503, 504) |
| last_error = exc |
| if attempt < max_retries and retriable: |
| time.sleep(retry_wait_seconds) |
| continue |
| raise |
| except (urllib.error.URLError, KeyError, IndexError, json.JSONDecodeError) as exc: |
| last_error = exc |
| if attempt < max_retries: |
| time.sleep(retry_wait_seconds) |
| continue |
| raise |
|
|
| if last_error: |
| raise last_error |
| raise RuntimeError("Unknown error during chat completion call.") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if not args.api_key: |
| raise ValueError("Missing API key. Set OPENAI_API_KEY or pass --api-key.") |
|
|
| for path in ( |
| args.prompt_low_path, |
| args.prompt_intermediate_path, |
| args.prompt_proficient_path, |
| ): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
|
|
| check_api_base(args.api_base, args.api_key, args.timeout_seconds) |
| models = parse_models(args.models) |
| templates = load_prompt_templates(args) |
| rows = load_verified_rows(args.input_path) |
|
|
| parsed_items: List[Dict[str, Any]] = [] |
| for idx, row in enumerate(rows): |
| gold_label = str(row.get("label", "")).strip() |
| fulltext = str(row.get("fulltext", "")).strip() |
| summary = str(row.get("summary", "")).strip() |
| if gold_label not in VALID_LABELS: |
| continue |
| if not fulltext or not summary: |
| continue |
| source_lang = infer_source_lang(fulltext) |
| prompt = build_prompt( |
| template=templates[gold_label], |
| fulltext=fulltext, |
| summary=summary, |
| source_lang=source_lang, |
| ) |
| parsed_items.append( |
| { |
| "row_index": idx, |
| "doc_id": row.get("doc_id"), |
| "gold_label": gold_label, |
| "source_lang": source_lang, |
| "prompt": prompt, |
| } |
| ) |
|
|
| if args.max_samples > 0: |
| parsed_items = parsed_items[: args.max_samples] |
| if not parsed_items: |
| raise RuntimeError("No valid rows found in input file.") |
|
|
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| os.makedirs(args.output_dir, exist_ok=True) |
| summary_path = os.path.join(args.output_dir, f"gpt5_inference_summary_{ts}.json") |
| combined_path = os.path.join(args.output_dir, f"gpt5_inference_all_{ts}.jsonl") |
|
|
| combined_records: List[Dict[str, Any]] = [] |
| model_stats: Dict[str, Dict[str, Any]] = {} |
|
|
| for model in models: |
| model_slug = model.replace("/", "_") |
| model_output_path = os.path.join( |
| args.output_dir, f"gpt5_inference_{model_slug}_{ts}.jsonl" |
| ) |
| success_count = 0 |
| error_count = 0 |
|
|
| with open(model_output_path, "w", encoding="utf-8") as f_model: |
| total = len(parsed_items) |
| progress_iter = tqdm( |
| parsed_items, |
| total=total, |
| desc=f"{model}", |
| unit="item", |
| ) |
| for item in progress_iter: |
|
|
| record: Dict[str, Any] = { |
| "model": model, |
| "row_index": item["row_index"], |
| "doc_id": item.get("doc_id"), |
| "gold_label": item["gold_label"], |
| "source_lang": item["source_lang"], |
| "prompt": item["prompt"], |
| } |
| try: |
| raw_response = call_chat_completion( |
| api_base=args.api_base, |
| api_key=args.api_key, |
| model=model, |
| prompt=item["prompt"], |
| temperature=args.temperature, |
| timeout_seconds=args.timeout_seconds, |
| max_retries=args.max_retries, |
| retry_wait_seconds=args.retry_wait_seconds, |
| ) |
| generated_text = extract_generated_text(raw_response, item["gold_label"]) |
| record["prediction"] = raw_response |
| record["generated_text"] = generated_text |
| record["error"] = "" |
| success_count += 1 |
| except Exception as exc: |
| record["prediction"] = "" |
| record["generated_text"] = "" |
| record["error"] = f"{type(exc).__name__}: {exc}" |
| error_count += 1 |
|
|
| f_model.write(json.dumps(record, ensure_ascii=False) + "\n") |
| combined_records.append(record) |
|
|
| model_stats[model] = { |
| "output_path": model_output_path, |
| "total_rows": len(parsed_items), |
| "success_count": success_count, |
| "error_count": error_count, |
| } |
| print(f"[DONE] {model} output: {model_output_path}") |
|
|
| with open(combined_path, "w", encoding="utf-8") as f_all: |
| for record in combined_records: |
| f_all.write(json.dumps(record, ensure_ascii=False) + "\n") |
|
|
| summary_obj = { |
| "input_path": args.input_path, |
| "api_base": args.api_base, |
| "models": models, |
| "max_samples": args.max_samples, |
| "temperature": args.temperature, |
| "total_dataset_rows_used": len(parsed_items), |
| "combined_output_path": combined_path, |
| "model_stats": model_stats, |
| } |
| with open(summary_path, "w", encoding="utf-8") as f_summary: |
| json.dump(summary_obj, f_summary, ensure_ascii=False, indent=2) |
|
|
| print(f"[DONE] Combined output: {combined_path}") |
| print(f"[DONE] Summary output: {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|