SpecX / app.py
Mushfi's picture
Update app.py
96e0773 verified
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
import os
import numpy as np
import shap
import scipy.special
model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = tf.keras.models.load_model(
"model.h5",
custom_objects={
'TFRobertaModel': TFAutoModel.from_pretrained(model_name)
}
)
labels = [
'Cardiologist',
'Dermatologist',
'ENT Specialist',
'Gastro-enterologist',
'General-Physicians',
'Neurologist/Gastro-enterologist',
'Ophthalmologist',
'Orthopedist',
'Psychiatrist',
'Respirologist',
'Rheumatologist',
'Rheumatologist/Gastro-enterologist',
'Rheumatologist/Orthopedist',
'Surgeon'
]
seq_len = 152
def prep_data(text):
tokens = tokenizer(
text, max_length=seq_len, truncation=True,
padding='max_length',
add_special_tokens=True,
return_tensors='tf'
)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask']
}
def inference(text):
encoded_text = prep_data(text)
probs = model.predict_on_batch(encoded_text)
probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))}
return probabilities
def predictor(x):
input_ids = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['input_ids']
attention_mask = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['attention_mask']
outputs = model.predict([input_ids, attention_mask])
probas = tf.nn.softmax(outputs).numpy()
val = scipy.special.logit(probas[:,1])
return val
def f_batch(x):
val = np.array([])
for i in x:
val = np.append(val, predictor(i))
return val
explainer_roberta = shap.Explainer(f_batch, tokenizer)
shap_values = explainer_roberta(["When I remember her I feel down"])
def get_shap_data(input_text):
shap_values = explainer_roberta([input_text])
html_shap_content = shap.plots.text(shap_values, display=False)
return html_shap_content
all_options = {
'Heart hurts': "Immediate medical attention is required. Please seek emergency care.",
'Acne': "Consider dermatological treatment options such as topical creams, medications, or lifestyle changes.",
'Hair falling out': "Consult a dermatologist to address hair-related concerns.",
'Infected wound': "Keep the wound clean and dry. If signs of infection develop, seek immediate medical attention. Consult with a surgeon for evaluation.",
'Skin issue': "Seek advice from a dermatologist for personalized treatment options.",
'Cough': "Stay hydrated, get plenty of rest, and consider over-the-counter cough medications.",
'Ear ache': "Avoid inserting anything into the ear and consult an ENT specialist for proper examination and treatment.",
'Feeling cold': "Rest, stay warm, and drink plenty of fluids. Seek medical attention if symptoms worsen.",
'Stomach ache': "Avoid spicy or fatty foods, and consider over-the-counter antacids for temporary relief.",
'Internal pain': "It's important to get a proper diagnosis. Please consult with a gastroenterologist for further evaluation.",
'Open wound': "Keep the wound clean and dry. If signs of infection develop, seek immediate medical attention. Consult with a surgeon for evaluation.",
'Feeling dizzy': "Ensure you are well-hydrated and rested. If symptoms persist, consult with a general physician for evaluation.",
'Body feels weak': "Get plenty of rest, eat a balanced diet, and consider consulting with a general physician for further evaluation.",
'Head ache': "Ensure you are well-hydrated and rested. If headaches persist, consult with a neurologist for evaluation.",
'Blurry vision': "Ensure you are well-rested and take regular breaks from screens. If vision problems persist, consult with an ophthalmologist for evaluation.",
'Joint pain': "Rest, gentle stretching, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with an orthopedist for evaluation.",
'Knee pain': "Rest, gentle stretching, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with an orthopedist for evaluation.",
'Back pain': "Rest, gentle stretching, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with an orthopedist for evaluation.",
'Emotional pain': "Seek support from friends and family, and consider speaking with a therapist or psychiatrist for evaluation and support.",
'Hard to breath': "If experiencing difficulty breathing, seek immediate medical attention. Consult with a respirologist for evaluation.",
'Muscle pain': "Rest, gentle exercise, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with a rheumatologist for evaluation.",
'Injury from sports': "Rest, gentle exercise, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with a rheumatologist for evaluation.",
'Foot ache': "Rest, gentle exercise, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with a rheumatologist for evaluation.",
'Shoulder pain': "Rest, gentle stretching, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with a rheumatologist or gastroenterologist for evaluation.",
'Neck pain': "Rest, gentle stretching, and over-the-counter pain relievers may provide relief. If symptoms persist, consult with a rheumatologist or orthopedist for evaluation.",
}
def pt_f(x):
return all_options.get(x)
css = """
textarea {
background-color: #00000000;
border: 1px solid #6366f160;
}
"""
with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo:
with gr.Column():
hsdjif = gr.HTML('''
<img width="50" src="https://i.ibb.co/V9QgwnL/specx-logo.png" alt="specx-logo" border="0">
''')
textmd = gr.Markdown('''
<div style="margin: 50px 0;"></div>
<h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1>
''')
with gr.Row():
with gr.Column(scale=1, min_width=600):
text_box = gr.Textbox(label="Explain your problem in one sentence.")
with gr.Row():
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
submit_btn_shap = gr.Button("Get SHAP Analysis", variant='secondary')
examples = gr.Examples(examples=[
"When I remember her I feel down",
"The area around my heart doesn't feel good.",
"I have a split on my thumb that will not heal."
], inputs=text_box)
with gr.Column():
label = gr.Label(num_top_classes=4, label="Recommended Specialist")
gr.Markdown("## SHAP Analysis")
shap_content_box = gr.HTML()
submit_btn.click(inference, inputs=text_box, outputs=label)
submit_btn_shap.click(get_shap_data, inputs=text_box, outputs=shap_content_box)
with gr.Row():
textmd1 = gr.Markdown('''
<div style="margin: 50px 0;"></div>
''')
with gr.Column():
textmd1 = gr.Markdown('''
## Personalised Treatment
''')
with gr.Row():
with gr.Column():
specialist_symptom = gr.Dropdown(
label="Select Specialist Type",
choices=list(all_options.keys()),
)
submit_btn2 = gr.Button("Submit", elem_id="warningk", variant='primary')
with gr.Column():
pt_text = gr.Textbox(label="Recommendation")
submit_btn2.click(pt_f, inputs=specialist_symptom, outputs=pt_text)
with gr.Row():
textmd1 = gr.Markdown('''
<div style="margin: 50px 0;"></div>
''')
with gr.Column():
textmd1 = gr.Markdown('''
## Telemedicine Appointment
''')
with gr.Row():
gr.HTML("""<iframe src="https://docs.google.com/forms/d/e/1FAIpQLSerfTGjCtMwHDTRcxKEhbe6u78YR5OjsEFqmFNbuLRu2ocWcw/viewform?embedded=true" width="100%" height="1700" frameborder="0" marginheight="0" marginwidth="0">Loading…</iframe>""")
with gr.Column():
textmd1 = gr.Markdown('''
## Download SpecX App
''')
serdger = gr.HTML("""
<img width="300" src="https://i.ibb.co/W25SHDD/Screenshot-2024-04-28-235911.png" alt="QR for SpecX App" border="0">
""")
with gr.Row():
with gr.Column():
textmd1 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path opacity="0.4" d="M10.7509 2.45007C11.4409 1.86007 12.5709 1.86007 13.2709 2.45007L14.8509 3.81007C15.1509 4.07007 15.7109 4.28007 16.1109 4.28007H17.8109C18.8709 4.28007 19.7409 5.15007 19.7409 6.21007V7.91007C19.7409 8.30007 19.9509 8.87007 20.2109 9.17007L21.5709 10.7501C22.1609 11.4401 22.1609 12.5701 21.5709 13.2701L20.2109 14.8501C19.9509 15.1501 19.7409 15.7101 19.7409 16.1101V17.8101C19.7409 18.8701 18.8709 19.7401 17.8109 19.7401H16.1109C15.7209 19.7401 15.1509 19.9501 14.8509 20.2101L13.2709 21.5701C12.5809 22.1601 11.4509 22.1601 10.7509 21.5701L9.17086 20.2101C8.87086 19.9501 8.31086 19.7401 7.91086 19.7401H6.18086C5.12086 19.7401 4.25086 18.8701 4.25086 17.8101V16.1001C4.25086 15.7101 4.04086 15.1501 3.79086 14.8501L2.44086 13.2601C1.86086 12.5701 1.86086 11.4501 2.44086 10.7601L3.79086 9.17007C4.04086 8.87007 4.25086 8.31007 4.25086 7.92007V6.20007C4.25086 5.14007 5.12086 4.27007 6.18086 4.27007H7.91086C8.30086 4.27007 8.87086 4.06007 9.17086 3.80007L10.7509 2.45007Z" fill="var(--neutral-700)"/>
<path d="M10.7905 15.17C10.5905 15.17 10.4005 15.09 10.2605 14.95L7.84055 12.53C7.55055 12.24 7.55055 11.76 7.84055 11.47C8.13055 11.18 8.61055 11.18 8.90055 11.47L10.7905 13.36L15.0905 9.06003C15.3805 8.77003 15.8605 8.77003 16.1505 9.06003C16.4405 9.35003 16.4405 9.83003 16.1505 10.12L11.3205 14.95C11.1805 15.09 10.9905 15.17 10.7905 15.17Z" fill="#ececec"/>
</svg>
<h2 style="text-align: center">98% Success Rate</h2>
</div>
''')
with gr.Column():
textmd2 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg width="50px" height="50px" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 33C8.66666 33 4 31.5 4 25.5C4 18.5 11 17 13 17C14 13.5 16 8 24 8C31 8 34 12 35 15.5C35 15.5 44 16.5 44 25C44 31 40 33 36 33" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M18 33L24 38L32 28" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<h2 style="text-align: center">99.9% Model uptime</h2>
</div>
''')
with gr.Column():
textmd3 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path opacity="0.4" d="M17.9981 7.16C17.9381 7.15 17.8681 7.15 17.8081 7.16C16.4281 7.11 15.3281 5.98 15.3281 4.58C15.3281 3.15 16.4781 2 17.9081 2C19.3381 2 20.4881 3.16 20.4881 4.58C20.4781 5.98 19.3781 7.11 17.9981 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path opacity="0.4" d="M16.9675 14.4402C18.3375 14.6702 19.8475 14.4302 20.9075 13.7202C22.3175 12.7802 22.3175 11.2402 20.9075 10.3002C19.8375 9.59016 18.3075 9.35016 16.9375 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path opacity="0.4" d="M5.96656 7.16C6.02656 7.15 6.09656 7.15 6.15656 7.16C7.53656 7.11 8.63656 5.98 8.63656 4.58C8.63656 3.15 7.48656 2 6.05656 2C4.62656 2 3.47656 3.16 3.47656 4.58C3.48656 5.98 4.58656 7.11 5.96656 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path opacity="0.4" d="M6.9975 14.4402C5.6275 14.6702 4.1175 14.4302 3.0575 13.7202C1.6475 12.7802 1.6475 11.2402 3.0575 10.3002C4.1275 9.59016 5.6575 9.35016 7.0275 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12.0001 14.6302C11.9401 14.6202 11.8701 14.6202 11.8101 14.6302C10.4301 14.5802 9.33008 13.4502 9.33008 12.0502C9.33008 10.6202 10.4801 9.47021 11.9101 9.47021C13.3401 9.47021 14.4901 10.6302 14.4901 12.0502C14.4801 13.4502 13.3801 14.5902 12.0001 14.6302Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M9.0907 17.7804C7.6807 18.7204 7.6807 20.2603 9.0907 21.2003C10.6907 22.2703 13.3107 22.2703 14.9107 21.2003C16.3207 20.2603 16.3207 18.7204 14.9107 17.7804C13.3207 16.7204 10.6907 16.7204 9.0907 17.7804Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<h2 style="text-align: center">77,223 Monthly active users</h2>
</div>
''')
with gr.Column():
textmd3 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg fill="var(--neutral-700)" width="50px" height="50px" viewBox="0 0 100 100" id="Layer_1" version="1.1" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g>
<path d="M93.998,45.312c0-3.676-1.659-7.121-4.486-9.414c0.123-0.587,0.184-1.151,0.184-1.706c0-4.579-3.386-8.382-7.785-9.037 c0.101-0.526,0.149-1.042,0.149-1.556c0-4.875-3.842-8.858-8.655-9.111c-0.079-0.013-0.159-0.024-0.242-0.024 c-0.04,0-0.079,0.005-0.12,0.006c-0.04-0.001-0.079-0.006-0.12-0.006c-0.458,0-0.919,0.041-1.406,0.126 c-0.846-4.485-4.753-7.825-9.437-7.825c-5.311,0-9.632,4.321-9.632,9.633v65.918c0,6.723,5.469,12.191,12.191,12.191 c4.46,0,8.508-2.413,10.646-6.246c0.479,0.104,0.939,0.168,1.401,0.198c2.903,0.185,5.73-0.766,7.926-2.693 c2.196-1.927,3.51-4.594,3.7-7.51c0.079-1.215-0.057-2.434-0.403-3.638c3.796-2.691,6.027-6.952,6.027-11.621 c0-3.385-1.219-6.635-3.445-9.224C92.731,51.505,93.998,48.471,93.998,45.312z M90.938,62.999c0,3.484-1.582,6.68-4.295,8.819 c-2.008-3.196-5.57-5.237-9.427-5.237c-0.828,0-1.5,0.672-1.5,1.5s0.672,1.5,1.5,1.5c3.341,0,6.384,2.093,7.582,5.208 c0.41,1.088,0.592,2.189,0.521,3.274c-0.138,2.116-1.091,4.051-2.685,5.449c-1.594,1.399-3.641,2.094-5.752,1.954 c-0.594-0.039-1.208-0.167-1.933-0.402c-0.74-0.242-1.541,0.124-1.846,0.84c-1.445,3.404-4.768,5.604-8.465,5.604 c-5.068,0-9.191-4.123-9.191-9.191V16.399c0-3.657,2.975-6.633,6.632-6.633c3.398,0,6.194,2.562,6.558,5.908 c-2.751,1.576-4.612,4.535-4.612,7.926c0,0.829,0.672,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.343,2.689-6.065,6.016-6.13 c3.327,0.065,6.016,2.787,6.016,6.129c0,0.622-0.117,1.266-0.359,1.971c-0.057,0.166-0.084,0.34-0.081,0.515 c0.001,0.041,0.003,0.079,0.007,0.115c-0.006,0.021-0.01,0.035-0.01,0.035c-0.118,0.465-0.006,0.959,0.301,1.328 c0.307,0.369,0.765,0.569,1.251,0.538c0.104-0.007,0.208-0.02,0.392-0.046c3.383,0,6.136,2.753,6.136,6.136 c0,0.572-0.103,1.159-0.322,1.849c-0.203,0.635,0.038,1.328,0.591,1.7c2.434,1.639,3.909,4.329,4.014,7.242 c0,0.004-0.001,0.008-0.001,0.012c0,5.03-4.092,9.123-9.122,9.123s-9.123-4.093-9.123-9.123c0-0.829-0.672-1.5-1.5-1.5 s-1.5,0.671-1.5,1.5c0,6.685,5.438,12.123,12.123,12.123c2.228,0,4.31-0.615,6.106-1.668C89.88,57.539,90.938,60.212,90.938,62.999 z"/>
<path d="M38.179,6.766c-4.684,0-8.59,3.34-9.435,7.825c-0.488-0.085-0.949-0.126-1.407-0.126c-0.04,0-0.079,0.005-0.12,0.006 c-0.04-0.001-0.079-0.006-0.12-0.006c-0.083,0-0.163,0.011-0.242,0.024c-4.813,0.253-8.654,4.236-8.654,9.111 c0,0.514,0.049,1.03,0.149,1.556c-4.399,0.655-7.785,4.458-7.785,9.037c0,0.554,0.061,1.118,0.184,1.706 c-2.827,2.293-4.486,5.738-4.486,9.414c0,3.159,1.266,6.193,3.505,8.463c-2.227,2.589-3.446,5.839-3.446,9.224 c0,4.669,2.231,8.929,6.027,11.621c-0.347,1.204-0.482,2.423-0.402,3.639c0.19,2.915,1.503,5.582,3.699,7.509 c2.196,1.928,5.015,2.879,7.926,2.693c0.455-0.03,0.919-0.096,1.4-0.199c2.138,3.834,6.186,6.247,10.646,6.247 c6.722,0,12.191-5.469,12.191-12.191V16.399C47.811,11.087,43.49,6.766,38.179,6.766z M44.811,82.317 c0,5.068-4.123,9.191-9.191,9.191c-3.697,0-7.02-2.2-8.464-5.604c-0.241-0.567-0.793-0.914-1.381-0.914 c-0.154,0-0.311,0.023-0.465,0.074c-0.724,0.235-1.338,0.363-1.933,0.402c-2.119,0.139-4.158-0.556-5.751-1.954 c-1.594-1.398-2.547-3.333-2.685-5.449c-0.076-1.16,0.125-2.336,0.598-3.495c0.007-0.017,0.005-0.036,0.011-0.053 c1.342-3.056,4.225-4.953,7.597-4.953c0.829,0,1.5-0.672,1.5-1.5s-0.671-1.5-1.5-1.5c-3.938,0-7.501,2.007-9.548,5.239 c-2.701-2.139-4.277-5.327-4.277-8.802c0-2.787,1.06-5.46,2.978-7.549c1.796,1.053,3.879,1.668,6.107,1.668 c6.685,0,12.123-5.438,12.123-12.123c0-0.829-0.671-1.5-1.5-1.5s-1.5,0.671-1.5,1.5c0,5.03-4.092,9.123-9.123,9.123 s-9.123-4.093-9.123-9.123c0-0.002-0.001-0.004-0.001-0.006c0.103-2.915,1.578-5.607,4.013-7.248 c0.553-0.372,0.793-1.064,0.591-1.699c-0.22-0.691-0.322-1.278-0.322-1.85c0-3.376,2.741-6.125,6.195-6.125 c0.007,0,0.015,0,0.022,0c0.103,0.014,0.206,0.027,0.311,0.034c0.485,0.03,0.948-0.171,1.254-0.542 c0.307-0.372,0.417-0.868,0.294-1.334c0-0.001-0.003-0.014-0.008-0.031c0.003-0.035,0.006-0.067,0.007-0.095 c0.005-0.18-0.022-0.359-0.081-0.529c-0.242-0.707-0.359-1.352-0.359-1.972c0-3.342,2.688-6.065,6.016-6.129 c3.328,0.065,6.016,2.787,6.016,6.13c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.391-1.861-6.35-4.612-7.926 c0.364-3.346,3.16-5.908,6.558-5.908c3.657,0,6.632,2.976,6.632,6.633V82.317z"/>
</g>
</svg>
<h2 style="text-align: center">990,224 total model predictions made</h2>
</div>
''')
demo.launch()