| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| |
|
| | import torch |
| | from unsloth import FastLanguageModel |
| | import json |
| | import tqdm |
| | import re |
| |
|
| | |
| | |
| | |
| | _model_cache = {"model": None, "tokenizer": None} |
| |
|
| | def load_finetuned_model(model_path: str): |
| | if _model_cache["model"] is not None: |
| | return _model_cache["model"], _model_cache["tokenizer"] |
| |
|
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=model_path, |
| | max_seq_length=8192, |
| | load_in_4bit=False, |
| | load_in_8bit=False, |
| | full_finetuning=False, |
| | ) |
| | |
| | FastLanguageModel.for_inference(model) |
| | |
| | _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer |
| | return model, tokenizer |
| |
|
| | |
| | |
| | |
| | def classification_prompt(full_text: str, summary: str) -> str: |
| | """ |
| | Constructs the prompt to classify readability of the summary |
| | based on the context of the full text. |
| | """ |
| | prompt = f"""You are a medical readability evaluator. |
| | |
| | ### Task |
| | Compare the "GENERATED TEXT" against the "FULL TEXT" to determine its readability for a general, non-medical audience. |
| | |
| | ### Input Data |
| | - **FULL TEXT:** {full_text} |
| | - **GENERATED TEXT (Evaluate this):** {summary} |
| | |
| | ### Readability Scale |
| | 1: Very Easy - Minimal medical language, uses simple terms. |
| | 2: Easy - Accessible to most, minor jargon explained. |
| | 3: Medium - Some technical terms, moderate complexity. |
| | 4: Hard - Clinical tone, assumes some prior knowledge. |
| | 5: Very Hard - Extremely technical, requires medical expertise. |
| | |
| | ### Constraints |
| | - Evaluate ONLY the "GENERATED TEXT". |
| | - Use "FULL TEXT" only for context of the subject matter. |
| | - Do NOT assess factual accuracy. |
| | |
| | ### Output Format |
| | Return ONLY a valid JSON object: |
| | {{ |
| | "readability_score": <integer_1_to_5> |
| | }}""" |
| | return prompt |
| |
|
| | |
| | |
| | |
| | def infer_readability(full_text: str, |
| | summary: str, |
| | model_path: str) -> dict: |
| | |
| | model, tokenizer = load_finetuned_model(model_path) |
| | prompt = classification_prompt(full_text, summary) |
| |
|
| | messages = [{"role": "user", "content": prompt}] |
| |
|
| | chat_text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| |
|
| | inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") |
| |
|
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=50, |
| | temperature=0.1, |
| | do_sample=False, |
| | ) |
| |
|
| | output_text = tokenizer.decode(output_ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip() |
| |
|
| | |
| | if "</think>" in output_text: |
| | output_text = output_text.split("</think>")[-1].strip() |
| | |
| | |
| | try: |
| | match = re.search(r"\{.*\}", output_text, re.DOTALL) |
| | if match: |
| | return json.loads(match.group()) |
| | return {"readability_score": "error", "raw": output_text} |
| | except Exception: |
| | return {"readability_score": "error", "raw": output_text} |
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | |
| | INPUT_FILE = "/home/mshahidul/readctrl/data/processed_raw_data/multiclinsum_test_en.json" |
| | SAVE_FOLDER = "/home/mshahidul/readctrl/data/classified_readability" |
| | |
| | MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_classifier_en" |
| |
|
| | os.makedirs(SAVE_FOLDER, exist_ok=True) |
| | file_name = os.path.basename(INPUT_FILE).split(".json")[0] |
| | OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"classified_{file_name}.json") |
| |
|
| | |
| | with open(INPUT_FILE, "r") as f: |
| | data = json.load(f) |
| |
|
| | |
| | result = [] |
| | if os.path.exists(OUTPUT_FILE): |
| | with open(OUTPUT_FILE, "r") as f: |
| | result = json.load(f) |
| | |
| | existing_ids = {item["id"] for item in result} |
| |
|
| | print(f"Starting classification. Saving to: {OUTPUT_FILE}") |
| |
|
| | for item in tqdm.tqdm(data): |
| | if item["id"] in existing_ids: |
| | continue |
| |
|
| | full_text = item.get("fulltext", "") |
| | summary = item.get("summary", "") |
| |
|
| | classification_res = infer_readability( |
| | full_text=full_text, |
| | summary=summary, |
| | model_path=MODEL_PATH |
| | ) |
| |
|
| | result.append({ |
| | "id": item["id"], |
| | "readability_score": classification_res.get("readability_score"), |
| | "fulltext": full_text, |
| | "summary": summary |
| | }) |
| |
|
| | |
| | if len(result) % 50 == 0: |
| | with open(OUTPUT_FILE, "w") as f: |
| | json.dump(result, f, indent=4, ensure_ascii=False) |
| |
|
| | |
| | with open(OUTPUT_FILE, "w") as f: |
| | json.dump(result, f, indent=4, ensure_ascii=False) |
| |
|
| | print(f"Classification completed. {len(result)} items processed.") |