idk / src /streamlit_app.py
Reyall's picture
Update src/streamlit_app.py
eaeb0e4 verified
import os
import streamlit as st
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
import torch
import pickle
import random
from collections import defaultdict
import json
# Name encoder yükləmə funksiyası
def load_name_encoder():
file_path = os.path.join(os.getcwd(), "best_model", "name_encoder.pkl")
if not os.path.exists(file_path):
st.error(f"Name encoder faylı tapılmadı: {file_path}")
st.stop()
with open(file_path, "rb") as f:
name_encoder = pickle.load(f)
return name_encoder
# Model və tokenizer yükləmə
@st.cache_resource
def load_model():
name_encoder = load_name_encoder()
model_path = os.path.join(os.getcwd(), "best_model")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=len(name_encoder.classes_)
)
model.eval()
return tokenizer, model, name_encoder
# Prediction funksiyası
def predict_disease(symptoms_text, tokenizer, model, name_encoder):
symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
agg_probs = defaultdict(float)
n_shuffles = 10
for _ in range(n_shuffles):
random.shuffle(symptoms)
shuffled_text = ", ".join(symptoms)
inputs = tokenizer(
shuffled_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
for i, p in enumerate(probs):
agg_probs[i] += p.item()
for k in agg_probs:
agg_probs[k] /= n_shuffles
top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3]
results = []
for idx, prob in top_3:
label = name_encoder.classes_[idx]
results.append({"disease": label, "probability": prob})
return results
# Page config
st.set_page_config(page_title="Disease API", layout="wide")
# Query parametrlər
query_params = st.query_params
is_api_mode = str(query_params.get("api", ["false"])[0]).lower() == "true"
# Model yüklə
tokenizer, model, name_encoder = load_model()
# API mode
if is_api_mode:
symptoms = query_params.get("symptoms", [""])[0]
if symptoms:
results = predict_disease(symptoms, tokenizer, model, name_encoder)
api_response = {
"status": "success",
"input": symptoms,
"predictions": results
}
else:
api_response = {
"status": "error",
"message": "symptoms parameter required"
}
st.markdown(
f"```json\n{json.dumps(api_response, ensure_ascii=False, indent=2)}\n```"
)
st.stop()
# Web interfeys
st.title("🏥 Disease Prediction")
st.success("Model yükləndi!")
# Debug: Siniflər
st.write("Available classes:", list(name_encoder.classes_))
# API usage info
st.markdown("### API İstifadəsi")
space_url = "https://your-username-your-space-name.hf.space"
api_example = f"{space_url}/?api=true&symptoms=fever,cough,headache"
st.code(api_example)
# Form
with st.form(key="predict_form"):
text = st.text_area("Simptomları daxil edin (vergüllə ayırın):")
submit_button = st.form_submit_button(label="Predict")
if submit_button:
if not text.strip():
st.warning("Simptomları daxil edin!")
else:
results = predict_disease(text, tokenizer, model, name_encoder)
st.subheader("🔍 Nəticələr:")
for result in results:
st.write(f"**{result['disease']}** — {result['probability']*100:.2f}%")