| import os |
| |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| import torch |
| from unsloth import FastLanguageModel |
| import json |
| import tqdm |
| import argparse |
|
|
|
|
| |
| |
| |
| _model_cache = {"model": None, "tokenizer": None} |
|
|
| def load_finetuned_model(model_path: str): |
| if _model_cache["model"] is not None: |
| return _model_cache["model"], _model_cache["tokenizer"] |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=model_path, |
| max_seq_length=8192, |
| load_in_4bit=False, |
| load_in_8bit=False, |
| full_finetuning=False, |
| ) |
| _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer |
| return model, tokenizer |
|
|
| |
| |
| |
| def extraction_prompt(medical_text: str) -> str: |
| prompt = f""" |
| You are an expert medical annotator. Your task is to extract granular, factual subclaims from medical text. |
| A subclaim is the smallest standalone factual unit that can be independently verified. |
| |
| Instructions: |
| 1. Read the provided medical text. |
| 2. Break it into clear, objective, atomic subclaims. |
| 3. Each subclaim must come directly from the text. |
| 4. Return ONLY a valid JSON list of strings. |
| |
| Medical Text: |
| {medical_text} |
| |
| Return your output in JSON list format: |
| [ |
| "subclaim 1", |
| "subclaim 2" |
| ] |
| """ |
| return prompt |
| |
| |
| |
| def infer_subclaims(medical_text: str, model, tokenizer, temperature: float = 0.2, max_tokens: int = 2048, retries: int = 1) -> list: |
| if not medical_text or medical_text.strip() == "": |
| return [] |
|
|
| prompt = extraction_prompt(medical_text) |
| messages = [{"role": "user", "content": prompt}] |
| chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=False |
| ) |
|
|
| output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
| |
| |
| if "</think>" in output_text: |
| output_text = output_text.split("</think>")[-1].strip() |
|
|
| |
| try: |
| start_idx = output_text.find('[') |
| end_idx = output_text.rfind(']') + 1 |
| |
| |
| if start_idx != -1 and end_idx > start_idx: |
| content = output_text[start_idx:end_idx] |
| parsed = json.loads(content) |
| if isinstance(parsed, list): |
| return parsed |
| |
| |
| raise ValueError("Incomplete JSON list") |
|
|
| except (json.JSONDecodeError, ValueError): |
| |
| if retries > 0: |
| new_max = max_tokens + 2048 |
| print(f"\n[Warning] Truncation detected. Retrying with {new_max} tokens...") |
| return infer_subclaims(medical_text, model, tokenizer, temperature, max_tokens=new_max, retries=retries-1) |
| |
| |
| return [output_text] |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input_file", type=str, required=True) |
| args = parser.parse_args() |
| |
| INPUT_FILE = args.input_file |
| file_name = os.path.basename(INPUT_FILE).split(".json")[0] |
| SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim" |
| MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-extraction-8b_ctx" |
|
|
| os.makedirs(SAVE_FOLDER, exist_ok=True) |
| OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"extracted_subclaims_{file_name}.json") |
|
|
| model, tokenizer = load_finetuned_model(MODEL_PATH) |
|
|
| with open(INPUT_FILE, "r") as f: |
| data = json.load(f) |
|
|
| result = [] |
| if os.path.exists(OUTPUT_FILE): |
| with open(OUTPUT_FILE, "r") as f: |
| result = json.load(f) |
|
|
| processed_data = {str(item.get("index") or item.get("id")): item for item in result} |
|
|
| for item in tqdm.tqdm(data): |
| item_id = str(item.get("index") if item.get("index") is not None else item.get("id")) |
| existing_entry = processed_data.get(item_id) |
|
|
| |
| if not existing_entry or not isinstance(existing_entry.get("fulltext_subclaims"), list): |
| f_sub = infer_subclaims(item.get("fulltext", ""), model, tokenizer, max_tokens=3072, retries=2) |
| else: |
| f_sub = existing_entry["fulltext_subclaims"] |
|
|
| |
| if not existing_entry or not isinstance(existing_entry.get("summary_subclaims"), list): |
| s_sub = infer_subclaims(item.get("summary", ""), model, tokenizer, max_tokens=2048, retries=1) |
| else: |
| s_sub = existing_entry["summary_subclaims"] |
|
|
| |
| diff_label_texts = item.get("diff_label_texts", {}) |
| diff_label_subclaims = existing_entry.get("diff_label_subclaims", {}) if existing_entry else {} |
|
|
| for label, text in diff_label_texts.items(): |
| if label not in diff_label_subclaims or not isinstance(diff_label_subclaims[label], list): |
| |
| diff_label_subclaims[label] = infer_subclaims(text, model, tokenizer, max_tokens=1536, retries=1) |
|
|
| |
| new_entry = { |
| "index": item.get("index"), |
| "id": item.get("id"), |
| "fulltext": item.get("fulltext", ""), |
| "fulltext_subclaims": f_sub, |
| "summary": item.get("summary", ""), |
| "summary_subclaims": s_sub, |
| "diff_label_texts": diff_label_texts, |
| "diff_label_subclaims": diff_label_subclaims, |
| "readability_score": item.get("readability_score", None) |
| } |
| processed_data[item_id] = new_entry |
|
|
| if len(processed_data) % 10 == 0: |
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False) |
|
|
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(list(processed_data.values()), f, indent=4, ensure_ascii=False) |
|
|
| print(f"Extraction completed. File saved at: {OUTPUT_FILE}") |