chrisjay commited on
Commit
e4a62fe
·
1 Parent(s): 603879a

fix to dashboard not loading

Browse files
Files changed (3) hide show
  1. app.py +25 -18
  2. data_mnist +1 -1
  3. utils.py +20 -9
app.py CHANGED
@@ -37,7 +37,7 @@ os.makedirs(LOCAL_DIR,exist_ok=True)
37
 
38
 
39
 
40
-
41
  HF_TOKEN = os.getenv("HF_TOKEN")
42
  MODEL_REPO = 'mnist-adversarial-model'
43
  HF_DATASET ="mnist-adversarial-dataset"
@@ -74,10 +74,11 @@ class MNISTAdversarial_Dataset(Dataset):
74
 
75
  image_path =os.path.join(self.FOLDER,'image.png')
76
  if os.path.exists(image_path) and os.path.exists(metadata_path):
77
- img = Image.open(image_path)
78
- self.images.append(img)
79
  metadata = read_json_lines(metadata_path)
80
- self.numbers.append(metadata[0]['correct_number'])
 
 
 
81
  assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
82
  def __len__(self):
83
  return len(self.images)
@@ -395,8 +396,15 @@ def flag(input_image,correct_result,adversarial_number):
395
  return output,adversarial_number
396
 
397
  def get_number_dict(DATA_DIR):
 
 
 
 
 
 
398
  files = [f.name for f in os.scandir(DATA_DIR)]
399
- numbers = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl'))[0]['correct_number'] for f in files]
 
400
  numbers_count = Counter(numbers)
401
  numbers_count_keys = list(numbers_count.keys())
402
  numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
@@ -425,10 +433,8 @@ def get_statistics():
425
  repo.git_pull()
426
  DATA_DIR = './data_mnist/data'
427
  numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
428
-
429
  STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
430
-
431
- plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit")
432
 
433
  fig_d, ax_d = plt.subplots(tight_layout=True)
434
 
@@ -440,26 +446,25 @@ def get_statistics():
440
  ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
441
  except Exception:
442
  continue
 
443
 
444
  else:
445
  metric_dict={}
446
 
447
  fig_d.legend()
448
  ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
449
-
450
- done_html = """<div style="color: green">
451
- <p> ✅ Statistics loaded successfully!</p>
452
  </div>
453
  """
454
-
455
  # Plot for total test accuracy for all digits
456
  fig_all, ax_all = plt.subplots(tight_layout=True)
457
  x_i = [i+1 for i in range(len(metric_dict['all']))]
458
 
459
  ax_all.plot(x_i, metric_dict['all'])
460
- fig_all.legend()
461
  ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
462
-
 
463
  return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
464
 
465
 
@@ -485,7 +490,7 @@ def main():
485
 
486
  number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
487
 
488
-
489
  flag_btn = gr.Button("Flag")
490
 
491
  output_result = gr.outputs.HTML()
@@ -496,8 +501,9 @@ def main():
496
  flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
497
 
498
  with gr.TabItem('Dashboard') as dashboard:
499
- notification = gr.HTML("""<div style="color: green">
500
- <p> Creating statistics... </p>
 
501
  </div>
502
  """)
503
 
@@ -508,7 +514,8 @@ def main():
508
  gr.Markdown(DASHBOARD_EXPLANATION_TEST)
509
  test_results_all=gr.Plot(type="matplotlib")
510
 
511
- dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
 
512
 
513
 
514
 
 
37
 
38
 
39
 
40
+ GET_STATISTICS_MESSAGE = "Get Statistics"
41
  HF_TOKEN = os.getenv("HF_TOKEN")
42
  MODEL_REPO = 'mnist-adversarial-model'
43
  HF_DATASET ="mnist-adversarial-dataset"
 
74
 
75
  image_path =os.path.join(self.FOLDER,'image.png')
76
  if os.path.exists(image_path) and os.path.exists(metadata_path):
 
 
77
  metadata = read_json_lines(metadata_path)
78
+ if metadata is not None:
79
+ img = Image.open(image_path)
80
+ self.images.append(img)
81
+ self.numbers.append(metadata[0]['correct_number'])
82
  assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
83
  def __len__(self):
84
  return len(self.images)
 
396
  return output,adversarial_number
397
 
398
  def get_number_dict(DATA_DIR):
399
+ """
400
+ It takes a directory as input, and returns a list of the number of times each number appears in the
401
+ metadata.jsonl files in that directory
402
+
403
+ :param DATA_DIR: The directory where the data is stored
404
+ """
405
  files = [f.name for f in os.scandir(DATA_DIR)]
406
+ metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files]
407
+ numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None]
408
  numbers_count = Counter(numbers)
409
  numbers_count_keys = list(numbers_count.keys())
410
  numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
 
433
  repo.git_pull()
434
  DATA_DIR = './data_mnist/data'
435
  numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
 
436
  STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
437
+ plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True)
 
438
 
439
  fig_d, ax_d = plt.subplots(tight_layout=True)
440
 
 
446
  ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
447
  except Exception:
448
  continue
449
+ ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1))
450
 
451
  else:
452
  metric_dict={}
453
 
454
  fig_d.legend()
455
  ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
456
+ done_html = f"""<div style="color: green">
457
+ <p> Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p>
 
458
  </div>
459
  """
 
460
  # Plot for total test accuracy for all digits
461
  fig_all, ax_all = plt.subplots(tight_layout=True)
462
  x_i = [i+1 for i in range(len(metric_dict['all']))]
463
 
464
  ax_all.plot(x_i, metric_dict['all'])
 
465
  ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
466
+ ax_all.set_xticks(range(0, x_i[-1]+1, 1))
467
+
468
  return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
469
 
470
 
 
490
 
491
  number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
492
 
493
+ gr.Markdown('Please wait a while after you press `Flag`. It takes time.')
494
  flag_btn = gr.Button("Flag")
495
 
496
  output_result = gr.outputs.HTML()
 
501
  flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
502
 
503
  with gr.TabItem('Dashboard') as dashboard:
504
+ get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}')
505
+ notification = gr.HTML(f"""<div style="color: green">
506
+ <p> ⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p>
507
  </div>
508
  """)
509
 
 
514
  gr.Markdown(DASHBOARD_EXPLANATION_TEST)
515
  test_results_all=gr.Plot(type="matplotlib")
516
 
517
+ #dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
518
+ get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
519
 
520
 
521
 
data_mnist CHANGED
@@ -1 +1 @@
1
- Subproject commit ed62a26e764902f519ff43df850842e07dfe2cc0
 
1
+ Subproject commit 0d5120c897f5b71d2f99b7fb2ef5dc28e3d7000d
utils.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import hashlib
4
  import random
5
  import string
 
6
  import matplotlib.pyplot as plt
7
 
8
  TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
@@ -25,7 +26,7 @@ MODEL_IS_WRONG = """
25
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
26
 
27
  DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
28
- DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
29
 
30
  STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
31
 
@@ -39,12 +40,16 @@ def read_json(file):
39
  return json.load(f)
40
 
41
  def read_json_lines(file):
42
- with open(file,'r',encoding="utf8") as f:
43
- lines = f.readlines()
44
- data=[]
45
- for l in lines:
46
- data.append(json.loads(l))
47
- return data
 
 
 
 
48
 
49
 
50
  def json_dump(thing):
@@ -63,11 +68,17 @@ def dump_json(thing,file):
63
  json.dump(thing,f)
64
 
65
 
66
- def plot_bar(value,name,x_name,y_name,title):
67
  fig, ax = plt.subplots(tight_layout=True)
68
 
69
  ax.set(xlabel=x_name, ylabel=y_name,title=title)
70
 
 
 
 
 
 
 
71
  ax.barh(name, value)
72
 
73
- return ax.figure
 
3
  import hashlib
4
  import random
5
  import string
6
+ import warnings
7
  import matplotlib.pyplot as plt
8
 
9
  TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
 
26
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
27
 
28
  DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
29
+ DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined."
30
 
31
  STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
32
 
 
40
  return json.load(f)
41
 
42
  def read_json_lines(file):
43
+ try:
44
+ with open(file,'r',encoding="utf8") as f:
45
+ lines = f.readlines()
46
+ data=[]
47
+ for l in lines:
48
+ data.append(json.loads(l))
49
+ return data
50
+ except Exception as err:
51
+ warnings.warn(f"{err}")
52
+ return None
53
 
54
 
55
  def json_dump(thing):
 
68
  json.dump(thing,f)
69
 
70
 
71
+ def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False):
72
  fig, ax = plt.subplots(tight_layout=True)
73
 
74
  ax.set(xlabel=x_name, ylabel=y_name,title=title)
75
 
76
+ if set_yticks:
77
+ ax.set_yticks(range(min(name), max(name)+1, 1))
78
+ if set_xticks:
79
+ ax.set_xticks(range(min(name), max(name)+1, 1))
80
+
81
+
82
  ax.barh(name, value)
83
 
84
+ return ax.figure