CVD-Predictor / main.py
narainkumbari's picture
Updated app with remote API inference and multi-input support
129775b
raw
history blame contribute delete
818 Bytes
# main.py
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
app = FastAPI()
MODEL_PATH = "Tufan1/BioMedLM-Cardio-Fold1-CPU"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.float16)
class PatientData(BaseModel):
input_text: str
@app.post("/predict")
def predict(data: PatientData):
inputs = tokenizer(data.input_text, return_tensors="pt").to("cuda")
model.eval()
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=4)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = decoded.split("Diagnosis:")[-1].strip()
return {"diagnosis": diagnosis}