|
|
from fastapi import FastAPI, Query |
|
|
from pydantic import BaseModel |
|
|
from typing import List |
|
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
import torch |
|
|
import pickle |
|
|
import random |
|
|
from collections import defaultdict |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
model = BertForSequenceClassification.from_pretrained("best_model") |
|
|
model.eval() |
|
|
|
|
|
with open("best_model/label_encoder.pkl", "rb") as f: |
|
|
label_encoder = pickle.load(f) |
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
disease: str |
|
|
probability: float |
|
|
|
|
|
@app.get("/predict", response_model=List[PredictionResponse]) |
|
|
def predict(symptoms: str = Query(..., description="Comma-separated symptoms")): |
|
|
symptoms_list = [s.strip() for s in symptoms.split(",") if s.strip()] |
|
|
agg_probs = defaultdict(float) |
|
|
n_shuffles = 10 |
|
|
|
|
|
for _ in range(n_shuffles): |
|
|
random.shuffle(symptoms_list) |
|
|
shuffled_text = ", ".join(symptoms_list) |
|
|
inputs = tokenizer(shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() |
|
|
for i, p in enumerate(probs): |
|
|
agg_probs[i] += p.item() |
|
|
for k in agg_probs: |
|
|
agg_probs[k] /= n_shuffles |
|
|
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
|
|
|
results = [] |
|
|
for idx, prob in top_3: |
|
|
label = label_encoder.classes_[idx] |
|
|
results.append({"disease": label, "probability": float(prob)}) |
|
|
|
|
|
return results |
|
|
|