File size: 13,205 Bytes
0495d53
 
 
 
 
baa9cea
 
0495d53
 
 
 
4255cac
0495d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baa9cea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0495d53
 
 
 
 
 
 
 
 
baa9cea
0495d53
 
 
 
baa9cea
0495d53
 
 
baa9cea
 
 
0495d53
 
 
 
 
baa9cea
 
 
 
0495d53
baa9cea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0495d53
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
import os
import numpy as np
import shap
import scipy.special

model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = tf.keras.models.load_model(
    "model.h5",
    custom_objects={
        'TFRobertaModel': TFAutoModel.from_pretrained(model_name)
    }
)

labels = [
    'Cardiologist',
    'Dermatologist',
    'ENT Specialist',
    'Gastro-enterologist',
    'General-Physicians',
    'Neurologist/Gastro-enterologist', 
    'Ophthalmologist',
    'Orthopedist', 
    'Psychiatrist', 
    'Respirologist',
    'Rheumatologist',
    'Rheumatologist/Gastro-enterologist', 
    'Rheumatologist/Orthopedist',
    'Surgeon'
]
seq_len = 152

def prep_data(text):
    tokens = tokenizer(
        text, max_length=seq_len, truncation=True, 
        padding='max_length', 
        add_special_tokens=True, 
        return_tensors='tf'
    )
    return {
        'input_ids': tokens['input_ids'], 
        'attention_mask': tokens['attention_mask']
    }

def inference(text):
    encoded_text = prep_data(text)
    probs = model.predict_on_batch(encoded_text)
    probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))}
    return probabilities

def predictor(x):
    input_ids = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['input_ids']
    attention_mask = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['attention_mask']
    outputs = model.predict([input_ids, attention_mask])
    probas = tf.nn.softmax(outputs).numpy()
    val = scipy.special.logit(probas[:,1])
    return val

def f_batch(x):
    val = np.array([])
    for i in x:
        val = np.append(val, predictor(i))
    return val

explainer_roberta = shap.Explainer(f_batch, tokenizer)

shap_values = explainer_roberta(["When I remember her I feel down"])

def get_shap_data(input_text):
  shap_values = explainer_roberta([input_text])
  html_shap_content = shap.plots.text(shap_values, display=False)
  return html_shap_content

css = """
textarea {
    background-color: #00000000;
    border: 1px solid #6366f160;
}
"""
with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo:
  with gr.Row():
    textmd = gr.Markdown('''
    <div style="margin: 50px 0;"></div>

    <h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1>

    ''')

  with gr.Row():
    with gr.Column(scale=1, min_width=600):
      text_box = gr.Textbox(label="Explain your problem in one sentence.")
      with gr.Row():
        submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
        submit_btn_shap = gr.Button("Get SHAP Analysis", variant='secondary')
      examples = gr.Examples(examples=[
          "When I remember her I feel down",
          "The area around my heart doesn't feel good.",
          "I have a split on my thumb that will not heal."
          ], inputs=text_box)
    with gr.Column():
      label = gr.Label(num_top_classes=4, label="Recommended Specialist")
      gr.Markdown("## SHAP Analysis")
      shap_content_box = gr.HTML()
    submit_btn.click(inference, inputs=text_box, outputs=label)
    submit_btn_shap.click(get_shap_data, inputs=text_box, outputs=shap_content_box)
  with gr.Row():
    textmd1 = gr.Markdown('''
      <div style="margin: 50px 0;"></div>
    ''')
  with gr.Row():
    with gr.Column():
      textmd1 = gr.Markdown('''
      <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
      <div style="margin: 20px 0;"></div>
      <svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
        <path opacity="0.4" d="M10.7509 2.45007C11.4409 1.86007 12.5709 1.86007 13.2709 2.45007L14.8509 3.81007C15.1509 4.07007 15.7109 4.28007 16.1109 4.28007H17.8109C18.8709 4.28007 19.7409 5.15007 19.7409 6.21007V7.91007C19.7409 8.30007 19.9509 8.87007 20.2109 9.17007L21.5709 10.7501C22.1609 11.4401 22.1609 12.5701 21.5709 13.2701L20.2109 14.8501C19.9509 15.1501 19.7409 15.7101 19.7409 16.1101V17.8101C19.7409 18.8701 18.8709 19.7401 17.8109 19.7401H16.1109C15.7209 19.7401 15.1509 19.9501 14.8509 20.2101L13.2709 21.5701C12.5809 22.1601 11.4509 22.1601 10.7509 21.5701L9.17086 20.2101C8.87086 19.9501 8.31086 19.7401 7.91086 19.7401H6.18086C5.12086 19.7401 4.25086 18.8701 4.25086 17.8101V16.1001C4.25086 15.7101 4.04086 15.1501 3.79086 14.8501L2.44086 13.2601C1.86086 12.5701 1.86086 11.4501 2.44086 10.7601L3.79086 9.17007C4.04086 8.87007 4.25086 8.31007 4.25086 7.92007V6.20007C4.25086 5.14007 5.12086 4.27007 6.18086 4.27007H7.91086C8.30086 4.27007 8.87086 4.06007 9.17086 3.80007L10.7509 2.45007Z" fill="var(--neutral-700)"/>
        <path d="M10.7905 15.17C10.5905 15.17 10.4005 15.09 10.2605 14.95L7.84055 12.53C7.55055 12.24 7.55055 11.76 7.84055 11.47C8.13055 11.18 8.61055 11.18 8.90055 11.47L10.7905 13.36L15.0905 9.06003C15.3805 8.77003 15.8605 8.77003 16.1505 9.06003C16.4405 9.35003 16.4405 9.83003 16.1505 10.12L11.3205 14.95C11.1805 15.09 10.9905 15.17 10.7905 15.17Z" fill="#ececec"/>
      </svg>

      <h2 style="text-align: center">98% Success Rate</h2>
      </div>
      ''')
    with gr.Column():
      textmd2 = gr.Markdown('''
      <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
      <div style="margin: 20px 0;"></div>
      <svg width="50px" height="50px" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
      <path d="M12 33C8.66666 33 4 31.5 4 25.5C4 18.5 11 17 13 17C14 13.5 16 8 24 8C31 8 34 12 35 15.5C35 15.5 44 16.5 44 25C44 31 40 33 36 33" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
      <path d="M18 33L24 38L32 28" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
      </svg>

      <h2 style="text-align: center">99.9% Model uptime</h2>
      </div>
      ''')
    with gr.Column():
      textmd3 = gr.Markdown('''
      <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
      <div style="margin: 20px 0;"></div>
      <svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
      <path opacity="0.4" d="M17.9981 7.16C17.9381 7.15 17.8681 7.15 17.8081 7.16C16.4281 7.11 15.3281 5.98 15.3281 4.58C15.3281 3.15 16.4781 2 17.9081 2C19.3381 2 20.4881 3.16 20.4881 4.58C20.4781 5.98 19.3781 7.11 17.9981 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
      <path opacity="0.4" d="M16.9675 14.4402C18.3375 14.6702 19.8475 14.4302 20.9075 13.7202C22.3175 12.7802 22.3175 11.2402 20.9075 10.3002C19.8375 9.59016 18.3075 9.35016 16.9375 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
      <path opacity="0.4" d="M5.96656 7.16C6.02656 7.15 6.09656 7.15 6.15656 7.16C7.53656 7.11 8.63656 5.98 8.63656 4.58C8.63656 3.15 7.48656 2 6.05656 2C4.62656 2 3.47656 3.16 3.47656 4.58C3.48656 5.98 4.58656 7.11 5.96656 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
      <path opacity="0.4" d="M6.9975 14.4402C5.6275 14.6702 4.1175 14.4302 3.0575 13.7202C1.6475 12.7802 1.6475 11.2402 3.0575 10.3002C4.1275 9.59016 5.6575 9.35016 7.0275 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
      <path d="M12.0001 14.6302C11.9401 14.6202 11.8701 14.6202 11.8101 14.6302C10.4301 14.5802 9.33008 13.4502 9.33008 12.0502C9.33008 10.6202 10.4801 9.47021 11.9101 9.47021C13.3401 9.47021 14.4901 10.6302 14.4901 12.0502C14.4801 13.4502 13.3801 14.5902 12.0001 14.6302Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
      <path d="M9.0907 17.7804C7.6807 18.7204 7.6807 20.2603 9.0907 21.2003C10.6907 22.2703 13.3107 22.2703 14.9107 21.2003C16.3207 20.2603 16.3207 18.7204 14.9107 17.7804C13.3207 16.7204 10.6907 16.7204 9.0907 17.7804Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
      </svg>

      <h2 style="text-align: center">77,223 Monthly active users</h2>
      </div>
      ''')
    with gr.Column():
      textmd3 = gr.Markdown('''
      <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
      <div style="margin: 20px 0;"></div>
      <svg fill="var(--neutral-700)" width="50px" height="50px" viewBox="0 0 100 100" id="Layer_1" version="1.1" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
      <g>
      <path d="M93.998,45.312c0-3.676-1.659-7.121-4.486-9.414c0.123-0.587,0.184-1.151,0.184-1.706c0-4.579-3.386-8.382-7.785-9.037   c0.101-0.526,0.149-1.042,0.149-1.556c0-4.875-3.842-8.858-8.655-9.111c-0.079-0.013-0.159-0.024-0.242-0.024   c-0.04,0-0.079,0.005-0.12,0.006c-0.04-0.001-0.079-0.006-0.12-0.006c-0.458,0-0.919,0.041-1.406,0.126   c-0.846-4.485-4.753-7.825-9.437-7.825c-5.311,0-9.632,4.321-9.632,9.633v65.918c0,6.723,5.469,12.191,12.191,12.191   c4.46,0,8.508-2.413,10.646-6.246c0.479,0.104,0.939,0.168,1.401,0.198c2.903,0.185,5.73-0.766,7.926-2.693   c2.196-1.927,3.51-4.594,3.7-7.51c0.079-1.215-0.057-2.434-0.403-3.638c3.796-2.691,6.027-6.952,6.027-11.621   c0-3.385-1.219-6.635-3.445-9.224C92.731,51.505,93.998,48.471,93.998,45.312z M90.938,62.999c0,3.484-1.582,6.68-4.295,8.819   c-2.008-3.196-5.57-5.237-9.427-5.237c-0.828,0-1.5,0.672-1.5,1.5s0.672,1.5,1.5,1.5c3.341,0,6.384,2.093,7.582,5.208   c0.41,1.088,0.592,2.189,0.521,3.274c-0.138,2.116-1.091,4.051-2.685,5.449c-1.594,1.399-3.641,2.094-5.752,1.954   c-0.594-0.039-1.208-0.167-1.933-0.402c-0.74-0.242-1.541,0.124-1.846,0.84c-1.445,3.404-4.768,5.604-8.465,5.604   c-5.068,0-9.191-4.123-9.191-9.191V16.399c0-3.657,2.975-6.633,6.632-6.633c3.398,0,6.194,2.562,6.558,5.908   c-2.751,1.576-4.612,4.535-4.612,7.926c0,0.829,0.672,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.343,2.689-6.065,6.016-6.13   c3.327,0.065,6.016,2.787,6.016,6.129c0,0.622-0.117,1.266-0.359,1.971c-0.057,0.166-0.084,0.34-0.081,0.515   c0.001,0.041,0.003,0.079,0.007,0.115c-0.006,0.021-0.01,0.035-0.01,0.035c-0.118,0.465-0.006,0.959,0.301,1.328   c0.307,0.369,0.765,0.569,1.251,0.538c0.104-0.007,0.208-0.02,0.392-0.046c3.383,0,6.136,2.753,6.136,6.136   c0,0.572-0.103,1.159-0.322,1.849c-0.203,0.635,0.038,1.328,0.591,1.7c2.434,1.639,3.909,4.329,4.014,7.242   c0,0.004-0.001,0.008-0.001,0.012c0,5.03-4.092,9.123-9.122,9.123s-9.123-4.093-9.123-9.123c0-0.829-0.672-1.5-1.5-1.5   s-1.5,0.671-1.5,1.5c0,6.685,5.438,12.123,12.123,12.123c2.228,0,4.31-0.615,6.106-1.668C89.88,57.539,90.938,60.212,90.938,62.999   z"/>
      <path d="M38.179,6.766c-4.684,0-8.59,3.34-9.435,7.825c-0.488-0.085-0.949-0.126-1.407-0.126c-0.04,0-0.079,0.005-0.12,0.006   c-0.04-0.001-0.079-0.006-0.12-0.006c-0.083,0-0.163,0.011-0.242,0.024c-4.813,0.253-8.654,4.236-8.654,9.111   c0,0.514,0.049,1.03,0.149,1.556c-4.399,0.655-7.785,4.458-7.785,9.037c0,0.554,0.061,1.118,0.184,1.706   c-2.827,2.293-4.486,5.738-4.486,9.414c0,3.159,1.266,6.193,3.505,8.463c-2.227,2.589-3.446,5.839-3.446,9.224   c0,4.669,2.231,8.929,6.027,11.621c-0.347,1.204-0.482,2.423-0.402,3.639c0.19,2.915,1.503,5.582,3.699,7.509   c2.196,1.928,5.015,2.879,7.926,2.693c0.455-0.03,0.919-0.096,1.4-0.199c2.138,3.834,6.186,6.247,10.646,6.247   c6.722,0,12.191-5.469,12.191-12.191V16.399C47.811,11.087,43.49,6.766,38.179,6.766z M44.811,82.317   c0,5.068-4.123,9.191-9.191,9.191c-3.697,0-7.02-2.2-8.464-5.604c-0.241-0.567-0.793-0.914-1.381-0.914   c-0.154,0-0.311,0.023-0.465,0.074c-0.724,0.235-1.338,0.363-1.933,0.402c-2.119,0.139-4.158-0.556-5.751-1.954   c-1.594-1.398-2.547-3.333-2.685-5.449c-0.076-1.16,0.125-2.336,0.598-3.495c0.007-0.017,0.005-0.036,0.011-0.053   c1.342-3.056,4.225-4.953,7.597-4.953c0.829,0,1.5-0.672,1.5-1.5s-0.671-1.5-1.5-1.5c-3.938,0-7.501,2.007-9.548,5.239   c-2.701-2.139-4.277-5.327-4.277-8.802c0-2.787,1.06-5.46,2.978-7.549c1.796,1.053,3.879,1.668,6.107,1.668   c6.685,0,12.123-5.438,12.123-12.123c0-0.829-0.671-1.5-1.5-1.5s-1.5,0.671-1.5,1.5c0,5.03-4.092,9.123-9.123,9.123   s-9.123-4.093-9.123-9.123c0-0.002-0.001-0.004-0.001-0.006c0.103-2.915,1.578-5.607,4.013-7.248   c0.553-0.372,0.793-1.064,0.591-1.699c-0.22-0.691-0.322-1.278-0.322-1.85c0-3.376,2.741-6.125,6.195-6.125   c0.007,0,0.015,0,0.022,0c0.103,0.014,0.206,0.027,0.311,0.034c0.485,0.03,0.948-0.171,1.254-0.542   c0.307-0.372,0.417-0.868,0.294-1.334c0-0.001-0.003-0.014-0.008-0.031c0.003-0.035,0.006-0.067,0.007-0.095   c0.005-0.18-0.022-0.359-0.081-0.529c-0.242-0.707-0.359-1.352-0.359-1.972c0-3.342,2.688-6.065,6.016-6.129   c3.328,0.065,6.016,2.787,6.016,6.13c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.391-1.861-6.35-4.612-7.926   c0.364-3.346,3.16-5.908,6.558-5.908c3.657,0,6.632,2.976,6.632,6.633V82.317z"/>
      </g>

      </svg>
      <h2 style="text-align: center">990,224 total model predictions made</h2>
      </div>
      ''')

demo.launch()