Spaces:
Runtime error
Runtime error
fix to dashboard not loading
Browse files- app.py +25 -18
- data_mnist +1 -1
- utils.py +20 -9
app.py
CHANGED
@@ -37,7 +37,7 @@ os.makedirs(LOCAL_DIR,exist_ok=True)
|
|
37 |
|
38 |
|
39 |
|
40 |
-
|
41 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
42 |
MODEL_REPO = 'mnist-adversarial-model'
|
43 |
HF_DATASET ="mnist-adversarial-dataset"
|
@@ -74,10 +74,11 @@ class MNISTAdversarial_Dataset(Dataset):
|
|
74 |
|
75 |
image_path =os.path.join(self.FOLDER,'image.png')
|
76 |
if os.path.exists(image_path) and os.path.exists(metadata_path):
|
77 |
-
img = Image.open(image_path)
|
78 |
-
self.images.append(img)
|
79 |
metadata = read_json_lines(metadata_path)
|
80 |
-
|
|
|
|
|
|
|
81 |
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
|
82 |
def __len__(self):
|
83 |
return len(self.images)
|
@@ -395,8 +396,15 @@ def flag(input_image,correct_result,adversarial_number):
|
|
395 |
return output,adversarial_number
|
396 |
|
397 |
def get_number_dict(DATA_DIR):
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
files = [f.name for f in os.scandir(DATA_DIR)]
|
399 |
-
|
|
|
400 |
numbers_count = Counter(numbers)
|
401 |
numbers_count_keys = list(numbers_count.keys())
|
402 |
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
|
@@ -425,10 +433,8 @@ def get_statistics():
|
|
425 |
repo.git_pull()
|
426 |
DATA_DIR = './data_mnist/data'
|
427 |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
428 |
-
|
429 |
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
|
430 |
-
|
431 |
-
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit")
|
432 |
|
433 |
fig_d, ax_d = plt.subplots(tight_layout=True)
|
434 |
|
@@ -440,26 +446,25 @@ def get_statistics():
|
|
440 |
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
|
441 |
except Exception:
|
442 |
continue
|
|
|
443 |
|
444 |
else:
|
445 |
metric_dict={}
|
446 |
|
447 |
fig_d.legend()
|
448 |
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
|
449 |
-
|
450 |
-
|
451 |
-
<p> ✅ Statistics loaded successfully!</p>
|
452 |
</div>
|
453 |
"""
|
454 |
-
|
455 |
# Plot for total test accuracy for all digits
|
456 |
fig_all, ax_all = plt.subplots(tight_layout=True)
|
457 |
x_i = [i+1 for i in range(len(metric_dict['all']))]
|
458 |
|
459 |
ax_all.plot(x_i, metric_dict['all'])
|
460 |
-
fig_all.legend()
|
461 |
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
|
462 |
-
|
|
|
463 |
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
|
464 |
|
465 |
|
@@ -485,7 +490,7 @@ def main():
|
|
485 |
|
486 |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
|
487 |
|
488 |
-
|
489 |
flag_btn = gr.Button("Flag")
|
490 |
|
491 |
output_result = gr.outputs.HTML()
|
@@ -496,8 +501,9 @@ def main():
|
|
496 |
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
|
497 |
|
498 |
with gr.TabItem('Dashboard') as dashboard:
|
499 |
-
|
500 |
-
|
|
|
501 |
</div>
|
502 |
""")
|
503 |
|
@@ -508,7 +514,8 @@ def main():
|
|
508 |
gr.Markdown(DASHBOARD_EXPLANATION_TEST)
|
509 |
test_results_all=gr.Plot(type="matplotlib")
|
510 |
|
511 |
-
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,
|
|
|
512 |
|
513 |
|
514 |
|
|
|
37 |
|
38 |
|
39 |
|
40 |
+
GET_STATISTICS_MESSAGE = "Get Statistics"
|
41 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
42 |
MODEL_REPO = 'mnist-adversarial-model'
|
43 |
HF_DATASET ="mnist-adversarial-dataset"
|
|
|
74 |
|
75 |
image_path =os.path.join(self.FOLDER,'image.png')
|
76 |
if os.path.exists(image_path) and os.path.exists(metadata_path):
|
|
|
|
|
77 |
metadata = read_json_lines(metadata_path)
|
78 |
+
if metadata is not None:
|
79 |
+
img = Image.open(image_path)
|
80 |
+
self.images.append(img)
|
81 |
+
self.numbers.append(metadata[0]['correct_number'])
|
82 |
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
|
83 |
def __len__(self):
|
84 |
return len(self.images)
|
|
|
396 |
return output,adversarial_number
|
397 |
|
398 |
def get_number_dict(DATA_DIR):
|
399 |
+
"""
|
400 |
+
It takes a directory as input, and returns a list of the number of times each number appears in the
|
401 |
+
metadata.jsonl files in that directory
|
402 |
+
|
403 |
+
:param DATA_DIR: The directory where the data is stored
|
404 |
+
"""
|
405 |
files = [f.name for f in os.scandir(DATA_DIR)]
|
406 |
+
metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files]
|
407 |
+
numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None]
|
408 |
numbers_count = Counter(numbers)
|
409 |
numbers_count_keys = list(numbers_count.keys())
|
410 |
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
|
|
|
433 |
repo.git_pull()
|
434 |
DATA_DIR = './data_mnist/data'
|
435 |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
|
|
436 |
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
|
437 |
+
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True)
|
|
|
438 |
|
439 |
fig_d, ax_d = plt.subplots(tight_layout=True)
|
440 |
|
|
|
446 |
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
|
447 |
except Exception:
|
448 |
continue
|
449 |
+
ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1))
|
450 |
|
451 |
else:
|
452 |
metric_dict={}
|
453 |
|
454 |
fig_d.legend()
|
455 |
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
|
456 |
+
done_html = f"""<div style="color: green">
|
457 |
+
<p> ✅ Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p>
|
|
|
458 |
</div>
|
459 |
"""
|
|
|
460 |
# Plot for total test accuracy for all digits
|
461 |
fig_all, ax_all = plt.subplots(tight_layout=True)
|
462 |
x_i = [i+1 for i in range(len(metric_dict['all']))]
|
463 |
|
464 |
ax_all.plot(x_i, metric_dict['all'])
|
|
|
465 |
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
|
466 |
+
ax_all.set_xticks(range(0, x_i[-1]+1, 1))
|
467 |
+
|
468 |
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
|
469 |
|
470 |
|
|
|
490 |
|
491 |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
|
492 |
|
493 |
+
gr.Markdown('Please wait a while after you press `Flag`. It takes time.')
|
494 |
flag_btn = gr.Button("Flag")
|
495 |
|
496 |
output_result = gr.outputs.HTML()
|
|
|
501 |
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
|
502 |
|
503 |
with gr.TabItem('Dashboard') as dashboard:
|
504 |
+
get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}')
|
505 |
+
notification = gr.HTML(f"""<div style="color: green">
|
506 |
+
<p> ⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p>
|
507 |
</div>
|
508 |
""")
|
509 |
|
|
|
514 |
gr.Markdown(DASHBOARD_EXPLANATION_TEST)
|
515 |
test_results_all=gr.Plot(type="matplotlib")
|
516 |
|
517 |
+
#dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
|
518 |
+
get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
|
519 |
|
520 |
|
521 |
|
data_mnist
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit 0d5120c897f5b71d2f99b7fb2ef5dc28e3d7000d
|
utils.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3 |
import hashlib
|
4 |
import random
|
5 |
import string
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
8 |
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
|
@@ -25,7 +26,7 @@ MODEL_IS_WRONG = """
|
|
25 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
26 |
|
27 |
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
|
28 |
-
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
|
29 |
|
30 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|
31 |
|
@@ -39,12 +40,16 @@ def read_json(file):
|
|
39 |
return json.load(f)
|
40 |
|
41 |
def read_json_lines(file):
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
def json_dump(thing):
|
@@ -63,11 +68,17 @@ def dump_json(thing,file):
|
|
63 |
json.dump(thing,f)
|
64 |
|
65 |
|
66 |
-
def plot_bar(value,name,x_name,y_name,title):
|
67 |
fig, ax = plt.subplots(tight_layout=True)
|
68 |
|
69 |
ax.set(xlabel=x_name, ylabel=y_name,title=title)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
ax.barh(name, value)
|
72 |
|
73 |
-
return ax.figure
|
|
|
3 |
import hashlib
|
4 |
import random
|
5 |
import string
|
6 |
+
import warnings
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
9 |
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
|
|
|
26 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
27 |
|
28 |
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
|
29 |
+
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined."
|
30 |
|
31 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|
32 |
|
|
|
40 |
return json.load(f)
|
41 |
|
42 |
def read_json_lines(file):
|
43 |
+
try:
|
44 |
+
with open(file,'r',encoding="utf8") as f:
|
45 |
+
lines = f.readlines()
|
46 |
+
data=[]
|
47 |
+
for l in lines:
|
48 |
+
data.append(json.loads(l))
|
49 |
+
return data
|
50 |
+
except Exception as err:
|
51 |
+
warnings.warn(f"{err}")
|
52 |
+
return None
|
53 |
|
54 |
|
55 |
def json_dump(thing):
|
|
|
68 |
json.dump(thing,f)
|
69 |
|
70 |
|
71 |
+
def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False):
|
72 |
fig, ax = plt.subplots(tight_layout=True)
|
73 |
|
74 |
ax.set(xlabel=x_name, ylabel=y_name,title=title)
|
75 |
|
76 |
+
if set_yticks:
|
77 |
+
ax.set_yticks(range(min(name), max(name)+1, 1))
|
78 |
+
if set_xticks:
|
79 |
+
ax.set_xticks(range(min(name), max(name)+1, 1))
|
80 |
+
|
81 |
+
|
82 |
ax.barh(name, value)
|
83 |
|
84 |
+
return ax.figure
|