SpecX / app.py
Mushfi's picture
Update app.py
425b074 verified
raw history blame
No virus
2.25 kB
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
import os
import numpy as np
model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = tf.keras.models.load_model(
"model.h5",
custom_objects={
'TFRobertaModel': TFAutoModel.from_pretrained(model_name)
}
)
labels = [
'Cardiologist',
'Dermatologist',
'ENT Specialist',
'Gastro-enterologist',
'General-Physicians',
'Neurologist/Gastro-enterologist',
'Ophthalmologist',
'Orthopedist',
'Psychiatrist',
'Respirologist',
'Rheumatologist',
'Rheumatologist/Gastro-enterologist',
'Rheumatologist/Orthopedist',
'Surgeon'
]
seq_len = 152
def prep_data(text):
tokens = tokenizer(
text, max_length=seq_len, truncation=True,
padding='max_length',
add_special_tokens=True,
return_tensors='tf'
)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask']
}
def inference(text):
encoded_text = prep_data(text)
probs = model.predict_on_batch(encoded_text)
probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))}
return probabilities
css = """
textarea {
background-color: #00000000;
border: 1px solid #6366f160;
}
"""
with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo:
with gr.Row():
textmd = gr.Markdown('''
<div style="margin: 50px 0;"></div>
<h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1>
''')
with gr.Row():
with gr.Column(scale=1, min_width=600):
text_box = gr.Textbox(label="Explain your problem in one sentence.")
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
examples = gr.Examples(examples=[
"When I remember her I feel down",
"The area around my heart doesn't feel good.",
"I have a split on my thumb that will not heal."
], inputs=text_box)
label = gr.Label(num_top_classes=4, label="Recommended Specialist")
submit_btn.click(inference, inputs=text_box, outputs=label)
demo.launch()