taesiri commited on
Commit
a8ed63e
1 Parent(s): 6e25e63

initial commit

Browse files
Files changed (4) hide show
  1. app.py +431 -0
  2. images/intro.jpg +0 -0
  3. showresults.py +98 -0
  4. utils.py +20 -0
app.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import numpy as np
4
+ import time
5
+ import csv
6
+ import json
7
+ import os
8
+ import random
9
+ import string
10
+ import sys
11
+ import time
12
+ import gradio as gr
13
+ import numpy as np
14
+ import pandas as pd
15
+ from huggingface_hub import (
16
+ CommitScheduler,
17
+ HfApi,
18
+ InferenceClient,
19
+ login,
20
+ snapshot_download,
21
+ )
22
+ from PIL import Image
23
+ from utils import string_to_image
24
+ import matplotlib.backends.backend_agg as agg
25
+ import math
26
+ from pathlib import Path
27
+ import zipfile
28
+
29
+
30
+ np.random.seed(int(time.time()))
31
+ csv.field_size_limit(sys.maxsize)
32
+ np.random.seed(int(time.time()))
33
+
34
+
35
+ ###############################################################################################################
36
+ session_token = os.environ.get("SessionToken")
37
+ login(token=session_token)
38
+
39
+ # Using snapshot_download to handle the download and extraction
40
+ snapshot_download(
41
+ repo_id='XAI/PEEB-Data',
42
+ repo_type='dataset',
43
+ local_dir='./',
44
+ cache_dir='./hf_cache'
45
+ )
46
+
47
+ with zipfile.ZipFile('./data.zip', 'r') as zip_ref:
48
+ zip_ref.extractall("./")
49
+
50
+
51
+ NUMBER_OF_IMAGES = 30
52
+ intro_screen = Image.open("./images/intro.jpg")
53
+
54
+ meta_top1 = json.load(open("./dogs/top1/metadata.json"))
55
+ meta_topK = json.load(open("./dogs/topK/metadata.json"))
56
+
57
+ all_data = {}
58
+ all_data["top1"] = meta_top1
59
+ all_data["topK"] = meta_topK
60
+
61
+
62
+ # for data in all_data["top1"] and all_data["topK"] add a key to show which type they are
63
+ for k in all_data["top1"].keys():
64
+ all_data["top1"][k]["type"] = "top1"
65
+
66
+ for k in all_data["topK"].keys():
67
+ all_data["topK"][k]["type"] = "topK"
68
+
69
+
70
+
71
+ REPO_URL = "taesiri/AdvisingNetworksReviewDataExtension"
72
+ JSON_DATASET_DIR = Path("responses")
73
+
74
+ ################################################################################################################
75
+
76
+ scheduler = CommitScheduler(
77
+ repo_id=REPO_URL,
78
+ repo_type="dataset",
79
+ folder_path=JSON_DATASET_DIR,
80
+ path_in_repo="./data",
81
+ every=1,
82
+ private=True,
83
+ )
84
+
85
+
86
+ if not JSON_DATASET_DIR.exists():
87
+ JSON_DATASET_DIR.mkdir()
88
+
89
+
90
+ def generate_data(type_of_nns):
91
+ global NUMBER_OF_IMAGES
92
+ # randomly pick NUMBER_OF_IMAGES from the dataset with type type_of_nns
93
+ keys = list(all_data[type_of_nns].keys())
94
+ sample_data = random.sample(keys, NUMBER_OF_IMAGES)
95
+
96
+ data = []
97
+ for k in sample_data:
98
+ new_datapoint = all_data[type_of_nns][k]
99
+ new_datapoint["image-path"] = f"./dogs/{type_of_nns}/{k}.jpeg"
100
+ data.append(new_datapoint)
101
+
102
+ return data
103
+
104
+
105
+ def load_sample(data, current_index):
106
+ current_datapoint = data[current_index]
107
+
108
+ image_path = current_datapoint["image-path"]
109
+ image = Image.open(image_path)
110
+ top_1 = current_datapoint["top1-label"]
111
+ top_1_score = current_datapoint["top1-score"]
112
+
113
+ q_template = (
114
+ "<div style='font-size: 24px;'>Sam guessed the Input image is "
115
+ "<span style='font-weight: bold;'>{}</span> "
116
+ "with <span style='font-weight: bold;'>{}%</span> "
117
+ "confidence. Is this bird a <span style='font-weight: bold;'>{}</span>?"
118
+ "</div>"
119
+ )
120
+
121
+ q_template = (
122
+ "<div style='font-size: 24px;'>Sam guessed the Input image is "
123
+ "<span style='font-weight: bold;'>{}</span> "
124
+ "with <span style='font-weight: bold;'>{}%</span> "
125
+ "confidence.<br>Is this bird a <span style='font-weight: bold;'>{}</span>?"
126
+ "</div>"
127
+ )
128
+
129
+ top_1_score = top_1_score * 100
130
+ top_1_score = round(top_1_score, 2)
131
+
132
+ rounded_up_score = math.ceil(top_1_score)
133
+ rounded_up_score = int(rounded_up_score)
134
+ question = q_template.format(top_1, str(rounded_up_score), top_1)
135
+
136
+ accept_reject = current_datapoint["Accept/Reject"]
137
+
138
+ return image, top_1, rounded_up_score, question, accept_reject
139
+
140
+
141
+ def preprocessing(data, type_of_nns, current_index, history, username):
142
+ print("preprocessing")
143
+ data = generate_data(type_of_nns)
144
+ print("data generated")
145
+
146
+ # append a random text to the username
147
+ random_text = "".join(
148
+ random.choice(string.ascii_lowercase + string.digits) for _ in range(8)
149
+ )
150
+
151
+ if username == "":
152
+ username = "username"
153
+
154
+ username = f"{username}-{random_text}"
155
+
156
+ current_index = 0
157
+ print("loading sample ....")
158
+ qimage, top_1, top_1_score, question, accept_reject = load_sample(
159
+ data, current_index
160
+ )
161
+
162
+ return (
163
+ qimage,
164
+ top_1,
165
+ top_1_score,
166
+ question,
167
+ accept_reject,
168
+ current_index,
169
+ history,
170
+ data,
171
+ username,
172
+ )
173
+
174
+
175
+ def update_app(decision, data, current_index, history, username):
176
+ global NUMBER_OF_IMAGES
177
+ if current_index == -1:
178
+ gr.Error("Please Enter your username and load samples")
179
+
180
+ fake_plot = string_to_image("Please Enter your username and load samples")
181
+ canvas = agg.FigureCanvasAgg(fake_plot)
182
+ canvas.draw()
183
+ empty_image = Image.frombytes(
184
+ "RGBA", canvas.get_width_height(), canvas.tostring_argb()
185
+ )
186
+
187
+ return (
188
+ empty_image,
189
+ "",
190
+ "",
191
+ "",
192
+ "",
193
+ current_index,
194
+ history,
195
+ data,
196
+ 0,
197
+ gr.update(interactive=False),
198
+ gr.update(interactive=False),
199
+ "",
200
+ )
201
+
202
+ # Done, let's save and upload
203
+ if current_index == NUMBER_OF_IMAGES - 1:
204
+ time_stamp = int(time.time())
205
+
206
+ # Add decision to the history
207
+ current_dicitonary = data[current_index].copy()
208
+ current_dicitonary["user_decision"] = decision
209
+ current_dicitonary["user_id"] = username
210
+ accept_reject_string = "Accept" if decision == "YES" else "Reject"
211
+ current_dicitonary["is_user_correct"] = (
212
+ current_dicitonary["Accept/Reject"] == accept_reject_string
213
+ )
214
+ history.append(current_dicitonary)
215
+
216
+ # convert to percentage
217
+ final_decision_data = {
218
+ "user_id": username,
219
+ "time": time_stamp,
220
+ "history": history,
221
+ }
222
+
223
+ # upload the decision to the server
224
+ temp_filename = f"./responses/results_{username}.json"
225
+ # convert decision_dict to json and save it on the disk
226
+ with open(temp_filename, "w") as f:
227
+ json.dump(final_decision_data, f)
228
+
229
+ fake_plot = string_to_image("Thank you for your time!")
230
+ canvas = agg.FigureCanvasAgg(fake_plot)
231
+ canvas.draw()
232
+ empty_image = Image.frombytes(
233
+ "RGBA", canvas.get_width_height(), canvas.tostring_argb()
234
+ )
235
+
236
+ # TODO, Call the accuracy and show it to the user
237
+ # calcualte the mean of is_user_correct
238
+ all_is_user_correct = [d["is_user_correct"] for d in history]
239
+ accuracy = np.mean(all_is_user_correct) * 100
240
+ accuracy = round(accuracy, 2)
241
+
242
+ return (
243
+ empty_image,
244
+ "",
245
+ "",
246
+ "",
247
+ "",
248
+ current_index,
249
+ history,
250
+ data,
251
+ current_index + 1,
252
+ gr.update(interactive=False),
253
+ gr.update(interactive=False),
254
+ f"User Accuracy: {accuracy}",
255
+ )
256
+
257
+ if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
258
+ current_dicitonary = data[current_index].copy()
259
+ current_dicitonary["user_decision"] = decision
260
+ current_dicitonary["user_id"] = username
261
+ accept_reject_string = True if decision == "YES" else False
262
+ current_dicitonary["is_user_correct"] = (
263
+ current_dicitonary["Accept/Reject"] == accept_reject_string
264
+ )
265
+
266
+ print(f" accept/reject : {current_dicitonary['Accept/Reject'] }")
267
+ print(
268
+ f" accept/reject status: {current_dicitonary['Accept/Reject'] == accept_reject_string}"
269
+ )
270
+
271
+ history.append(current_dicitonary)
272
+
273
+ current_index += 1
274
+ qimage, top_1, top_1_score, question, accept_reject = load_sample(
275
+ data, current_index
276
+ )
277
+
278
+ return (
279
+ qimage,
280
+ top_1,
281
+ top_1_score,
282
+ question,
283
+ accept_reject,
284
+ current_index,
285
+ history,
286
+ data,
287
+ current_index,
288
+ gr.update(interactive=True),
289
+ gr.update(interactive=True),
290
+ "",
291
+ )
292
+
293
+
294
+ def disable_component():
295
+ return gr.update(interactive=False)
296
+
297
+
298
+ def enable_component():
299
+ return gr.update(interactive=True)
300
+
301
+
302
+ def hide_component():
303
+ return gr.update(visible=False)
304
+
305
+
306
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
307
+ data_state = gr.State({})
308
+ current_index = gr.State(-1)
309
+ history = gr.State([])
310
+
311
+ gr.Markdown("# Advising Networks")
312
+ gr.Markdown("## Accept/Reject AI predicted label using Explanations")
313
+
314
+ with gr.Column():
315
+ with gr.Row():
316
+ username_textbox = gr.Textbox(label="Username", value=f"username")
317
+ labeled_images_textbox = gr.Textbox(label="Labeled Images", value="0")
318
+ total_images_textbox = gr.Textbox(
319
+ label="Total Images", value=NUMBER_OF_IMAGES
320
+ )
321
+ type_of_nns_dropdown = gr.Dropdown(
322
+ label="Type of NNs",
323
+ choices=["top1", "topK"],
324
+ value="top1",
325
+ )
326
+
327
+ prepare_btn = gr.Button(value="Start The Experiment")
328
+
329
+ with gr.Column():
330
+ with gr.Row():
331
+ question_textbox = gr.HTML("")
332
+ # question_textbox = gr.Markdown("")
333
+
334
+ with gr.Column(elem_id="parent_row"):
335
+ query_image = gr.Image(
336
+ type="pil", label="Query", show_label=False, value="./images/intro.jpg"
337
+ )
338
+
339
+ with gr.Row():
340
+ accept_btn = gr.Button(value="YES", interactive=False)
341
+ reject_btn = gr.Button(value="NO", interactive=False)
342
+
343
+ with gr.Column(elem_id="parent_row"):
344
+ top_1_textbox = gr.Textbox(label="Top 1", value="", visible=False)
345
+ top_1_score_textbox = gr.Textbox(
346
+ label="Top 1 Score", value="", visible=False
347
+ )
348
+ accept_reject_textbox = gr.Textbox(
349
+ label="Accept/Reject", value="", visible=False
350
+ )
351
+
352
+ with gr.Column():
353
+ with gr.Row():
354
+ final_results = gr.HTML("")
355
+
356
+ # data, type_of_nns, current_index, history
357
+ prepare_btn.click(
358
+ preprocessing,
359
+ inputs=[
360
+ data_state,
361
+ type_of_nns_dropdown,
362
+ current_index,
363
+ history,
364
+ username_textbox,
365
+ ],
366
+ outputs=[
367
+ query_image,
368
+ top_1_textbox,
369
+ top_1_score_textbox,
370
+ question_textbox,
371
+ accept_reject_textbox,
372
+ current_index,
373
+ history,
374
+ data_state,
375
+ username_textbox,
376
+ ],
377
+ ).then(fn=disable_component, outputs=[prepare_btn]).then(
378
+ fn=disable_component, outputs=[type_of_nns_dropdown]
379
+ ).then(
380
+ fn=disable_component, outputs=[username_textbox]
381
+ ).then(
382
+ fn=disable_component, outputs=[prepare_btn]
383
+ ).then(
384
+ fn=enable_component, outputs=[accept_btn]
385
+ ).then(
386
+ fn=enable_component, outputs=[reject_btn]
387
+ ).then(
388
+ fn=hide_component, outputs=[prepare_btn]
389
+ )
390
+
391
+ accept_btn.click(
392
+ update_app,
393
+ inputs=[accept_btn, data_state, current_index, history, username_textbox],
394
+ outputs=[
395
+ query_image,
396
+ top_1_textbox,
397
+ top_1_score_textbox,
398
+ question_textbox,
399
+ accept_reject_textbox,
400
+ current_index,
401
+ history,
402
+ data_state,
403
+ labeled_images_textbox,
404
+ accept_btn,
405
+ reject_btn,
406
+ final_results,
407
+ ],
408
+ )
409
+
410
+ reject_btn.click(
411
+ update_app,
412
+ inputs=[reject_btn, data_state, current_index, history, username_textbox],
413
+ outputs=[
414
+ query_image,
415
+ top_1_textbox,
416
+ top_1_score_textbox,
417
+ question_textbox,
418
+ accept_reject_textbox,
419
+ current_index,
420
+ history,
421
+ data_state,
422
+ labeled_images_textbox,
423
+ accept_btn,
424
+ reject_btn,
425
+ final_results,
426
+ ],
427
+ )
428
+
429
+
430
+ demo.launch(debug=False, server_name="0.0.0.0")
431
+ # demo.launch(debug=False)
images/intro.jpg ADDED
showresults.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import json
3
+ import numpy as np
4
+ import gradio as gr
5
+
6
+
7
+
8
+ def calculate_the_results():
9
+ all_jsons_path = glob('./responses/*.json')
10
+ all_jsons = [json.load(open(path)) for path in all_jsons_path]
11
+
12
+ # count number of user corrects for each json and average and also calcaulte the type of NNs
13
+
14
+ top1_results = []
15
+ top1_acc = []
16
+ topK_results = []
17
+ topK_acc = []
18
+
19
+ for js in all_jsons:
20
+ # read one key and determine the type of NN
21
+ type_of_NNs = js['history'][0]['type']
22
+ if type_of_NNs == 'topK':
23
+ acc = np.mean([js['history'][x]['is_user_correct'] for x in range(len(js['history']))])
24
+ topK_acc.append((acc*100).round(2))
25
+ topK_results.append(js)
26
+
27
+ else:
28
+ top1_results.append(js)
29
+ acc = np.mean([js['history'][x]['is_user_correct'] for x in range(len(js['history']))])
30
+ top1_acc.append((acc*100).round(2))
31
+
32
+
33
+ print('# of top1: ', len(top1_results))
34
+ print('top1 Accuracy: ', top1_acc)
35
+ # print std and mean of top1_acc
36
+ std = np.std(top1_acc)
37
+ mean = np.mean(top1_acc)
38
+
39
+ print('top1 std: ', std)
40
+ print('top1 mean: ', mean)
41
+
42
+
43
+
44
+
45
+ print('----------------------------------')
46
+
47
+
48
+ print('# of topK: ', len(topK_results))
49
+ print('topK Accuracy: ', topK_acc)
50
+
51
+ std = np.std(topK_acc)
52
+ mean = np.mean(topK_acc)
53
+
54
+ print('topK std: ', std)
55
+ print('topK mean: ', mean)
56
+
57
+
58
+
59
+
60
+ def calculate_the_results():
61
+ all_jsons_path = glob('./responses/*.json')
62
+ all_jsons = [json.load(open(path)) for path in all_jsons_path]
63
+
64
+ # count number of user corrects for each json and average and also calculate the type of NNs
65
+
66
+ top1_results = []
67
+ top1_acc = []
68
+ topK_results = []
69
+ topK_acc = []
70
+
71
+ for js in all_jsons:
72
+ # read one key and determine the type of NN
73
+ type_of_NNs = js['history'][0]['type']
74
+ if type_of_NNs == 'topK':
75
+ acc = np.mean([js['history'][x]['is_user_correct'] for x in range(len(js['history']))])
76
+ topK_acc.append((acc*100).round(2))
77
+ topK_results.append(js)
78
+ else:
79
+ top1_results.append(js)
80
+ acc = np.mean([js['history'][x]['is_user_correct'] for x in range(len(js['history']))])
81
+ top1_acc.append((acc*100).round(2))
82
+
83
+ top1_output = f"# of top1: {len(top1_results)}\ntop1 Accuracy: {top1_acc}\ntop1 std: {np.std(top1_acc)}\ntop1 mean: {np.mean(top1_acc)}\n----------------------------------\n"
84
+ topK_output = f"# of topK: {len(topK_results)}\ntopK Accuracy: {topK_acc}\ntopK std: {np.std(topK_acc)}\ntopK mean: {np.mean(topK_acc)}"
85
+
86
+ return top1_output + topK_output
87
+
88
+
89
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
+ update_btn = gr.Button("Calculate the results")
91
+ results_textbox = gr.Textbox(lines=10, label="Results")
92
+
93
+
94
+
95
+ update_btn.click(fn=calculate_the_results, outputs=results_textbox)
96
+
97
+
98
+ demo.launch(debug=False, server_name="0.0.0.0", server_port=9911)
utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def string_to_image(text):
6
+ text = text.replace("_", " ").lower().replace(", ", "\n")
7
+ # Create a blank white square image
8
+ img = np.ones((220, 75, 3))
9
+
10
+ fig, ax = plt.subplots(figsize=(6, 2.25))
11
+ ax.imshow(img, extent=[0, 1, 0, 1])
12
+ ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center")
13
+ ax.set_xticks([])
14
+ ax.set_yticks([])
15
+ ax.set_xticklabels([])
16
+ ax.set_yticklabels([])
17
+ for spine in ax.spines.values():
18
+ spine.set_visible(False)
19
+
20
+ return fig