| import argparse |
| import json |
| import os |
| from collections import Counter |
| from typing import Dict, List, Tuple |
|
|
| import dspy |
| from tqdm import tqdm |
|
|
|
|
| API_FILE = "/home/mshahidul/api_new.json" |
| DEFAULT_MODEL_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/model.json" |
| DEFAULT_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80.json" |
| DEFAULT_OUTPUT_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_accuracy.json" |
| DEFAULT_PREDICTIONS_PATH = "/home/mshahidul/readctrl/code/text_classifier/dspy_model/student-gpt5-mini_teacher-gpt5_v1/full_dataset_predictions.json" |
| DEFAULT_CLEAN_DATASET_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_clean200.json" |
| DEFAULT_REMOVED_PATH = "/home/mshahidul/readctrl/code/text_classifier/verified_combined_0-80_removed21.json" |
| VALID_LABELS = { |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| } |
| LABEL_ORDER = { |
| "low_health_literacy": 0, |
| "intermediate_health_literacy": 1, |
| "proficient_health_literacy": 2, |
| } |
|
|
|
|
| class HealthLiteracySignature(dspy.Signature): |
| """ |
| Analyze the linguistic complexity, use of medical jargon, and sentence |
| structure of 'generated_text' to determine the health literacy level. |
| """ |
|
|
| generated_text = dspy.InputField( |
| desc="A version of the source text rewritten for a specific audience." |
| ) |
| literacy_label = dspy.OutputField( |
| desc=( |
| "Classification: low_health_literacy (simple words, no jargon), " |
| "intermediate_health_literacy (moderate technicality), or " |
| "proficient_health_literacy (highly technical/original level)." |
| ) |
| ) |
|
|
|
|
| class HealthLiteracyClassifier(dspy.Module): |
| def __init__(self): |
| super().__init__() |
| self.classifier = dspy.ChainOfThought(HealthLiteracySignature) |
|
|
| def forward(self, generated_text): |
| return self.classifier(generated_text=generated_text) |
|
|
|
|
| def load_openai_key(api_file: str) -> str: |
| with open(api_file, "r") as f: |
| api_keys = json.load(f) |
| if "openai" not in api_keys: |
| raise KeyError(f"'openai' key is missing in {api_file}") |
| return api_keys["openai"] |
|
|
|
|
| def normalize_label(text: str) -> str: |
| return str(text or "").strip().lower() |
|
|
|
|
| def is_correct(gold_label: str, predicted_label: str) -> bool: |
| gold = normalize_label(gold_label) |
| pred = normalize_label(predicted_label) |
| return gold in pred |
|
|
|
|
| def extract_predicted_label(predicted_text: str) -> str: |
| pred = normalize_label(predicted_text) |
| matched = [label for label in VALID_LABELS if label in pred] |
| if len(matched) == 1: |
| return matched[0] |
| return "" |
|
|
|
|
| def misclassification_severity(gold_label: str, predicted_label: str) -> int: |
| gold = LABEL_ORDER.get(gold_label) |
| pred = LABEL_ORDER.get(predicted_label) |
| if gold is None or pred is None: |
| |
| return 3 |
| return abs(gold - pred) |
|
|
|
|
| def load_full_examples(dataset_path: str): |
| with open(dataset_path, "r") as f: |
| raw_data = json.load(f) |
|
|
| examples = [] |
| for idx, item in enumerate(raw_data): |
| label = item.get("label") |
| text = item.get("diff_label_texts") |
| if label in VALID_LABELS and text: |
| examples.append( |
| { |
| "index": idx, |
| "generated_text": text, |
| "gold_label": label, |
| "doc_id": item.get("doc_id"), |
| "raw_item": item, |
| } |
| ) |
| if not examples: |
| raise ValueError("No valid labeled examples found in dataset.") |
| return examples |
|
|
|
|
| def choose_indices_to_remove( |
| predictions: List[Dict], remove_count: int |
| ) -> Tuple[List[Dict], List[int]]: |
| def _rank_key(p: Dict): |
| return ( |
| 0 if not p["exact_correct"] else 1, |
| -p["severity"], |
| 0 if not p["predicted_label"] else 1, |
| -len(normalize_label(p["raw_prediction_text"])), |
| p["index"], |
| ) |
|
|
| label_sequence = sorted(VALID_LABELS, key=lambda x: LABEL_ORDER[x]) |
| per_label_all = {label: [] for label in label_sequence} |
| per_label_mis = {label: [] for label in label_sequence} |
| for p in predictions: |
| label = p["gold_label"] |
| if label in per_label_all: |
| per_label_all[label].append(p) |
| if not p["exact_correct"]: |
| per_label_mis[label].append(p) |
|
|
| for label in label_sequence: |
| per_label_all[label].sort(key=_rank_key) |
| per_label_mis[label].sort(key=_rank_key) |
|
|
| |
| num_labels = len(label_sequence) |
| base_quota = remove_count // num_labels |
| remainder = remove_count % num_labels |
| quotas = {label: base_quota for label in label_sequence} |
|
|
| |
| remainder_order = sorted( |
| label_sequence, |
| key=lambda label: (-len(per_label_mis[label]), LABEL_ORDER[label]), |
| ) |
| for label in remainder_order[:remainder]: |
| quotas[label] += 1 |
|
|
| removed = [] |
| removed_indices_set = set() |
|
|
| |
| for label in label_sequence: |
| take = min(quotas[label], len(per_label_mis[label])) |
| for item in per_label_mis[label][:take]: |
| removed.append(item) |
| removed_indices_set.add(item["index"]) |
|
|
| |
| |
| for label in label_sequence: |
| needed = quotas[label] - sum(1 for x in removed if x["gold_label"] == label) |
| if needed <= 0: |
| continue |
| candidates = [ |
| x for x in per_label_all[label] if x["index"] not in removed_indices_set |
| ] |
| for item in candidates[:needed]: |
| removed.append(item) |
| removed_indices_set.add(item["index"]) |
|
|
| |
| if len(removed) < remove_count: |
| remaining_global = sorted( |
| (p for p in predictions if p["index"] not in removed_indices_set), |
| key=_rank_key, |
| ) |
| need = remove_count - len(removed) |
| for item in remaining_global[:need]: |
| removed.append(item) |
| removed_indices_set.add(item["index"]) |
|
|
| |
| removed = sorted(removed, key=_rank_key)[:remove_count] |
| removed_indices = sorted(p["index"] for p in removed) |
| return removed, removed_indices |
|
|
|
|
| def run_inference( |
| model_path: str, |
| dataset_path: str, |
| output_path: str, |
| predictions_path: str, |
| clean_dataset_path: str, |
| removed_path: str, |
| target_clean_size: int, |
| ): |
| openai_api_key = load_openai_key(API_FILE) |
| student_lm = dspy.LM(model="gpt-5-mini", api_key=openai_api_key) |
| dspy.configure(lm=student_lm) |
|
|
| classifier = HealthLiteracyClassifier() |
| classifier.load(model_path) |
|
|
| examples = load_full_examples(dataset_path) |
| total = len(examples) |
| if target_clean_size <= 0 or target_clean_size >= total: |
| raise ValueError( |
| f"target_clean_size must be between 1 and {total - 1}, got {target_clean_size}" |
| ) |
|
|
| remove_count = total - target_clean_size |
| correct = 0 |
| label_totals = Counter() |
| label_correct = Counter() |
| predictions = [] |
|
|
| for idx, ex in enumerate( |
| tqdm(examples, desc="Classifying full dataset", unit="sample"), start=1 |
| ): |
| pred = classifier(generated_text=ex["generated_text"]) |
| raw_pred_label = getattr(pred, "literacy_label", "") |
| pred_label = extract_predicted_label(raw_pred_label) |
| gold_label = ex["gold_label"] |
| exact_correct = pred_label == gold_label |
| lenient_correct = is_correct(gold_label, raw_pred_label) |
| severity = ( |
| misclassification_severity(gold_label, pred_label) if not exact_correct else 0 |
| ) |
|
|
| label_totals[gold_label] += 1 |
| if lenient_correct: |
| correct += 1 |
| label_correct[gold_label] += 1 |
|
|
| predictions.append( |
| { |
| "index": ex["index"], |
| "doc_id": ex["doc_id"], |
| "gold_label": gold_label, |
| "predicted_label": pred_label, |
| "raw_prediction_text": raw_pred_label, |
| "lenient_correct": lenient_correct, |
| "exact_correct": exact_correct, |
| "severity": severity, |
| "generated_text": ex["generated_text"], |
| } |
| ) |
|
|
| if idx % 10 == 0 or idx == total: |
| tqdm.write(f"Processed {idx}/{total}") |
|
|
| accuracy = correct / total if total else 0.0 |
| exact_accuracy = ( |
| sum(1 for p in predictions if p["exact_correct"]) / total if total else 0.0 |
| ) |
| per_label_accuracy = { |
| label: ( |
| (label_correct[label] / label_totals[label]) if label_totals[label] else 0.0 |
| ) |
| for label in sorted(VALID_LABELS) |
| } |
| removed_examples, removed_indices = choose_indices_to_remove(predictions, remove_count) |
| removed_index_set = set(removed_indices) |
| clean_dataset = [ |
| p["raw_item"] |
| for p in examples |
| if p["index"] not in removed_index_set |
| ] |
| removed_dataset = [ |
| p["raw_item"] |
| for p in examples |
| if p["index"] in removed_index_set |
| ] |
|
|
| report = { |
| "model_path": model_path, |
| "dataset_path": dataset_path, |
| "num_examples": total, |
| "num_correct": correct, |
| "lenient_accuracy": accuracy, |
| "exact_accuracy": exact_accuracy, |
| "per_label_accuracy": per_label_accuracy, |
| "target_clean_size": target_clean_size, |
| "removed_count": remove_count, |
| "clean_dataset_size": len(clean_dataset), |
| "removed_dataset_size": len(removed_dataset), |
| "removed_misclassified_count": sum( |
| 1 for p in removed_examples if not p["exact_correct"] |
| ), |
| "removed_per_label": dict( |
| Counter(p["gold_label"] for p in removed_examples) |
| ), |
| } |
|
|
| for path in [ |
| output_path, |
| predictions_path, |
| clean_dataset_path, |
| removed_path, |
| ]: |
| output_dir = os.path.dirname(path) |
| if output_dir: |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| with open(output_path, "w") as f: |
| json.dump(report, f, indent=2) |
| with open(predictions_path, "w") as f: |
| json.dump(predictions, f, indent=2) |
| with open(clean_dataset_path, "w") as f: |
| json.dump(clean_dataset, f, indent=2, ensure_ascii=False) |
| with open(removed_path, "w") as f: |
| json.dump(removed_dataset, f, indent=2, ensure_ascii=False) |
|
|
| print(json.dumps(report, indent=2)) |
| print(f"Saved predictions to: {predictions_path}") |
| print(f"Saved clean dataset to: {clean_dataset_path}") |
| print(f"Saved removed examples to: {removed_path}") |
| print(f"Saved report to: {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Load a compiled DSPy classifier and evaluate on full dataset." |
| ) |
| parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH) |
| parser.add_argument("--dataset-path", default=DEFAULT_DATASET_PATH) |
| parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH) |
| parser.add_argument("--predictions-path", default=DEFAULT_PREDICTIONS_PATH) |
| parser.add_argument("--clean-dataset-path", default=DEFAULT_CLEAN_DATASET_PATH) |
| parser.add_argument("--removed-path", default=DEFAULT_REMOVED_PATH) |
| parser.add_argument("--target-clean-size", type=int, default=200) |
| args = parser.parse_args() |
|
|
| run_inference( |
| model_path=args.model_path, |
| dataset_path=args.dataset_path, |
| output_path=args.output_path, |
| predictions_path=args.predictions_path, |
| clean_dataset_path=args.clean_dataset_path, |
| removed_path=args.removed_path, |
| target_clean_size=args.target_clean_size, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|