| import json |
| from pathlib import Path |
| from openai import OpenAI |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| from unsloth.chat_templates import get_chat_template |
|
|
| |
| API_BASE = "http://172.16.34.22:8086/v1" |
| MODEL_PATH = "sc" |
| TOKENIZER_NAME = "meta-llama/Llama-3.1-8B-Instruct" |
| DATASET_FILE = Path("/home/mshahidul/readctrl/data/finetuning_data/finetune_dataset_subclaim_support_v2.json") |
| TEXT_VARIANT = "hard_text" |
|
|
| |
| client = OpenAI(api_key="EMPTY", base_url=API_BASE) |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) |
| tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1") |
|
|
|
|
| def render_chat_prompt(user_prompt: str) -> str: |
| messages = [{"role": "user", "content": user_prompt}] |
| template = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| import ipdb; ipdb.set_trace() |
| print(template) |
| return template |
|
|
| def build_user_prompt(text: str, subclaims: list[str]) -> str: |
| numbered_subclaims = "\n".join(f"{idx + 1}. {s}" for idx, s in enumerate(subclaims)) |
| return ( |
| "You are a medical evidence checker.\n" |
| "Given a medical passage and a list of subclaims, return labels for each " |
| "subclaim in the same order.\n\n" |
| "Allowed labels: supported, not_supported.\n" |
| "Output format: a JSON array of strings only.\n\n" |
| f"Medical text:\n{text}\n\n" |
| f"Subclaims:\n{numbered_subclaims}" |
| ) |
|
|
| def main(): |
| |
| raw_dataset = load_dataset("json", data_files=str(DATASET_FILE), split="train") |
| |
| |
| splits = raw_dataset.train_test_split(test_size=0.1, seed=3407, shuffle=True) |
| test_split = splits["test"] |
|
|
| print(f"Running inference on {len(test_split)} samples...") |
|
|
| results = [] |
| for row in test_split: |
| for item in row.get("items", []): |
| text = item.get(TEXT_VARIANT, "").strip() |
| subclaims = [s["subclaim"] for s in item.get("subclaims", [])] |
| gold_labels = [s["label"] for s in item.get("subclaims", [])] |
| |
| |
| |
| |
| |
|
|
| if not text or not subclaims: |
| continue |
|
|
| |
| prompt = render_chat_prompt(build_user_prompt(text, subclaims)) |
| response = client.completions.create( |
| model=MODEL_PATH, |
| prompt=prompt, |
| temperature=0, |
| max_tokens=256 |
| ) |
|
|
| pred_text = response.choices[0].text.strip() |
| |
| print(f"--- Sample ---") |
| print(f"Pred: {pred_text}") |
| print(f"Gold: {gold_labels}") |
| |
| results.append({ |
| "predicted": pred_text, |
| "gold": gold_labels |
| }) |
|
|
| |
| with open("inference_results.json", "w") as f: |
| json.dump(results, f, indent=4) |
|
|
| if __name__ == "__main__": |
| main() |