ARPZ's picture
add app file
4733bf9 verified
raw
history blame
No virus
2.04 kB
import streamlit as st
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.preprocessing import LabelEncoder
import torch
import numpy as np
# Load model and label encoder
@st.cache(allow_output_mutation=True)
def load_model_and_label_encoder():
fine_tuned_model = BertForSequenceClassification.from_pretrained('./fine_tuned_model')
label_encoder = LabelEncoder()
label_encoder.classes_ = np.load('./label_encoder_classes.npy', allow_pickle=True)
tokenizer = BertTokenizer.from_pretrained("./tokenizer") # Load tokenizer from local file
return fine_tuned_model, label_encoder, tokenizer
def predict(symptom, fine_tuned_model, label_encoder, tokenizer):
user_input_encoding = tokenizer(symptom, padding=True, truncation=True, return_tensors='pt', max_length=512, return_attention_mask=True, return_token_type_ids=True)
with torch.no_grad():
logits = fine_tuned_model(**user_input_encoding)
probabilities = torch.nn.functional.softmax(logits.logits, dim=1).numpy()[0]
predicted_labels = np.argsort(-probabilities)[:5]
predicted_diseases = label_encoder.inverse_transform(predicted_labels)
predicted_probabilities = probabilities[predicted_labels]
predictions = [{'disease': disease, 'probability': probability * 100} for disease, probability in zip(predicted_diseases, predicted_probabilities)]
return predictions
def main():
st.title("Disease Prediction App")
# Load model and label encoder
fine_tuned_model, label_encoder, tokenizer = load_model_and_label_encoder()
# Input
symptom = st.text_input("Enter symptom:", "")
# Predict
if st.button("Predict"):
if symptom:
predictions = predict(symptom, fine_tuned_model, label_encoder, tokenizer)
st.write("Top 5 Predictions:")
for prediction in predictions:
st.write(f"Disease: {prediction['disease']}, Probability: {prediction['probability']:.2f}%")
if __name__ == '__main__':
main()