chrisjay commited on
Commit
5bd1489
1 Parent(s): 7583157

fixing upload of models and metrics

Browse files
Files changed (3) hide show
  1. app.py +62 -41
  2. data_mnist +1 -1
  3. utils.py +1 -0
app.py CHANGED
@@ -231,46 +231,16 @@ optimizer = optim.SGD(network.parameters(), lr=learning_rate,
231
  momentum=momentum)
232
 
233
 
234
- model_state_dict = MODEL_WEIGHTS_PATH
235
- optimizer_state_dict = OPTIMIZER_PATH
236
- model_repo.git_pull()
237
- if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
238
- network_state_dict = torch.load(model_state_dict)
239
- network.load_state_dict(network_state_dict)
240
 
241
- optimizer_state_dict = torch.load(optimizer_state_dict)
242
- optimizer.load_state_dict(optimizer_state_dict)
 
243
 
244
- # Train
245
- #train(n_epochs,network,optimizer)
246
-
247
-
248
- def image_classifier(inp):
249
- """
250
- It takes an image as input and returns a dictionary of class labels and their corresponding
251
- confidence scores.
252
 
253
- :param inp: the image to be classified
254
- :return: A dictionary of the class index and the confidence value.
255
- """
256
- input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
257
- with torch.no_grad():
258
-
259
- prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
260
- #pred_number = prediction.data.max(1, keepdim=True)[1]
261
- sorted_prediction = torch.sort(prediction,descending=True)
262
- confidences={}
263
- for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
264
- confidences.update({s:v})
265
- return confidences
266
-
267
- def train_and_test():
268
- # Train for one epoch and test
269
- train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
270
-
271
- train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
272
- )
273
- train(n_epochs,network,optimizer,train_loader)
274
  test_metric,test_acc = test()
275
 
276
  if os.path.exists(METRIC_PATH):
@@ -301,6 +271,48 @@ def train_and_test():
301
 
302
  return test_metric
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  def flag(input_image,correct_result,adversarial_number):
305
 
306
  adversarial_number = 0 if None else adversarial_number
@@ -380,7 +392,7 @@ def get_statistics():
380
 
381
  STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
382
 
383
- plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
384
 
385
  fig_d, ax_d = plt.subplots(tight_layout=True)
386
 
@@ -392,7 +404,7 @@ def get_statistics():
392
  ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
393
  except Exception:
394
  continue
395
- dump_json(thing=metric_dict,file=METRIC_PATH)
396
  else:
397
  metric_dict={}
398
 
@@ -404,8 +416,15 @@ def get_statistics():
404
  </div>
405
  """
406
 
407
- return plt_digits,fig_d,done_html,STATS_EXPLANATION_
 
 
 
 
 
 
408
 
 
409
 
410
 
411
 
@@ -453,8 +472,10 @@ def main():
453
  stat_adv_image =gr.Plot(type="matplotlib")
454
  gr.Markdown(DASHBOARD_EXPLANATION)
455
  test_results=gr.Plot(type="matplotlib")
 
 
456
 
457
- dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
458
 
459
 
460
 
 
231
  momentum=momentum)
232
 
233
 
234
+ def train_and_test(train=True):
 
 
 
 
 
235
 
236
+ if train:
237
+ # Train for one epoch and test
238
+ train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
239
 
240
+ train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
241
+ )
242
+ train(n_epochs,network,optimizer,train_loader)
 
 
 
 
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  test_metric,test_acc = test()
245
 
246
  if os.path.exists(METRIC_PATH):
 
271
 
272
  return test_metric
273
 
274
+
275
+ model_state_dict = MODEL_WEIGHTS_PATH
276
+ optimizer_state_dict = OPTIMIZER_PATH
277
+ model_repo.git_pull()
278
+ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
279
+ network_state_dict = torch.load(model_state_dict)
280
+ network.load_state_dict(network_state_dict)
281
+
282
+ optimizer_state_dict = torch.load(optimizer_state_dict)
283
+ optimizer.load_state_dict(optimizer_state_dict)
284
+
285
+ else:
286
+ # Evaluate model to get initial evaluation with no adversarial training
287
+ torch.save(network.state_dict(), MODEL_WEIGHTS_PATH)
288
+ torch.save(optimizer.state_dict(), OPTIMIZER_PATH)
289
+ _ = train_and_test(False)
290
+
291
+
292
+ # Train
293
+ #train(n_epochs,network,optimizer)
294
+
295
+
296
+ def image_classifier(inp):
297
+ """
298
+ It takes an image as input and returns a dictionary of class labels and their corresponding
299
+ confidence scores.
300
+
301
+ :param inp: the image to be classified
302
+ :return: A dictionary of the class index and the confidence value.
303
+ """
304
+ input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
305
+ with torch.no_grad():
306
+
307
+ prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
308
+ #pred_number = prediction.data.max(1, keepdim=True)[1]
309
+ sorted_prediction = torch.sort(prediction,descending=True)
310
+ confidences={}
311
+ for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
312
+ confidences.update({s:v})
313
+ return confidences
314
+
315
+
316
  def flag(input_image,correct_result,adversarial_number):
317
 
318
  adversarial_number = 0 if None else adversarial_number
 
392
 
393
  STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
394
 
395
+ plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit")
396
 
397
  fig_d, ax_d = plt.subplots(tight_layout=True)
398
 
 
404
  ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
405
  except Exception:
406
  continue
407
+
408
  else:
409
  metric_dict={}
410
 
 
416
  </div>
417
  """
418
 
419
+ # Plot for total test accuracy for all digits
420
+ fig_all, ax_all = plt.subplots(tight_layout=True)
421
+ x_i = [i+1 for i in range(len(metric_dict['all']))]
422
+
423
+ ax_all.plot(x_i, metric_dict['all'])
424
+ fig_all.legend()
425
+ ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
426
 
427
+ return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
428
 
429
 
430
 
 
472
  stat_adv_image =gr.Plot(type="matplotlib")
473
  gr.Markdown(DASHBOARD_EXPLANATION)
474
  test_results=gr.Plot(type="matplotlib")
475
+ gr.Markdown(DASHBOARD_EXPLANATION_TEST)
476
+ test_results_all=gr.Plot(type="matplotlib")
477
 
478
+ dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
479
 
480
 
481
 
data_mnist CHANGED
@@ -1 +1 @@
1
- Subproject commit fc4a0229b36d92306494d4d1cafb256d1f480046
 
1
+ Subproject commit b85a2ad15f628eb33a6595afbaba38cfb6a98ece
utils.py CHANGED
@@ -26,6 +26,7 @@ MODEL_IS_WRONG = """
26
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
27
 
28
  DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
 
29
 
30
  STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
31
 
 
26
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
27
 
28
  DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
29
+ DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
30
 
31
  STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
32