Mushfi commited on
Commit
baa9cea
1 Parent(s): e8f0723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -3
app.py CHANGED
@@ -3,6 +3,8 @@ import tensorflow as tf
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import os
5
  import numpy as np
 
 
6
 
7
  model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -49,6 +51,29 @@ def inference(text):
49
  probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))}
50
  return probabilities
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  css = """
53
  textarea {
54
  background-color: #00000000;
@@ -58,21 +83,87 @@ textarea {
58
  with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo:
59
  with gr.Row():
60
  textmd = gr.Markdown('''
61
- <div style="margin: 50px 0;"></div>
62
 
63
  <h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1>
64
 
65
  ''')
 
66
  with gr.Row():
67
  with gr.Column(scale=1, min_width=600):
68
  text_box = gr.Textbox(label="Explain your problem in one sentence.")
69
- submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
 
 
70
  examples = gr.Examples(examples=[
71
  "When I remember her I feel down",
72
  "The area around my heart doesn't feel good.",
73
  "I have a split on my thumb that will not heal."
74
  ], inputs=text_box)
75
- label = gr.Label(num_top_classes=4, label="Recommended Specialist")
 
 
 
76
  submit_btn.click(inference, inputs=text_box, outputs=label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  demo.launch()
 
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import os
5
  import numpy as np
6
+ import shap
7
+ import scipy.special
8
 
9
  model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
51
  probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))}
52
  return probabilities
53
 
54
+ def predictor(x):
55
+ input_ids = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['input_ids']
56
+ attention_mask = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['attention_mask']
57
+ outputs = model.predict([input_ids, attention_mask])
58
+ probas = tf.nn.softmax(outputs).numpy()
59
+ val = scipy.special.logit(probas[:,1])
60
+ return val
61
+
62
+ def f_batch(x):
63
+ val = np.array([])
64
+ for i in x:
65
+ val = np.append(val, predictor(i))
66
+ return val
67
+
68
+ explainer_roberta = shap.Explainer(f_batch, tokenizer)
69
+
70
+ shap_values = explainer_roberta(["When I remember her I feel down"])
71
+
72
+ def get_shap_data(input_text):
73
+ shap_values = explainer_roberta([input_text])
74
+ html_shap_content = shap.plots.text(shap_values, display=False)
75
+ return html_shap_content
76
+
77
  css = """
78
  textarea {
79
  background-color: #00000000;
 
83
  with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo:
84
  with gr.Row():
85
  textmd = gr.Markdown('''
86
+ <div style="margin: 50px 0;"></div>
87
 
88
  <h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1>
89
 
90
  ''')
91
+
92
  with gr.Row():
93
  with gr.Column(scale=1, min_width=600):
94
  text_box = gr.Textbox(label="Explain your problem in one sentence.")
95
+ with gr.Row():
96
+ submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
97
+ submit_btn_shap = gr.Button("Get SHAP Analysis", variant='secondary')
98
  examples = gr.Examples(examples=[
99
  "When I remember her I feel down",
100
  "The area around my heart doesn't feel good.",
101
  "I have a split on my thumb that will not heal."
102
  ], inputs=text_box)
103
+ with gr.Column():
104
+ label = gr.Label(num_top_classes=4, label="Recommended Specialist")
105
+ gr.Markdown("## SHAP Analysis")
106
+ shap_content_box = gr.HTML()
107
  submit_btn.click(inference, inputs=text_box, outputs=label)
108
+ submit_btn_shap.click(get_shap_data, inputs=text_box, outputs=shap_content_box)
109
+ with gr.Row():
110
+ textmd1 = gr.Markdown('''
111
+ <div style="margin: 50px 0;"></div>
112
+ ''')
113
+ with gr.Row():
114
+ with gr.Column():
115
+ textmd1 = gr.Markdown('''
116
+ <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
117
+ <div style="margin: 20px 0;"></div>
118
+ <svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
119
+ <path opacity="0.4" d="M10.7509 2.45007C11.4409 1.86007 12.5709 1.86007 13.2709 2.45007L14.8509 3.81007C15.1509 4.07007 15.7109 4.28007 16.1109 4.28007H17.8109C18.8709 4.28007 19.7409 5.15007 19.7409 6.21007V7.91007C19.7409 8.30007 19.9509 8.87007 20.2109 9.17007L21.5709 10.7501C22.1609 11.4401 22.1609 12.5701 21.5709 13.2701L20.2109 14.8501C19.9509 15.1501 19.7409 15.7101 19.7409 16.1101V17.8101C19.7409 18.8701 18.8709 19.7401 17.8109 19.7401H16.1109C15.7209 19.7401 15.1509 19.9501 14.8509 20.2101L13.2709 21.5701C12.5809 22.1601 11.4509 22.1601 10.7509 21.5701L9.17086 20.2101C8.87086 19.9501 8.31086 19.7401 7.91086 19.7401H6.18086C5.12086 19.7401 4.25086 18.8701 4.25086 17.8101V16.1001C4.25086 15.7101 4.04086 15.1501 3.79086 14.8501L2.44086 13.2601C1.86086 12.5701 1.86086 11.4501 2.44086 10.7601L3.79086 9.17007C4.04086 8.87007 4.25086 8.31007 4.25086 7.92007V6.20007C4.25086 5.14007 5.12086 4.27007 6.18086 4.27007H7.91086C8.30086 4.27007 8.87086 4.06007 9.17086 3.80007L10.7509 2.45007Z" fill="var(--neutral-700)"/>
120
+ <path d="M10.7905 15.17C10.5905 15.17 10.4005 15.09 10.2605 14.95L7.84055 12.53C7.55055 12.24 7.55055 11.76 7.84055 11.47C8.13055 11.18 8.61055 11.18 8.90055 11.47L10.7905 13.36L15.0905 9.06003C15.3805 8.77003 15.8605 8.77003 16.1505 9.06003C16.4405 9.35003 16.4405 9.83003 16.1505 10.12L11.3205 14.95C11.1805 15.09 10.9905 15.17 10.7905 15.17Z" fill="#ececec"/>
121
+ </svg>
122
+
123
+ <h2 style="text-align: center">98% Success Rate</h2>
124
+ </div>
125
+ ''')
126
+ with gr.Column():
127
+ textmd2 = gr.Markdown('''
128
+ <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
129
+ <div style="margin: 20px 0;"></div>
130
+ <svg width="50px" height="50px" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
131
+ <path d="M12 33C8.66666 33 4 31.5 4 25.5C4 18.5 11 17 13 17C14 13.5 16 8 24 8C31 8 34 12 35 15.5C35 15.5 44 16.5 44 25C44 31 40 33 36 33" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
132
+ <path d="M18 33L24 38L32 28" stroke="var(--neutral-700)" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/>
133
+ </svg>
134
+
135
+ <h2 style="text-align: center">99.9% Model uptime</h2>
136
+ </div>
137
+ ''')
138
+ with gr.Column():
139
+ textmd3 = gr.Markdown('''
140
+ <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
141
+ <div style="margin: 20px 0;"></div>
142
+ <svg width="50px" height="50px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
143
+ <path opacity="0.4" d="M17.9981 7.16C17.9381 7.15 17.8681 7.15 17.8081 7.16C16.4281 7.11 15.3281 5.98 15.3281 4.58C15.3281 3.15 16.4781 2 17.9081 2C19.3381 2 20.4881 3.16 20.4881 4.58C20.4781 5.98 19.3781 7.11 17.9981 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
144
+ <path opacity="0.4" d="M16.9675 14.4402C18.3375 14.6702 19.8475 14.4302 20.9075 13.7202C22.3175 12.7802 22.3175 11.2402 20.9075 10.3002C19.8375 9.59016 18.3075 9.35016 16.9375 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
145
+ <path opacity="0.4" d="M5.96656 7.16C6.02656 7.15 6.09656 7.15 6.15656 7.16C7.53656 7.11 8.63656 5.98 8.63656 4.58C8.63656 3.15 7.48656 2 6.05656 2C4.62656 2 3.47656 3.16 3.47656 4.58C3.48656 5.98 4.58656 7.11 5.96656 7.16Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
146
+ <path opacity="0.4" d="M6.9975 14.4402C5.6275 14.6702 4.1175 14.4302 3.0575 13.7202C1.6475 12.7802 1.6475 11.2402 3.0575 10.3002C4.1275 9.59016 5.6575 9.35016 7.0275 9.59016" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
147
+ <path d="M12.0001 14.6302C11.9401 14.6202 11.8701 14.6202 11.8101 14.6302C10.4301 14.5802 9.33008 13.4502 9.33008 12.0502C9.33008 10.6202 10.4801 9.47021 11.9101 9.47021C13.3401 9.47021 14.4901 10.6302 14.4901 12.0502C14.4801 13.4502 13.3801 14.5902 12.0001 14.6302Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
148
+ <path d="M9.0907 17.7804C7.6807 18.7204 7.6807 20.2603 9.0907 21.2003C10.6907 22.2703 13.3107 22.2703 14.9107 21.2003C16.3207 20.2603 16.3207 18.7204 14.9107 17.7804C13.3207 16.7204 10.6907 16.7204 9.0907 17.7804Z" stroke="var(--neutral-700)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
149
+ </svg>
150
+
151
+ <h2 style="text-align: center">77,223 Monthly active users</h2>
152
+ </div>
153
+ ''')
154
+ with gr.Column():
155
+ textmd3 = gr.Markdown('''
156
+ <div style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
157
+ <div style="margin: 20px 0;"></div>
158
+ <svg fill="var(--neutral-700)" width="50px" height="50px" viewBox="0 0 100 100" id="Layer_1" version="1.1" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
159
+ <g>
160
+ <path d="M93.998,45.312c0-3.676-1.659-7.121-4.486-9.414c0.123-0.587,0.184-1.151,0.184-1.706c0-4.579-3.386-8.382-7.785-9.037 c0.101-0.526,0.149-1.042,0.149-1.556c0-4.875-3.842-8.858-8.655-9.111c-0.079-0.013-0.159-0.024-0.242-0.024 c-0.04,0-0.079,0.005-0.12,0.006c-0.04-0.001-0.079-0.006-0.12-0.006c-0.458,0-0.919,0.041-1.406,0.126 c-0.846-4.485-4.753-7.825-9.437-7.825c-5.311,0-9.632,4.321-9.632,9.633v65.918c0,6.723,5.469,12.191,12.191,12.191 c4.46,0,8.508-2.413,10.646-6.246c0.479,0.104,0.939,0.168,1.401,0.198c2.903,0.185,5.73-0.766,7.926-2.693 c2.196-1.927,3.51-4.594,3.7-7.51c0.079-1.215-0.057-2.434-0.403-3.638c3.796-2.691,6.027-6.952,6.027-11.621 c0-3.385-1.219-6.635-3.445-9.224C92.731,51.505,93.998,48.471,93.998,45.312z M90.938,62.999c0,3.484-1.582,6.68-4.295,8.819 c-2.008-3.196-5.57-5.237-9.427-5.237c-0.828,0-1.5,0.672-1.5,1.5s0.672,1.5,1.5,1.5c3.341,0,6.384,2.093,7.582,5.208 c0.41,1.088,0.592,2.189,0.521,3.274c-0.138,2.116-1.091,4.051-2.685,5.449c-1.594,1.399-3.641,2.094-5.752,1.954 c-0.594-0.039-1.208-0.167-1.933-0.402c-0.74-0.242-1.541,0.124-1.846,0.84c-1.445,3.404-4.768,5.604-8.465,5.604 c-5.068,0-9.191-4.123-9.191-9.191V16.399c0-3.657,2.975-6.633,6.632-6.633c3.398,0,6.194,2.562,6.558,5.908 c-2.751,1.576-4.612,4.535-4.612,7.926c0,0.829,0.672,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.343,2.689-6.065,6.016-6.13 c3.327,0.065,6.016,2.787,6.016,6.129c0,0.622-0.117,1.266-0.359,1.971c-0.057,0.166-0.084,0.34-0.081,0.515 c0.001,0.041,0.003,0.079,0.007,0.115c-0.006,0.021-0.01,0.035-0.01,0.035c-0.118,0.465-0.006,0.959,0.301,1.328 c0.307,0.369,0.765,0.569,1.251,0.538c0.104-0.007,0.208-0.02,0.392-0.046c3.383,0,6.136,2.753,6.136,6.136 c0,0.572-0.103,1.159-0.322,1.849c-0.203,0.635,0.038,1.328,0.591,1.7c2.434,1.639,3.909,4.329,4.014,7.242 c0,0.004-0.001,0.008-0.001,0.012c0,5.03-4.092,9.123-9.122,9.123s-9.123-4.093-9.123-9.123c0-0.829-0.672-1.5-1.5-1.5 s-1.5,0.671-1.5,1.5c0,6.685,5.438,12.123,12.123,12.123c2.228,0,4.31-0.615,6.106-1.668C89.88,57.539,90.938,60.212,90.938,62.999 z"/>
161
+ <path d="M38.179,6.766c-4.684,0-8.59,3.34-9.435,7.825c-0.488-0.085-0.949-0.126-1.407-0.126c-0.04,0-0.079,0.005-0.12,0.006 c-0.04-0.001-0.079-0.006-0.12-0.006c-0.083,0-0.163,0.011-0.242,0.024c-4.813,0.253-8.654,4.236-8.654,9.111 c0,0.514,0.049,1.03,0.149,1.556c-4.399,0.655-7.785,4.458-7.785,9.037c0,0.554,0.061,1.118,0.184,1.706 c-2.827,2.293-4.486,5.738-4.486,9.414c0,3.159,1.266,6.193,3.505,8.463c-2.227,2.589-3.446,5.839-3.446,9.224 c0,4.669,2.231,8.929,6.027,11.621c-0.347,1.204-0.482,2.423-0.402,3.639c0.19,2.915,1.503,5.582,3.699,7.509 c2.196,1.928,5.015,2.879,7.926,2.693c0.455-0.03,0.919-0.096,1.4-0.199c2.138,3.834,6.186,6.247,10.646,6.247 c6.722,0,12.191-5.469,12.191-12.191V16.399C47.811,11.087,43.49,6.766,38.179,6.766z M44.811,82.317 c0,5.068-4.123,9.191-9.191,9.191c-3.697,0-7.02-2.2-8.464-5.604c-0.241-0.567-0.793-0.914-1.381-0.914 c-0.154,0-0.311,0.023-0.465,0.074c-0.724,0.235-1.338,0.363-1.933,0.402c-2.119,0.139-4.158-0.556-5.751-1.954 c-1.594-1.398-2.547-3.333-2.685-5.449c-0.076-1.16,0.125-2.336,0.598-3.495c0.007-0.017,0.005-0.036,0.011-0.053 c1.342-3.056,4.225-4.953,7.597-4.953c0.829,0,1.5-0.672,1.5-1.5s-0.671-1.5-1.5-1.5c-3.938,0-7.501,2.007-9.548,5.239 c-2.701-2.139-4.277-5.327-4.277-8.802c0-2.787,1.06-5.46,2.978-7.549c1.796,1.053,3.879,1.668,6.107,1.668 c6.685,0,12.123-5.438,12.123-12.123c0-0.829-0.671-1.5-1.5-1.5s-1.5,0.671-1.5,1.5c0,5.03-4.092,9.123-9.123,9.123 s-9.123-4.093-9.123-9.123c0-0.002-0.001-0.004-0.001-0.006c0.103-2.915,1.578-5.607,4.013-7.248 c0.553-0.372,0.793-1.064,0.591-1.699c-0.22-0.691-0.322-1.278-0.322-1.85c0-3.376,2.741-6.125,6.195-6.125 c0.007,0,0.015,0,0.022,0c0.103,0.014,0.206,0.027,0.311,0.034c0.485,0.03,0.948-0.171,1.254-0.542 c0.307-0.372,0.417-0.868,0.294-1.334c0-0.001-0.003-0.014-0.008-0.031c0.003-0.035,0.006-0.067,0.007-0.095 c0.005-0.18-0.022-0.359-0.081-0.529c-0.242-0.707-0.359-1.352-0.359-1.972c0-3.342,2.688-6.065,6.016-6.129 c3.328,0.065,6.016,2.787,6.016,6.13c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5c0-3.391-1.861-6.35-4.612-7.926 c0.364-3.346,3.16-5.908,6.558-5.908c3.657,0,6.632,2.976,6.632,6.633V82.317z"/>
162
+ </g>
163
+
164
+ </svg>
165
+ <h2 style="text-align: center">990,224 total model predictions made</h2>
166
+ </div>
167
+ ''')
168
 
169
  demo.launch()