File size: 13,205 Bytes
0495d53 baa9cea 0495d53 4255cac 0495d53 baa9cea 0495d53 baa9cea 0495d53 baa9cea 0495d53 baa9cea 0495d53 baa9cea 0495d53 baa9cea 0495d53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
import os
import numpy as np
import shap
import scipy.special
model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = tf.keras.models.load_model(
"model.h5",
custom_objects={
'TFRobertaModel': TFAutoModel.from_pretrained(model_name)
}
)
labels = [
'Cardiologist',
'Dermatologist',
'ENT Specialist',
'Gastro-enterologist',
'General-Physicians',
'Neurologist/Gastro-enterologist',
'Ophthalmologist',
'Orthopedist',
'Psychiatrist',
'Respirologist',
'Rheumatologist',
'Rheumatologist/Gastro-enterologist',
'Rheumatologist/Orthopedist',
'Surgeon'
]
seq_len = 152
def prep_data(text):
tokens = tokenizer(
text, max_length=seq_len, truncation=True,
padding='max_length',
add_special_tokens=True,
return_tensors='tf'
)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask']
}
def inference(text):
encoded_text = prep_data(text)
probs = model.predict_on_batch(encoded_text)
probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))}
return probabilities
def predictor(x):
input_ids = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['input_ids']
attention_mask = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['attention_mask']
outputs = model.predict([input_ids, attention_mask])
probas = tf.nn.softmax(outputs).numpy()
val = scipy.special.logit(probas[:,1])
return val
def f_batch(x):
val = np.array([])
for i in x:
val = np.append(val, predictor(i))
return val
explainer_roberta = shap.Explainer(f_batch, tokenizer)
shap_values = explainer_roberta(["When I remember her I feel down"])
def get_shap_data(input_text):
shap_values = explainer_roberta([input_text])
html_shap_content = shap.plots.text(shap_values, display=False)
return html_shap_content
css = """
textarea {
background-color: #00000000;
border: 1px solid #6366f160;
}
"""
with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo:
with gr.Row():
textmd = gr.Markdown('''
<div style="margin: 50px 0;"></div>
<h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1>
''')
with gr.Row():
with gr.Column(scale=1, min_width=600):
text_box = gr.Textbox(label="Explain your problem in one sentence.")
with gr.Row():
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
submit_btn_shap = gr.Button("Get SHAP Analysis", variant='secondary')
examples = gr.Examples(examples=[
"When I remember her I feel down",
"The area around my heart doesn't feel good.",
"I have a split on my thumb that will not heal."
], inputs=text_box)
with gr.Column():
label = gr.Label(num_top_classes=4, label="Recommended Specialist")
gr.Markdown("## SHAP Analysis")
shap_content_box = gr.HTML()
submit_btn.click(inference, inputs=text_box, outputs=label)
submit_btn_shap.click(get_shap_data, inputs=text_box, outputs=shap_content_box)
with gr.Row():
textmd1 = gr.Markdown('''
<div style="margin: 50px 0;"></div>
''')
with gr.Row():
with gr.Column():
textmd1 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path opacity="0.4" d="M10.7509 2.45007C11.4409 1.86007 12.5709 1.86007 13.2709 2.45007L14.8509 3.81007C15.1509 4.07007 15.7109 4.28007 16.1109 4.28007H17.8109C18.8709 4.28007 19.7409 5.15007 19.7409 6.21007V7.91007C19.7409 8.30007 19.9509 8.87007 20.2109 9.17007L21.5709 10.7501C22.1609 11.4401 22.1609 12.5701 21.5709 13.2701L20.2109 14.8501C19.9509 15.1501 19.7409 15.7101 19.7409 16.1101V17.8101C19.7409 18.8701 18.8709 19.7401 17.8109 19.7401H16.1109C15.7209 19.7401 15.1509 19.9501 14.8509 20.2101L13.2709 21.5701C12.5809 22.1601 11.4509 22.1601 10.7509 21.5701L9.17086 20.2101C8.87086 19.9501 8.31086 19.7401 7.91086 19.7401H6.18086C5.12086 19.7401 4.25086 18.8701 4.25086 17.8101V16.1001C4.25086 15.7101 4.04086 15.1501 3.79086 14.8501L2.44086 13.2601C1.86086 12.5701 1.86086 11.4501 2.44086 10.7601L3.79086 9.17007C4.04086 8.87007 4.25086 8.31007 4.25086 7.92007V6.20007C4.25086 5.14007 5.12086 4.27007 6.18086 4.27007H7.91086C8.30086 4.27007 8.87086 4.06007 9.17086 3.80007L10.7509 2.45007Z" fill="var(--neutral-700)"/>
<path d="M10.7905 15.17C10.5905 15.17 10.4005 15.09 10.2605 14.95L7.84055 12.53C7.55055 12.24 7.55055 11.76 7.84055 11.47C8.13055 11.18 8.61055 11.18 8.90055 11.47L10.7905 13.36L15.0905 9.06003C15.3805 8.77003 15.8605 8.77003 16.1505 9.06003C16.4405 9.35003 16.4405 9.83003 16.1505 10.12L11.3205 14.95C11.1805 15.09 10.9905 15.17 10.7905 15.17Z" fill="#ececec"/>
</svg>
<h2 style="text-align: center">98% Success Rate</h2>
</div>
''')
with gr.Column():
textmd2 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg width="50px" height="50px" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 33C8.66666 33 4 31.5 4 25.5C4 18.5 11 17 13 17C14 13.5 16 8 24 8C31 8 34 12 35 15.5C35 15.5 44 16.5 44 25C44 31 40 33 36 33" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M18 33L24 38L32 28" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<h2 style="text-align: center">99.9% Model uptime</h2>
</div>
''')
with gr.Column():
textmd3 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path opacity="0.4" d="M17.9981 7.16C17.9381 7.15 17.8681 7.15 17.8081 7.16C16.4281 7.11 15.3281 5.98 15.3281 4.58C15.3281 3.15 16.4781 2 17.9081 2C19.3381 2 20.4881 3.16 20.4881 4.58C20.4781 5.98 19.3781 7.11 17.9981 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path opacity="0.4" d="M16.9675 14.4402C18.3375 14.6702 19.8475 14.4302 20.9075 13.7202C22.3175 12.7802 22.3175 11.2402 20.9075 10.3002C19.8375 9.59016 18.3075 9.35016 16.9375 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path opacity="0.4" d="M5.96656 7.16C6.02656 7.15 6.09656 7.15 6.15656 7.16C7.53656 7.11 8.63656 5.98 8.63656 4.58C8.63656 3.15 7.48656 2 6.05656 2C4.62656 2 3.47656 3.16 3.47656 4.58C3.48656 5.98 4.58656 7.11 5.96656 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path opacity="0.4" d="M6.9975 14.4402C5.6275 14.6702 4.1175 14.4302 3.0575 13.7202C1.6475 12.7802 1.6475 11.2402 3.0575 10.3002C4.1275 9.59016 5.6575 9.35016 7.0275 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12.0001 14.6302C11.9401 14.6202 11.8701 14.6202 11.8101 14.6302C10.4301 14.5802 9.33008 13.4502 9.33008 12.0502C9.33008 10.6202 10.4801 9.47021 11.9101 9.47021C13.3401 9.47021 14.4901 10.6302 14.4901 12.0502C14.4801 13.4502 13.3801 14.5902 12.0001 14.6302Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M9.0907 17.7804C7.6807 18.7204 7.6807 20.2603 9.0907 21.2003C10.6907 22.2703 13.3107 22.2703 14.9107 21.2003C16.3207 20.2603 16.3207 18.7204 14.9107 17.7804C13.3207 16.7204 10.6907 16.7204 9.0907 17.7804Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<h2 style="text-align: center">77,223 Monthly active users</h2>
</div>
''')
with gr.Column():
textmd3 = gr.Markdown('''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div style="margin: 20px 0;"></div>
<svg fill="var(--neutral-700)" width="50px" height="50px" viewBox="0 0 100 100" id="Layer_1" version="1.1" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g>
<path d="M93.998,45.312c0-3.676-1.659-7.121-4.486-9.414c0.123-0.587,0.184-1.151,0.184-1.706c0-4.579-3.386-8.382-7.785-9.037 c0.101-0.526,0.149-1.042,0.149-1.556c0-4.875-3.842-8.858-8.655-9.111c-0.079-0.013-0.159-0.024-0.242-0.024 c-0.04,0-0.079,0.005-0.12,0.006c-0.04-0.001-0.079-0.006-0.12-0.006c-0.458,0-0.919,0.041-1.406,0.126 c-0.846-4.485-4.753-7.825-9.437-7.825c-5.311,0-9.632,4.321-9.632,9.633v65.918c0,6.723,5.469,12.191,12.191,12.191 c4.46,0,8.508-2.413,10.646-6.246c0.479,0.104,0.939,0.168,1.401,0.198c2.903,0.185,5.73-0.766,7.926-2.693 c2.196-1.927,3.51-4.594,3.7-7.51c0.079-1.215-0.057-2.434-0.403-3.638c3.796-2.691,6.027-6.952,6.027-11.621 c0-3.385-1.219-6.635-3.445-9.224C92.731,51.505,93.998,48.471,93.998,45.312z M90.938,62.999c0,3.484-1.582,6.68-4.295,8.819 c-2.008-3.196-5.57-5.237-9.427-5.237c-0.828,0-1.5,0.672-1.5,1.5s0.672,1.5,1.5,1.5c3.341,0,6.384,2.093,7.582,5.208 c0.41,1.088,0.592,2.189,0.521,3.274c-0.138,2.116-1.091,4.051-2.685,5.449c-1.594,1.399-3.641,2.094-5.752,1.954 c-0.594-0.039-1.208-0.167-1.933-0.402c-0.74-0.242-1.541,0.124-1.846,0.84c-1.445,3.404-4.768,5.604-8.465,5.604 c-5.068,0-9.191-4.123-9.191-9.191V16.399c0-3.657,2.975-6.633,6.632-6.633c3.398,0,6.194,2.562,6.558,5.908 c-2.751,1.576-4.612,4.535-4.612,7.926c0,0.829,0.672,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.343,2.689-6.065,6.016-6.13 c3.327,0.065,6.016,2.787,6.016,6.129c0,0.622-0.117,1.266-0.359,1.971c-0.057,0.166-0.084,0.34-0.081,0.515 c0.001,0.041,0.003,0.079,0.007,0.115c-0.006,0.021-0.01,0.035-0.01,0.035c-0.118,0.465-0.006,0.959,0.301,1.328 c0.307,0.369,0.765,0.569,1.251,0.538c0.104-0.007,0.208-0.02,0.392-0.046c3.383,0,6.136,2.753,6.136,6.136 c0,0.572-0.103,1.159-0.322,1.849c-0.203,0.635,0.038,1.328,0.591,1.7c2.434,1.639,3.909,4.329,4.014,7.242 c0,0.004-0.001,0.008-0.001,0.012c0,5.03-4.092,9.123-9.122,9.123s-9.123-4.093-9.123-9.123c0-0.829-0.672-1.5-1.5-1.5 s-1.5,0.671-1.5,1.5c0,6.685,5.438,12.123,12.123,12.123c2.228,0,4.31-0.615,6.106-1.668C89.88,57.539,90.938,60.212,90.938,62.999 z"/>
<path d="M38.179,6.766c-4.684,0-8.59,3.34-9.435,7.825c-0.488-0.085-0.949-0.126-1.407-0.126c-0.04,0-0.079,0.005-0.12,0.006 c-0.04-0.001-0.079-0.006-0.12-0.006c-0.083,0-0.163,0.011-0.242,0.024c-4.813,0.253-8.654,4.236-8.654,9.111 c0,0.514,0.049,1.03,0.149,1.556c-4.399,0.655-7.785,4.458-7.785,9.037c0,0.554,0.061,1.118,0.184,1.706 c-2.827,2.293-4.486,5.738-4.486,9.414c0,3.159,1.266,6.193,3.505,8.463c-2.227,2.589-3.446,5.839-3.446,9.224 c0,4.669,2.231,8.929,6.027,11.621c-0.347,1.204-0.482,2.423-0.402,3.639c0.19,2.915,1.503,5.582,3.699,7.509 c2.196,1.928,5.015,2.879,7.926,2.693c0.455-0.03,0.919-0.096,1.4-0.199c2.138,3.834,6.186,6.247,10.646,6.247 c6.722,0,12.191-5.469,12.191-12.191V16.399C47.811,11.087,43.49,6.766,38.179,6.766z M44.811,82.317 c0,5.068-4.123,9.191-9.191,9.191c-3.697,0-7.02-2.2-8.464-5.604c-0.241-0.567-0.793-0.914-1.381-0.914 c-0.154,0-0.311,0.023-0.465,0.074c-0.724,0.235-1.338,0.363-1.933,0.402c-2.119,0.139-4.158-0.556-5.751-1.954 c-1.594-1.398-2.547-3.333-2.685-5.449c-0.076-1.16,0.125-2.336,0.598-3.495c0.007-0.017,0.005-0.036,0.011-0.053 c1.342-3.056,4.225-4.953,7.597-4.953c0.829,0,1.5-0.672,1.5-1.5s-0.671-1.5-1.5-1.5c-3.938,0-7.501,2.007-9.548,5.239 c-2.701-2.139-4.277-5.327-4.277-8.802c0-2.787,1.06-5.46,2.978-7.549c1.796,1.053,3.879,1.668,6.107,1.668 c6.685,0,12.123-5.438,12.123-12.123c0-0.829-0.671-1.5-1.5-1.5s-1.5,0.671-1.5,1.5c0,5.03-4.092,9.123-9.123,9.123 s-9.123-4.093-9.123-9.123c0-0.002-0.001-0.004-0.001-0.006c0.103-2.915,1.578-5.607,4.013-7.248 c0.553-0.372,0.793-1.064,0.591-1.699c-0.22-0.691-0.322-1.278-0.322-1.85c0-3.376,2.741-6.125,6.195-6.125 c0.007,0,0.015,0,0.022,0c0.103,0.014,0.206,0.027,0.311,0.034c0.485,0.03,0.948-0.171,1.254-0.542 c0.307-0.372,0.417-0.868,0.294-1.334c0-0.001-0.003-0.014-0.008-0.031c0.003-0.035,0.006-0.067,0.007-0.095 c0.005-0.18-0.022-0.359-0.081-0.529c-0.242-0.707-0.359-1.352-0.359-1.972c0-3.342,2.688-6.065,6.016-6.129 c3.328,0.065,6.016,2.787,6.016,6.13c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.391-1.861-6.35-4.612-7.926 c0.364-3.346,3.16-5.908,6.558-5.908c3.657,0,6.632,2.976,6.632,6.633V82.317z"/>
</g>
</svg>
<h2 style="text-align: center">990,224 total model predictions made</h2>
</div>
''')
demo.launch() |