Symptom_Cheking / app.py
Mohsen-drnext's picture
Upload 12 files
3bb9a91 verified
import gradio as gr
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.preprocessing import LabelEncoder
import torch
import numpy as np
# Load model and label encoder
fine_tuned_model = BertForSequenceClassification.from_pretrained('./fine_tuned_model')
label_encoder = LabelEncoder()
label_encoder.classes_ = np.load('./label_encoder_classes.npy', allow_pickle=True)
tokenizer = BertTokenizer.from_pretrained("./tokenizer") # Load tokenizer from local file
def predict(symptom):
user_input_encoding = tokenizer(symptom, padding=True, truncation=True, return_tensors='pt', max_length=512, return_attention_mask=True, return_token_type_ids=True)
with torch.no_grad():
logits = fine_tuned_model(**user_input_encoding)
probabilities = torch.nn.functional.softmax(logits.logits, dim=1).numpy()[0]
predicted_labels = np.argsort(-probabilities)[:5]
predicted_diseases = label_encoder.inverse_transform(predicted_labels)
predicted_probabilities = probabilities[predicted_labels]
predictions = [{'disease': disease, 'probability': probability * 100} for disease, probability in zip(predicted_diseases, predicted_probabilities)]
return predictions
# Gradio Interface
demo = gr.Interface(fn=predict, inputs="text", outputs="text")
demo.launch(share=True) # Share your demo with just 1 extra parameter πŸš€