| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
|
|
| import torch |
| from unsloth import FastLanguageModel |
| import json |
| import tqdm |
| import re |
|
|
| |
| |
| |
| _model_cache = {"model": None, "tokenizer": None} |
|
|
| def load_finetuned_model(model_path: str): |
| if _model_cache["model"] is not None: |
| return _model_cache["model"], _model_cache["tokenizer"] |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=model_path, |
| max_seq_length=8192, |
| load_in_4bit=False, |
| load_in_8bit=False, |
| full_finetuning=False, |
| ) |
| |
| FastLanguageModel.for_inference(model) |
| |
| _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer |
| return model, tokenizer |
|
|
| |
| |
| |
| def classification_prompt(full_text: str, summary: str) -> str: |
| """ |
| Constructs the prompt to classify readability of the summary |
| based on the context of the full text. |
| """ |
| prompt = f"""You are a medical readability evaluator. |
| |
| ### Task |
| Compare the "GENERATED TEXT" against the "FULL TEXT" to determine its readability for a general, non-medical audience. |
| |
| ### Input Data |
| - **FULL TEXT:** {full_text} |
| - **GENERATED TEXT (Evaluate this):** {summary} |
| |
| ### Readability Scale |
| 1: Very Easy - Minimal medical language, uses simple terms. |
| 2: Easy - Accessible to most, minor jargon explained. |
| 3: Medium - Some technical terms, moderate complexity. |
| 4: Hard - Clinical tone, assumes some prior knowledge. |
| 5: Very Hard - Extremely technical, requires medical expertise. |
| |
| ### Constraints |
| - Evaluate ONLY the "GENERATED TEXT". |
| - Use "FULL TEXT" only for context of the subject matter. |
| - Do NOT assess factual accuracy. |
| |
| ### Output Format |
| Return ONLY a valid JSON object: |
| {{ |
| "readability_score": <integer_1_to_5> |
| }}""" |
| return prompt |
|
|
| |
| |
| |
| def infer_readability(full_text: str, |
| summary: str, |
| model_path: str) -> dict: |
| |
| model, tokenizer = load_finetuned_model(model_path) |
| prompt = classification_prompt(full_text, summary) |
|
|
| messages = [{"role": "user", "content": prompt}] |
|
|
| chat_text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=50, |
| temperature=0.1, |
| do_sample=False, |
| ) |
|
|
| output_text = tokenizer.decode(output_ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip() |
|
|
| |
| if "</think>" in output_text: |
| output_text = output_text.split("</think>")[-1].strip() |
| |
| |
| try: |
| match = re.search(r"\{.*\}", output_text, re.DOTALL) |
| if match: |
| return json.loads(match.group()) |
| return {"readability_score": "error", "raw": output_text} |
| except Exception: |
| return {"readability_score": "error", "raw": output_text} |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| INPUT_FILE = "/home/mshahidul/readctrl/data/processed_raw_data/multiclinsum_test_en.json" |
| SAVE_FOLDER = "/home/mshahidul/readctrl/data/classified_readability" |
| |
| MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_classifier_en" |
|
|
| os.makedirs(SAVE_FOLDER, exist_ok=True) |
| file_name = os.path.basename(INPUT_FILE).split(".json")[0] |
| OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"classified_{file_name}.json") |
|
|
| |
| with open(INPUT_FILE, "r") as f: |
| data = json.load(f) |
|
|
| |
| result = [] |
| if os.path.exists(OUTPUT_FILE): |
| with open(OUTPUT_FILE, "r") as f: |
| result = json.load(f) |
| |
| existing_ids = {item["id"] for item in result} |
|
|
| print(f"Starting classification. Saving to: {OUTPUT_FILE}") |
|
|
| for item in tqdm.tqdm(data): |
| if item["id"] in existing_ids: |
| continue |
|
|
| full_text = item.get("fulltext", "") |
| summary = item.get("summary", "") |
|
|
| classification_res = infer_readability( |
| full_text=full_text, |
| summary=summary, |
| model_path=MODEL_PATH |
| ) |
|
|
| result.append({ |
| "id": item["id"], |
| "readability_score": classification_res.get("readability_score"), |
| "fulltext": full_text, |
| "summary": summary |
| }) |
|
|
| |
| if len(result) % 50 == 0: |
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(result, f, indent=4, ensure_ascii=False) |
|
|
| |
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(result, f, indent=4, ensure_ascii=False) |
|
|
| print(f"Classification completed. {len(result)} items processed.") |