|
import os
|
|
from pathlib import Path
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from peft import PeftModel
|
|
import torch, json
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf"
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
MODEL_DIR = BASE_DIR / "medbot_model"
|
|
|
|
|
|
GEN_CFG = {
|
|
"max_new_tokens": 160,
|
|
"temperature": 0.7,
|
|
"top_p": 0.9,
|
|
"do_sample": True,
|
|
}
|
|
cfg_path = MODEL_DIR / "generation_config.json"
|
|
if cfg_path.exists():
|
|
try:
|
|
GEN_CFG.update(json.loads(cfg_path.read_text(encoding="utf-8")))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
base_txt = MODEL_DIR / "BASE_MODEL.txt"
|
|
if base_txt.exists():
|
|
t = base_txt.read_text(encoding="utf-8").strip()
|
|
if t:
|
|
base_model_id = t
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
MODEL_DIR.as_posix(),
|
|
use_fast=False,
|
|
legacy=True
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
use_cuda = torch.cuda.is_available()
|
|
dtype = torch.float16 if use_cuda else torch.float32
|
|
device_map = "auto" if use_cuda else "cpu"
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
base_model_id,
|
|
torch_dtype=dtype,
|
|
device_map=device_map,
|
|
)
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, MODEL_DIR.as_posix())
|
|
model.eval()
|
|
for p in model.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
def _format_prompt(user_text: str) -> str:
|
|
if hasattr(tokenizer, "apply_chat_template"):
|
|
return tokenizer.apply_chat_template(
|
|
[
|
|
{"role": "system", "content": "You are a helpful medical assistant."},
|
|
{"role": "user", "content": user_text.strip()},
|
|
],
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
return f"<|system|>\nYou are a helpful medical assistant.\n<|user|>\n{user_text.strip()}\n<|assistant|>"
|
|
|
|
@torch.inference_mode()
|
|
def _generate(prompt: str) -> str:
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
out = model.generate(
|
|
**inputs,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
**GEN_CFG,
|
|
)
|
|
new_tokens = out[0, inputs["input_ids"].shape[-1]:]
|
|
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
|
|
|
def get_answer(question: str) -> str:
|
|
text = (question or "").strip()
|
|
if not text:
|
|
return "Please enter a question."
|
|
ans = _generate(_format_prompt(text))
|
|
disclaimer = ("MedBot provides general information only and is not a substitute for professional medical advice. "
|
|
"If this is an emergency, call your local emergency number.")
|
|
return f"{disclaimer}\n\n{ans or 'I’m sorry—please rephrase your question.'}"
|
|
|