File size: 2,040 Bytes
4733bf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.preprocessing import LabelEncoder
import torch
import numpy as np

# Load model and label encoder
@st.cache(allow_output_mutation=True)
def load_model_and_label_encoder():
    fine_tuned_model = BertForSequenceClassification.from_pretrained('./fine_tuned_model')
    label_encoder = LabelEncoder()
    label_encoder.classes_ = np.load('./label_encoder_classes.npy', allow_pickle=True)
    tokenizer = BertTokenizer.from_pretrained("./tokenizer")  # Load tokenizer from local file
    return fine_tuned_model, label_encoder, tokenizer

def predict(symptom, fine_tuned_model, label_encoder, tokenizer):
    user_input_encoding = tokenizer(symptom, padding=True, truncation=True, return_tensors='pt', max_length=512, return_attention_mask=True, return_token_type_ids=True)

    with torch.no_grad():
        logits = fine_tuned_model(**user_input_encoding)
        probabilities = torch.nn.functional.softmax(logits.logits, dim=1).numpy()[0]
        predicted_labels = np.argsort(-probabilities)[:5]
        predicted_diseases = label_encoder.inverse_transform(predicted_labels)
        predicted_probabilities = probabilities[predicted_labels]

    predictions = [{'disease': disease, 'probability': probability * 100} for disease, probability in zip(predicted_diseases, predicted_probabilities)]
    return predictions

def main():
    st.title("Disease Prediction App")

    # Load model and label encoder
    fine_tuned_model, label_encoder, tokenizer = load_model_and_label_encoder()

    # Input
    symptom = st.text_input("Enter symptom:", "")

    # Predict
    if st.button("Predict"):
        if symptom:
            predictions = predict(symptom, fine_tuned_model, label_encoder, tokenizer)
            st.write("Top 5 Predictions:")
            for prediction in predictions:
                st.write(f"Disease: {prediction['disease']}, Probability: {prediction['probability']:.2f}%")

if __name__ == '__main__':
    main()