|
|
import os |
|
|
import streamlit as st |
|
|
from transformers.models.bert import BertTokenizer, BertForSequenceClassification |
|
|
import torch |
|
|
import pickle |
|
|
import random |
|
|
from collections import defaultdict |
|
|
import json |
|
|
|
|
|
|
|
|
def load_name_encoder(): |
|
|
file_path = os.path.join(os.getcwd(), "best_model", "name_encoder.pkl") |
|
|
if not os.path.exists(file_path): |
|
|
st.error(f"Name encoder faylı tapılmadı: {file_path}") |
|
|
st.stop() |
|
|
with open(file_path, "rb") as f: |
|
|
name_encoder = pickle.load(f) |
|
|
return name_encoder |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
name_encoder = load_name_encoder() |
|
|
model_path = os.path.join(os.getcwd(), "best_model") |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
model = BertForSequenceClassification.from_pretrained( |
|
|
model_path, |
|
|
num_labels=len(name_encoder.classes_) |
|
|
) |
|
|
model.eval() |
|
|
return tokenizer, model, name_encoder |
|
|
|
|
|
|
|
|
def predict_disease(symptoms_text, tokenizer, model, name_encoder): |
|
|
symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()] |
|
|
agg_probs = defaultdict(float) |
|
|
n_shuffles = 10 |
|
|
|
|
|
for _ in range(n_shuffles): |
|
|
random.shuffle(symptoms) |
|
|
shuffled_text = ", ".join(symptoms) |
|
|
inputs = tokenizer( |
|
|
shuffled_text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=128 |
|
|
) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() |
|
|
|
|
|
for i, p in enumerate(probs): |
|
|
agg_probs[i] += p.item() |
|
|
|
|
|
for k in agg_probs: |
|
|
agg_probs[k] /= n_shuffles |
|
|
|
|
|
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
results = [] |
|
|
for idx, prob in top_3: |
|
|
label = name_encoder.classes_[idx] |
|
|
results.append({"disease": label, "probability": prob}) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Disease API", layout="wide") |
|
|
|
|
|
|
|
|
query_params = st.query_params |
|
|
is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true" |
|
|
|
|
|
|
|
|
tokenizer, model, name_encoder = load_model() |
|
|
|
|
|
|
|
|
if is_api_mode: |
|
|
symptoms = query_params.get("symptoms", [""])[0] |
|
|
if symptoms: |
|
|
results = predict_disease(symptoms, tokenizer, model, name_encoder) |
|
|
api_response = { |
|
|
"status": "success", |
|
|
"input": symptoms, |
|
|
"predictions": results |
|
|
} |
|
|
else: |
|
|
api_response = { |
|
|
"status": "error", |
|
|
"message": "symptoms parameter required" |
|
|
} |
|
|
|
|
|
st.markdown( |
|
|
f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```" |
|
|
) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
st.title("🏥 Disease Prediction") |
|
|
st.success("Model yükləndi!") |
|
|
|
|
|
|
|
|
st.write("Available classes:", list(name_encoder.classes_)) |
|
|
|
|
|
|
|
|
st.markdown("### API İstifadəsi") |
|
|
space_url = "https://your-username-your-space-name.hf.space" |
|
|
api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache" |
|
|
st.code(api_example) |
|
|
|
|
|
|
|
|
with st.form(key="predict_form"): |
|
|
text = st.text_area("Simptomları daxil edin (vergüllə ayırın):") |
|
|
submit_button = st.form_submit_button(label="Predict") |
|
|
|
|
|
if submit_button: |
|
|
if not text.strip(): |
|
|
st.warning("Simptomları daxil edin!") |
|
|
else: |
|
|
results = predict_disease(text, tokenizer, model, name_encoder) |
|
|
st.subheader("🔍 Nəticələr:") |
|
|
for result in results: |
|
|
st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%") |
|
|
|