chrisjay commited on
Commit
35ee063
1 Parent(s): 866cafe

work on training and dashboard statistics 2

Browse files
Files changed (8) hide show
  1. .gitignore +2 -1
  2. .gitmodules +3 -0
  3. app.py +11 -13
  4. data_mnist +1 -1
  5. metrics.json +1 -0
  6. model.pth +1 -1
  7. optimizer.pth +1 -1
  8. utils.py +2 -2
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  __pycache__/*
2
  data_local/*
3
- flagged/*
 
 
1
  __pycache__/*
2
  data_local/*
3
+ flagged/*
4
+ data_mnist/*
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "data_mnist"]
2
+ path = data_mnist
3
+ url = https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset
app.py CHANGED
@@ -22,7 +22,7 @@ learning_rate = 0.01
22
  momentum = 0.5
23
  log_interval = 10
24
  random_seed = 1
25
- TRAIN_CUTOFF = 5
26
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
27
  METRIC_PATH = './metrics.json'
28
  REPOSITORY_DIR = "data"
@@ -216,8 +216,8 @@ def test():
216
  test_losses.append(test_loss)
217
  acc = 100. * correct / len(test_loader.dataset)
218
  acc = acc.item()
219
- test_metric = '〽Current test metric - Avg. loss: `{:.4f}`, Accuracy: `{}/{}` (`{:.0f}%`)\n'.format(
220
- test_loss, correct, len(test_loader.dataset),acc )
221
  return test_metric,acc
222
 
223
 
@@ -349,7 +349,7 @@ def flag(input_image,correct_result,adversarial_number):
349
  test_metric_ = train_and_test()
350
  test_metric = f"<html> {test_metric_} </html>"
351
  output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
352
- return output,test_metric,adversarial_number
353
 
354
  def get_number_dict(DATA_DIR):
355
  files = [f.name for f in os.scandir(DATA_DIR)]
@@ -376,10 +376,11 @@ def get_statistics():
376
  DATA_DIR = './data_mnist/data'
377
  numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
378
 
 
379
 
380
  plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
381
 
382
- fig_d, ax_d = plt.subplots(figsize=(10,4),tight_layout=True)
383
 
384
  if os.path.exists(METRIC_PATH):
385
  metric_dict = read_json(METRIC_PATH)
@@ -401,7 +402,7 @@ def get_statistics():
401
  </div>
402
  """
403
 
404
- return plt_digits,fig_d,done_html
405
 
406
 
407
 
@@ -417,7 +418,7 @@ def main():
417
  with gr.Tabs():
418
  with gr.TabItem('MNIST'):
419
  gr.Markdown(WHAT_TO_DO)
420
- test_metric = gr.outputs.HTML(DEFAULT_TEST_METRIC)
421
  with gr.Row():
422
 
423
 
@@ -435,23 +436,20 @@ def main():
435
 
436
 
437
  submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
438
- flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,test_metric,adversarial_number])
439
 
440
  with gr.TabItem('Dashboard') as dashboard:
441
  notification = gr.HTML("""<div style="color: green">
442
  <p> ⌛ Creating statistics... </p>
443
  </div>
444
  """)
445
- _,numbers_count_values_ = get_number_dict('./data_mnist/data')
446
 
447
- STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values_))
448
-
449
- gr.Markdown(STATS_EXPLANATION_)
450
  stat_adv_image =gr.Plot(type="matplotlib")
451
  gr.Markdown(DASHBOARD_EXPLANATION)
452
  test_results=gr.Plot(type="matplotlib")
453
 
454
- dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification])
455
 
456
 
457
 
 
22
  momentum = 0.5
23
  log_interval = 10
24
  random_seed = 1
25
+ TRAIN_CUTOFF = 10
26
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
27
  METRIC_PATH = './metrics.json'
28
  REPOSITORY_DIR = "data"
 
216
  test_losses.append(test_loss)
217
  acc = 100. * correct / len(test_loader.dataset)
218
  acc = acc.item()
219
+ test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
220
+ test_loss,acc)
221
  return test_metric,acc
222
 
223
 
 
349
  test_metric_ = train_and_test()
350
  test_metric = f"<html> {test_metric_} </html>"
351
  output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
352
+ return output,adversarial_number
353
 
354
  def get_number_dict(DATA_DIR):
355
  files = [f.name for f in os.scandir(DATA_DIR)]
 
376
  DATA_DIR = './data_mnist/data'
377
  numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
378
 
379
+ STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
380
 
381
  plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
382
 
383
+ fig_d, ax_d = plt.subplots(tight_layout=True)
384
 
385
  if os.path.exists(METRIC_PATH):
386
  metric_dict = read_json(METRIC_PATH)
 
402
  </div>
403
  """
404
 
405
+ return plt_digits,fig_d,done_html,STATS_EXPLANATION_
406
 
407
 
408
 
 
418
  with gr.Tabs():
419
  with gr.TabItem('MNIST'):
420
  gr.Markdown(WHAT_TO_DO)
421
+ #test_metric = gr.outputs.HTML("")
422
  with gr.Row():
423
 
424
 
 
436
 
437
 
438
  submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
439
+ flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
440
 
441
  with gr.TabItem('Dashboard') as dashboard:
442
  notification = gr.HTML("""<div style="color: green">
443
  <p> ⌛ Creating statistics... </p>
444
  </div>
445
  """)
 
446
 
447
+ stats = gr.Markdown()
 
 
448
  stat_adv_image =gr.Plot(type="matplotlib")
449
  gr.Markdown(DASHBOARD_EXPLANATION)
450
  test_results=gr.Plot(type="matplotlib")
451
 
452
+ dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
453
 
454
 
455
 
data_mnist CHANGED
@@ -1 +1 @@
1
- Subproject commit eb1e3cf9de597112c1da3b921ffcd07c8e4419c1
 
1
+ Subproject commit c6d1292ac6318c7c44131ca2fb18d37535ae1383
metrics.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"all": [10.55875015258789], "0": [0.0], "1": [0.0], "2": [0.0], "3": [43.33333206176758], "4": [86.66666412353516], "5": [0.0], "6": [0.0], "7": [0.0], "8": [0.0], "9": [0.0]}
model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:70d9632141eaf1062d2973b2fc9c4c9b286d6ae563297328ed929be8401b4a15
3
  size 89871
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0615455222f0654123d29490ed6fa00db335abb7bc856a817ed8069c03cfaf42
3
  size 89871
optimizer.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c6a2a19c5c0dc6aabfa195bfc337c3d9c29e5ef78c546b65b08aa86dcc60287
3
  size 89807
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cfa224990c352a3ad53a41d439d5dd790358bb1e0acb9d3d63379f5c9d0ba7e
3
  size 89807
utils.py CHANGED
@@ -22,7 +22,7 @@ WHAT_TO_DO="""
22
  MODEL_IS_WRONG = """
23
  ---
24
 
25
- > Did the model get it wrong? Choose the correct prediction below and flag it. When you flag it, the instance is saved to our dataset and the model is trained on it.
26
  """
27
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
28
 
@@ -65,7 +65,7 @@ def dump_json(thing,file):
65
 
66
 
67
  def plot_bar(value,name,x_name,y_name,title):
68
- fig, ax = plt.subplots(figsize=(10,4),tight_layout=True)
69
 
70
  ax.set(xlabel=x_name, ylabel=y_name,title=title)
71
 
 
22
  MODEL_IS_WRONG = """
23
  ---
24
 
25
+ ### Did the model get it wrong? Choose the correct prediction below and flag it. When you flag it, the instance is saved to our dataset and the model is trained on it.
26
  """
27
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
28
 
 
65
 
66
 
67
  def plot_bar(value,name,x_name,y_name,title):
68
+ fig, ax = plt.subplots(tight_layout=True)
69
 
70
  ax.set(xlabel=x_name, ylabel=y_name,title=title)
71