ARPZ's picture
add app file
4733bf9 verified
raw
history blame
2.04 kB
import streamlit as st
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.preprocessing import LabelEncoder
import torch
import numpy as np
# Load model and label encoder
@st.cache(allow_output_mutation=True)
def load_model_and_label_encoder():
fine_tuned_model = BertForSequenceClassification.from_pretrained('./fine_tuned_model')
label_encoder = LabelEncoder()
label_encoder.classes_ = np.load('./label_encoder_classes.npy', allow_pickle=True)
tokenizer = BertTokenizer.from_pretrained("./tokenizer") # Load tokenizer from local file
return fine_tuned_model, label_encoder, tokenizer
def predict(symptom, fine_tuned_model, label_encoder, tokenizer):
user_input_encoding = tokenizer(symptom, padding=True, truncation=True, return_tensors='pt', max_length=512, return_attention_mask=True, return_token_type_ids=True)
with torch.no_grad():
logits = fine_tuned_model(**user_input_encoding)
probabilities = torch.nn.functional.softmax(logits.logits, dim=1).numpy()[0]
predicted_labels = np.argsort(-probabilities)[:5]
predicted_diseases = label_encoder.inverse_transform(predicted_labels)
predicted_probabilities = probabilities[predicted_labels]
predictions = [{'disease': disease, 'probability': probability * 100} for disease, probability in zip(predicted_diseases, predicted_probabilities)]
return predictions
def main():
st.title("Disease Prediction App")
# Load model and label encoder
fine_tuned_model, label_encoder, tokenizer = load_model_and_label_encoder()
# Input
symptom = st.text_input("Enter symptom:", "")
# Predict
if st.button("Predict"):
if symptom:
predictions = predict(symptom, fine_tuned_model, label_encoder, tokenizer)
st.write("Top 5 Predictions:")
for prediction in predictions:
st.write(f"Disease: {prediction['disease']}, Probability: {prediction['probability']:.2f}%")
if __name__ == '__main__':
main()