vivekjada commited on
Commit
353ae91
·
verified ·
1 Parent(s): 4b1743f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+
6
+ MODEL_ADAPTER_ID = "vivekjada/medical-o1-llm-sft-lora"
7
+ BASE_ID = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit"
8
+
9
+ bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
10
+ base = AutoModelForCausalLM.from_pretrained(BASE_ID, quantization_config=bnb, device_map="auto")
11
+ model = PeftModel.from_pretrained(base, MODEL_ADAPTER_ID)
12
+ tok = AutoTokenizer.from_pretrained(BASE_ID)
13
+
14
+ SYSTEM = ("You are a careful medical assistant. You provide educational information—not medical advice. "
15
+ "Reason step-by-step and end with a concise final answer.")
16
+
17
+ def respond(question, max_new_tokens, temperature, top_p):
18
+ prompt = (f"<|system|>\n{SYSTEM}\n<|end|>\n"
19
+ f"<|user|>\n{question}\n<|end|>\n"
20
+ f"<|assistant|>\n")
21
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
22
+ with torch.no_grad():
23
+ out = model.generate(**inputs, max_new_tokens=int(max_new_tokens),
24
+ temperature=float(temperature), top_p=float(top_p),
25
+ do_sample=True, eos_token_id=tok.eos_token_id)
26
+ text = tok.decode(out[0], skip_special_tokens=True)
27
+ reply = text.split("<|assistant|>")[-1].strip()
28
+ return ("⚠️ **Disclaimer:** This demo is for educational purposes only and is **not** medical advice.\n\n"
29
+ + reply)
30
+
31
+ demo = gr.Interface(
32
+ fn=respond,
33
+ inputs=[
34
+ gr.Textbox(label="Enter a medical question", lines=6, placeholder="e.g., How to interpret borderline TSH in a 1st-trimester patient?"),
35
+ gr.Slider(64, 1024, value=384, step=32, label="Max new tokens"),
36
+ gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature"),
37
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
38
+ ],
39
+ outputs=gr.Markdown(label="Model response"),
40
+ title="Medical o1 Reasoning (SFT, LoRA)",
41
+ description="Llama-3.1-8B (Unsloth 4-bit) fine-tuned on medical o1 reasoning. Educational only."
42
+ )
43
+
44
+ if __name__ == "__main__":
45
+ demo.launch()