| | import argparse |
| | import json |
| | import os |
| | import re |
| | import traceback |
| | import urllib.error |
| | import urllib.request |
| | from datetime import datetime |
| | from typing import Any, Dict, List, Tuple |
| |
|
| | import dspy |
| | from openai import OpenAI |
| | from tqdm import tqdm |
| |
|
| |
|
| | DEFAULT_CLASSIFIER_API_BASE = "http://172.16.34.22:8040/v1" |
| | DEFAULT_SUPPORT_API_BASE = "http://172.16.34.22:3090/v1" |
| | DEFAULT_MODEL_PATH = ( |
| | "/home/mshahidul/readctrl/code/text_classifier/" |
| | "dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json" |
| | ) |
| | DEFAULT_INPUT_FILE = ( |
| | "/home/mshahidul/readctrl/code/RL_model/inference_data/" |
| | "RL_model_inference_v1.jsonl" |
| | ) |
| | DEFAULT_REFERENCE_SUBCLAIMS_FILE = ( |
| | "/home/mshahidul/readctrl/code/text_classifier/data/" |
| | "verified_combined_0-80_clean200_with_subclaims.json" |
| | ) |
| | DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result_v2" |
| |
|
| | CHAT_TEMPLATE = ( |
| | "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" |
| | "Cutting Knowledge Date: December 2023\n" |
| | "Today Date: 26 July 2024\n\n" |
| | "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" |
| | "{user_prompt}" |
| | "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
| | ) |
| |
|
| | VALID_LABELS = { |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | } |
| |
|
| |
|
| | class HealthLiteracySignature(dspy.Signature): |
| | generated_text = dspy.InputField( |
| | desc="A version of the source text rewritten for a specific audience." |
| | ) |
| | literacy_label = dspy.OutputField( |
| | desc=( |
| | "Classification: low_health_literacy (simple words, no jargon), " |
| | "intermediate_health_literacy (moderate technicality), or " |
| | "proficient_health_literacy (highly technical/original level)." |
| | ) |
| | ) |
| |
|
| |
|
| | class HealthLiteracyClassifier(dspy.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
| |
|
| | def forward(self, generated_text): |
| | return self.classifier(generated_text=generated_text) |
| |
|
| |
|
| | class MedicalClaimVerifier: |
| | def __init__(self, base_url: str, model_name: str): |
| | self.model_name = model_name |
| | self.base_url = base_url |
| | self.client = OpenAI(api_key="EMPTY", base_url=self.base_url) |
| | self.cov_iqr_ranges = { |
| | "low": (0.1765, 0.3226), |
| | "intermediate": (0.1818, 0.4091), |
| | "proficient": (0.7725, 0.9347), |
| | } |
| |
|
| | def build_user_prompt(self, text: str, subclaims: List[str]) -> str: |
| | numbered_subclaims = "\n".join( |
| | f"{idx + 1}. {subclaim}" for idx, subclaim in enumerate(subclaims) |
| | ) |
| | return ( |
| | "You are a medical evidence checker.\n" |
| | "Given a medical passage and a list of subclaims, return labels for each " |
| | "subclaim in the same order.\n\n" |
| | "Allowed labels: supported, not_supported.\n" |
| | "Output format: a JSON array of strings only.\n\n" |
| | f"Medical text:\n{text}\n\n" |
| | f"Subclaims:\n{numbered_subclaims}" |
| | ) |
| |
|
| | def render_chat_prompt(self, user_prompt: str) -> str: |
| | return CHAT_TEMPLATE.format(user_prompt=user_prompt) |
| |
|
| | def extract_label_list(self, text: str) -> List[str]: |
| | cleaned = text.strip() |
| | try: |
| | parsed = json.loads(cleaned) |
| | if isinstance(parsed, list): |
| | return parsed |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | match = re.search(r"\[[\s\S]*\]", cleaned) |
| | if match: |
| | try: |
| | parsed = json.loads(match.group(0)) |
| | if isinstance(parsed, list): |
| | return parsed |
| | except json.JSONDecodeError: |
| | return [] |
| | return [] |
| |
|
| | def check_support_api(self, context: str, subclaims: List[str]) -> List[str]: |
| | if not context or not subclaims: |
| | return [] |
| |
|
| | user_prompt = self.build_user_prompt(context, subclaims) |
| | prompt = self.render_chat_prompt(user_prompt) |
| | try: |
| | response = self.client.completions.create( |
| | model=self.model_name, |
| | prompt=prompt, |
| | max_tokens=256, |
| | temperature=0, |
| | ) |
| | pred_text = response.choices[0].text.strip() |
| | labels = self.extract_label_list(pred_text) |
| | return [str(x).strip().lower() for x in labels] |
| | except Exception: |
| | return [] |
| |
|
| | @staticmethod |
| | def average_supported(labels: List[str], expected_len: int) -> float: |
| | if expected_len <= 0: |
| | return 0.0 |
| | normalized = [str(x).strip().lower() for x in labels] |
| | if len(normalized) < expected_len: |
| | normalized.extend(["invalid"] * (expected_len - len(normalized))) |
| | elif len(normalized) > expected_len: |
| | normalized = normalized[:expected_len] |
| | supported_count = sum(1 for item in normalized if item == "supported") |
| | return supported_count / expected_len |
| |
|
| | def evaluate_level( |
| | self, gen_text: str, gold_subs: List[str], full_subs: List[str] |
| | ) -> Tuple[float, float]: |
| | if not gen_text or not gold_subs or not full_subs: |
| | return 0.0, 0.0 |
| | comp_labels = self.check_support_api(gen_text, gold_subs) |
| | cov_labels = self.check_support_api(gen_text, full_subs) |
| | comp_score = self.average_supported(comp_labels, len(gold_subs)) |
| | cov_score = self.average_supported(cov_labels, len(full_subs)) |
| | return comp_score, cov_score |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Evaluate classifier accuracy plus subclaim support thresholds " |
| | "(completeness + coverage)." |
| | ) |
| | ) |
| | parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) |
| | parser.add_argument( |
| | "--input-file", |
| | default=DEFAULT_INPUT_FILE, |
| | help="Path to RL inference JSONL (e.g. RL_model_inference_v1.jsonl).", |
| | ) |
| | parser.add_argument( |
| | "--reference-subclaims-file", |
| | default=DEFAULT_REFERENCE_SUBCLAIMS_FILE, |
| | help=( |
| | "JSON list file that contains summary_subclaims/fulltext_subclaims " |
| | "(used for lookup by doc_id + label)." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--classifier-api-base", |
| | default=os.environ.get("VLLM_API_BASE", DEFAULT_CLASSIFIER_API_BASE), |
| | ) |
| | parser.add_argument( |
| | "--support-api-base", |
| | default=os.environ.get("SUPPORT_API_BASE", DEFAULT_SUPPORT_API_BASE), |
| | ) |
| | parser.add_argument( |
| | "--support-model", |
| | default=os.environ.get("VLLM_MODEL", "sc"), |
| | ) |
| | parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| | parser.add_argument( |
| | "--generated-text-key", |
| | default="generated_text", |
| | help="Field name to evaluate text from input JSONL.", |
| | ) |
| | parser.add_argument( |
| | "--comp-min-threshold", |
| | type=float, |
| | default=0.9, |
| | help="Completeness pass lower bound (inclusive).", |
| | ) |
| | parser.add_argument( |
| | "--comp-max-threshold", |
| | type=float, |
| | default=1.0, |
| | help="Completeness pass upper bound (inclusive).", |
| | ) |
| | parser.add_argument( |
| | "--max-samples", |
| | type=int, |
| | default=-1, |
| | help="Use -1 for all rows.", |
| | ) |
| | parser.add_argument( |
| | "--provide-traceback", |
| | action="store_true", |
| | help="Print full traceback if runtime error happens.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def check_api_base(api_base: str) -> None: |
| | models_url = api_base.rstrip("/") + "/models" |
| | req = urllib.request.Request(models_url, method="GET") |
| | try: |
| | with urllib.request.urlopen(req, timeout=5) as resp: |
| | if resp.status >= 400: |
| | raise RuntimeError( |
| | f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| | ) |
| | except urllib.error.URLError as exc: |
| | raise ConnectionError( |
| | "Cannot reach OpenAI-compatible endpoint. " |
| | f"api_base={api_base}. " |
| | "Start your vLLM server or pass correct api base." |
| | ) from exc |
| |
|
| |
|
| | def load_compiled_classifier(path: str): |
| | if hasattr(dspy, "load"): |
| | try: |
| | return dspy.load(path) |
| | except Exception: |
| | pass |
| | classifier = HealthLiteracyClassifier() |
| | try: |
| | classifier.load(path) |
| | except Exception as exc: |
| | raise RuntimeError(f"Failed to load compiled model from {path}") from exc |
| | return classifier |
| |
|
| |
|
| | def normalize_pred_label(pred_obj: Any) -> str: |
| | if not pred_obj or not hasattr(pred_obj, "literacy_label"): |
| | return "" |
| | return str(pred_obj.literacy_label).strip().lower() |
| |
|
| |
|
| | def load_items(path: str) -> List[Dict[str, Any]]: |
| | items: List[Dict[str, Any]] = [] |
| | with open(path, "r", encoding="utf-8") as f: |
| | for line_no, line in enumerate(f, start=1): |
| | if not line.strip(): |
| | continue |
| | row = json.loads(line) |
| | items.append( |
| | { |
| | "line_no": line_no, |
| | "row_index": row.get("row_index"), |
| | "doc_id": row.get("doc_id"), |
| | "gold_label": str(row.get("gold_label", "")).strip(), |
| | "generated_text": str(row.get("generated_text", "")).strip(), |
| | } |
| | ) |
| | return items |
| |
|
| |
|
| | def load_subclaim_lookup( |
| | reference_path: str, |
| | ) -> Dict[Tuple[Any, str], Tuple[List[str], List[str]]]: |
| | with open(reference_path, "r", encoding="utf-8") as f: |
| | rows = json.load(f) |
| | if not isinstance(rows, list): |
| | raise ValueError("Reference subclaims file must be a JSON list.") |
| |
|
| | lookup: Dict[Tuple[Any, str], Tuple[List[str], List[str]]] = {} |
| | for row in rows: |
| | doc_id = row.get("doc_id") |
| | label = str(row.get("label", "")).strip() |
| | gold_subs = row.get("summary_subclaims", []) |
| | full_subs = row.get("fulltext_subclaims", []) |
| | if label not in VALID_LABELS: |
| | continue |
| | if not isinstance(gold_subs, list) or not isinstance(full_subs, list): |
| | continue |
| | if not gold_subs or not full_subs: |
| | continue |
| | key = (doc_id, label) |
| | if key not in lookup: |
| | lookup[key] = (gold_subs, full_subs) |
| | return lookup |
| |
|
| |
|
| | def to_level_key(label: str) -> str: |
| | mapping = { |
| | "low_health_literacy": "low", |
| | "intermediate_health_literacy": "intermediate", |
| | "proficient_health_literacy": "proficient", |
| | } |
| | return mapping.get(label, "") |
| |
|
| |
|
| | def in_range(value: float, lower: float, upper: float) -> bool: |
| | return lower <= value <= upper |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | if not os.path.exists(args.model_path): |
| | raise FileNotFoundError(f"Model file not found: {args.model_path}") |
| | if not os.path.exists(args.input_file): |
| | raise FileNotFoundError(f"Input file not found: {args.input_file}") |
| | if not os.path.exists(args.reference_subclaims_file): |
| | raise FileNotFoundError( |
| | f"Reference subclaims file not found: {args.reference_subclaims_file}" |
| | ) |
| |
|
| | try: |
| | check_api_base(args.classifier_api_base) |
| | check_api_base(args.support_api_base) |
| |
|
| | lm = dspy.LM( |
| | model="openai/dspy", |
| | api_base=args.classifier_api_base, |
| | api_key="EMPTY", |
| | temperature=0.0, |
| | ) |
| | dspy.configure(lm=lm) |
| | classifier = load_compiled_classifier(args.model_path) |
| | verifier = MedicalClaimVerifier( |
| | base_url=args.support_api_base, |
| | model_name=args.support_model, |
| | ) |
| | subclaim_lookup = load_subclaim_lookup(args.reference_subclaims_file) |
| |
|
| | rows = load_items(args.input_file) |
| | if args.max_samples > 0: |
| | rows = rows[: args.max_samples] |
| |
|
| | unmatched_rows = 0 |
| | total = 0 |
| | classifier_correct = 0 |
| | comp_pass_count = 0 |
| | cov_pass_count = 0 |
| | cls_and_comp_pass_count = 0 |
| | cls_comp_cov_pass_count = 0 |
| | details: List[Dict[str, Any]] = [] |
| |
|
| | for idx, row in enumerate(tqdm(rows, desc="Evaluating"), start=1): |
| | gold_label = str(row.get("gold_label", "")).strip() |
| | if gold_label not in VALID_LABELS: |
| | continue |
| |
|
| | generated_text = str(row.get(args.generated_text_key, "")).strip() |
| | subclaims = subclaim_lookup.get((row.get("doc_id"), gold_label)) |
| | if not generated_text or not subclaims: |
| | if not subclaims: |
| | unmatched_rows += 1 |
| | continue |
| | gold_subs, full_subs = subclaims |
| |
|
| | total += 1 |
| | pred = classifier(generated_text=generated_text) |
| | pred_label = normalize_pred_label(pred) |
| | is_cls_correct = gold_label in pred_label |
| | classifier_correct += int(is_cls_correct) |
| |
|
| | comp_score, cov_score = verifier.evaluate_level( |
| | gen_text=generated_text, |
| | gold_subs=gold_subs, |
| | full_subs=full_subs, |
| | ) |
| |
|
| | comp_pass = in_range( |
| | comp_score, args.comp_min_threshold, args.comp_max_threshold |
| | ) |
| | comp_pass_count += int(comp_pass) |
| |
|
| | level_key = to_level_key(gold_label) |
| | cov_low, cov_high = verifier.cov_iqr_ranges[level_key] |
| | cov_pass = in_range(cov_score, cov_low, cov_high) |
| | cov_pass_count += int(cov_pass) |
| |
|
| | cls_and_comp_pass = is_cls_correct and comp_pass |
| | cls_comp_cov_pass = cls_and_comp_pass and cov_pass |
| | cls_and_comp_pass_count += int(cls_and_comp_pass) |
| | cls_comp_cov_pass_count += int(cls_comp_cov_pass) |
| |
|
| | details.append( |
| | { |
| | "idx": idx, |
| | "line_no": row.get("line_no"), |
| | "row_index": row.get("row_index"), |
| | "doc_id": row.get("doc_id"), |
| | "gold_label": gold_label, |
| | "pred_label": pred_label, |
| | "classifier_correct": is_cls_correct, |
| | "completeness_score": comp_score, |
| | "coverage_score": cov_score, |
| | "completeness_threshold": [ |
| | args.comp_min_threshold, |
| | args.comp_max_threshold, |
| | ], |
| | "completeness_pass": comp_pass, |
| | "coverage_iqr_threshold": [cov_low, cov_high], |
| | "coverage_pass": cov_pass, |
| | "pass_cls_and_completeness": cls_and_comp_pass, |
| | "pass_cls_comp_cov": cls_comp_cov_pass, |
| | } |
| | ) |
| |
|
| | if total == 0: |
| | raise RuntimeError("No valid rows were found for evaluation.") |
| |
|
| | def safe_rate(n: int) -> float: |
| | return n / total if total else 0.0 |
| |
|
| | os.makedirs(args.output_dir, exist_ok=True) |
| | ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | summary_path = os.path.join( |
| | args.output_dir, f"classifier_subclaim_threshold_eval_{ts}.json" |
| | ) |
| | details_path = os.path.join( |
| | args.output_dir, f"classifier_subclaim_threshold_eval_{ts}.jsonl" |
| | ) |
| |
|
| | summary = { |
| | "model_path": args.model_path, |
| | "input_file": args.input_file, |
| | "reference_subclaims_file": args.reference_subclaims_file, |
| | "generated_text_key": args.generated_text_key, |
| | "classifier_api_base": args.classifier_api_base, |
| | "support_api_base": args.support_api_base, |
| | "support_model": args.support_model, |
| | "total_samples": total, |
| | "unmatched_rows": unmatched_rows, |
| | "classifier_only_accuracy": safe_rate(classifier_correct), |
| | "completeness_pass_rate": safe_rate(comp_pass_count), |
| | "coverage_pass_rate": safe_rate(cov_pass_count), |
| | "accuracy_cls_and_completeness_threshold": safe_rate( |
| | cls_and_comp_pass_count |
| | ), |
| | "accuracy_cls_completeness_coverage_threshold": safe_rate( |
| | cls_comp_cov_pass_count |
| | ), |
| | "completeness_threshold": [args.comp_min_threshold, args.comp_max_threshold], |
| | "coverage_thresholds": verifier.cov_iqr_ranges, |
| | "details_path": details_path, |
| | } |
| |
|
| | with open(summary_path, "w", encoding="utf-8") as f: |
| | json.dump(summary, f, indent=2) |
| |
|
| | with open(details_path, "w", encoding="utf-8") as f: |
| | for item in details: |
| | f.write(json.dumps(item, ensure_ascii=False) + "\n") |
| |
|
| | print(json.dumps(summary, indent=2)) |
| | print(f"[DONE] Summary saved: {summary_path}") |
| | print(f"[DONE] Details saved: {details_path}") |
| |
|
| | except Exception as exc: |
| | print(f"[error] {type(exc).__name__}: {exc}") |
| | if args.provide_traceback: |
| | traceback.print_exc() |
| | raise |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|