| import argparse | |
| import glob | |
| import json | |
| import os | |
| import traceback | |
| import urllib.error | |
| import urllib.request | |
| from datetime import datetime | |
| from typing import Any, Dict, List | |
| import dspy | |
| from tqdm import tqdm | |
| DEFAULT_API_BASE = "http://172.16.34.21:8040/v1" | |
| DEFAULT_MODEL_PATH = ( | |
| "/home/mshahidul/readctrl/code/text_classifier/" | |
| "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json" | |
| ) | |
| DEFAULT_INPUT_PATH = "/home/mshahidul/readctrl/code/RL_model/inference_data" | |
| DEFAULT_INPUT_FILE = ( | |
| "/home/mshahidul/readctrl/code/RL_model/inference_data/" | |
| "vllm_inference_qwen-qwen3-4b-instruct-2507_20260213_173334.jsonl" | |
| ) | |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result" | |
| VALID_LABELS = { | |
| "low_health_literacy", | |
| "intermediate_health_literacy", | |
| "proficient_health_literacy", | |
| } | |
| class HealthLiteracySignature(dspy.Signature): | |
| generated_text = dspy.InputField( | |
| desc="A version of the source text rewritten for a specific audience." | |
| ) | |
| literacy_label = dspy.OutputField( | |
| desc=( | |
| "Classification: low_health_literacy (simple words, no jargon), " | |
| "intermediate_health_literacy (moderate technicality), or " | |
| "proficient_health_literacy (highly technical/original level)." | |
| ) | |
| ) | |
| class HealthLiteracyClassifier(dspy.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) | |
| def forward(self, generated_text): | |
| return self.classifier(generated_text=generated_text) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Evaluate saved DSPy classifier on saved vLLM inference outputs." | |
| ) | |
| parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) | |
| parser.add_argument( | |
| "--input-path", | |
| default=DEFAULT_INPUT_FILE, | |
| help=( | |
| "Path to vLLM output JSONL (e.g. vllm_inference_*.jsonl). " | |
| "Set to empty string to auto-select latest file in --search-dir." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--search-dir", | |
| default=DEFAULT_INPUT_PATH, | |
| help="Directory to auto-search for vllm_inference_*.jsonl", | |
| ) | |
| parser.add_argument( | |
| "--api-base", | |
| default=os.environ.get("VLLM_API_BASE", DEFAULT_API_BASE), | |
| ) | |
| parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) | |
| parser.add_argument( | |
| "--max-samples", | |
| type=int, | |
| default=-1, | |
| help="Use -1 for all rows.", | |
| ) | |
| parser.add_argument( | |
| "--provide-traceback", | |
| action="store_true", | |
| help="Print full traceback if runtime error happens.", | |
| ) | |
| return parser.parse_args() | |
| def check_api_base(api_base: str) -> None: | |
| models_url = api_base.rstrip("/") + "/models" | |
| req = urllib.request.Request(models_url, method="GET") | |
| try: | |
| with urllib.request.urlopen(req, timeout=5) as resp: | |
| if resp.status >= 400: | |
| raise RuntimeError( | |
| f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" | |
| ) | |
| except urllib.error.URLError as exc: | |
| raise ConnectionError( | |
| "Cannot reach OpenAI-compatible endpoint. " | |
| f"api_base={api_base}. " | |
| "Start your vLLM server or pass correct --api-base." | |
| ) from exc | |
| def resolve_input_path(input_path: str, search_dir: str) -> str: | |
| if input_path and os.path.exists(input_path): | |
| return input_path | |
| if input_path: | |
| raise FileNotFoundError(f"Input file not found: {input_path}") | |
| candidates = sorted( | |
| glob.glob(os.path.join(search_dir, "vllm_inference_*.jsonl")), | |
| key=os.path.getmtime, | |
| ) | |
| if not candidates: | |
| raise FileNotFoundError( | |
| "No vLLM output file found. Expected pattern: " | |
| f"{search_dir}/vllm_inference_*.jsonl" | |
| ) | |
| return candidates[-1] | |
| def load_compiled_classifier(path: str): | |
| if hasattr(dspy, "load"): | |
| try: | |
| return dspy.load(path) | |
| except Exception: | |
| pass | |
| classifier = HealthLiteracyClassifier() | |
| try: | |
| classifier.load(path) | |
| except Exception as exc: | |
| raise RuntimeError(f"Failed to load compiled model from {path}") from exc | |
| return classifier | |
| def normalize_pred_label(pred_obj: Any) -> str: | |
| if not pred_obj or not hasattr(pred_obj, "literacy_label"): | |
| return "" | |
| return str(pred_obj.literacy_label).strip().lower() | |
| def load_eval_items(path: str) -> List[Dict[str, Any]]: | |
| items: List[Dict[str, Any]] = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line_no, line in enumerate(f, start=1): | |
| if not line.strip(): | |
| continue | |
| row = json.loads(line) | |
| gold_label = str(row.get("gold_label", "")).strip() | |
| generated_text = str(row.get("generated_text", "")).strip() | |
| if not generated_text: | |
| generated_text = str(row.get("prediction", "")).strip() | |
| err_msg = str(row.get("error", "")).strip() | |
| if gold_label not in VALID_LABELS: | |
| continue | |
| if err_msg: | |
| continue | |
| if not generated_text: | |
| continue | |
| items.append( | |
| { | |
| "line_no": line_no, | |
| "row_index": row.get("row_index"), | |
| "doc_id": row.get("doc_id"), | |
| "gold_label": gold_label, | |
| "generated_text": generated_text, | |
| } | |
| ) | |
| return items | |
| def main() -> None: | |
| args = parse_args() | |
| args.input_path = resolve_input_path(args.input_path, args.search_dir) | |
| if not os.path.exists(args.model_path): | |
| raise FileNotFoundError(f"Model file not found: {args.model_path}") | |
| try: | |
| check_api_base(args.api_base) | |
| lm = dspy.LM( | |
| model="openai/dspy", | |
| api_base=args.api_base, | |
| api_key="EMPTY", | |
| temperature=0.0, | |
| ) | |
| dspy.configure(lm=lm) | |
| classifier = load_compiled_classifier(args.model_path) | |
| print(f"[INFO] Using input file: {args.input_path}") | |
| parsed_items = load_eval_items(args.input_path) | |
| if args.max_samples > 0: | |
| parsed_items = parsed_items[: args.max_samples] | |
| if not parsed_items: | |
| raise RuntimeError("No valid rows found in input file for classifier evaluation.") | |
| correct = 0 | |
| results: List[Dict[str, Any]] = [] | |
| for item in tqdm(parsed_items, desc="Classifying"): | |
| pred = classifier(generated_text=item["generated_text"]) | |
| pred_label = normalize_pred_label(pred) | |
| is_correct = item["gold_label"] in pred_label | |
| correct += int(is_correct) | |
| results.append( | |
| { | |
| "line_no": item["line_no"], | |
| "row_index": item["row_index"], | |
| "doc_id": item.get("doc_id"), | |
| "gold_label": item["gold_label"], | |
| "pred_label": pred_label, | |
| "is_correct": is_correct, | |
| } | |
| ) | |
| total = len(results) | |
| accuracy = correct / total if total else 0.0 | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| summary_path = os.path.join(args.output_dir, f"classifier_eval_vllm_{ts}.json") | |
| details_path = os.path.join(args.output_dir, f"classifier_eval_vllm_{ts}.jsonl") | |
| with open(summary_path, "w", encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| "model_path": args.model_path, | |
| "input_path": args.input_path, | |
| "api_base": args.api_base, | |
| "total_samples": total, | |
| "correct_samples": correct, | |
| "accuracy_score": accuracy, | |
| "details_path": details_path, | |
| }, | |
| f, | |
| indent=2, | |
| ) | |
| with open(details_path, "w", encoding="utf-8") as f: | |
| for r in results: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| print(json.dumps({"total_samples": total, "accuracy_score": accuracy}, indent=2)) | |
| print(f"[DONE] Summary saved: {summary_path}") | |
| print(f"[DONE] Details saved: {details_path}") | |
| except Exception as exc: | |
| print(f"[error] {type(exc).__name__}: {exc}") | |
| if args.provide_traceback: | |
| traceback.print_exc() | |
| raise | |
| if __name__ == "__main__": | |
| main() | |