idk / src /api.py
Reyall's picture
Update src/api.py
b122186 verified
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import List
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import pickle
import random
from collections import defaultdict
app = FastAPI()
# Model və tokenizer yükləmə
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained("best_model")
model.eval()
with open("best_model/label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
class PredictionResponse(BaseModel):
disease: str
probability: float
@app.get("/predict", response_model=List[PredictionResponse])
def predict(symptoms: str = Query(..., description="Comma-separated symptoms")):
symptoms_list = [s.strip() for s in symptoms.split(",") if s.strip()]
agg_probs = defaultdict(float)
n_shuffles = 10
for _ in range(n_shuffles):
random.shuffle(symptoms_list)
shuffled_text = ", ".join(symptoms_list)
inputs = tokenizer(shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
for i, p in enumerate(probs):
agg_probs[i] += p.item()
for k in agg_probs:
agg_probs[k] /= n_shuffles
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
results = []
for idx, prob in top_3:
label = label_encoder.classes_[idx]
results.append({"disease": label, "probability": float(prob)})
return results