chrisjay commited on
Commit
c4c6bd6
1 Parent(s): d4894f5

enabled reloading of trained weights + using best weights

Browse files
Files changed (5) hide show
  1. app.py +50 -13
  2. best_weights/mnist_model.pth +3 -0
  3. best_weights/optimizer.pth +3 -0
  4. data_mnist +1 -1
  5. utils.py +1 -1
app.py CHANGED
@@ -24,6 +24,8 @@ momentum = 0.5
24
  log_interval = 10
25
  random_seed = 1
26
  TRAIN_CUTOFF = 10
 
 
27
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
28
  MODEL_PATH = 'model'
29
  METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
@@ -86,7 +88,7 @@ class MNISTAdversarial_Dataset(Dataset):
86
  return img, label
87
 
88
  class MNISTCorrupted_By_Digit(Dataset):
89
- def __init__(self,transform,digit,limit=500):
90
  self.transform = transform
91
  self.digit = digit
92
  corrupted_dir="./mnist_c"
@@ -127,8 +129,8 @@ class MNISTCorrupted(Dataset):
127
  self.transform = transform
128
  corrupted_dir="./mnist_c"
129
  files = [f.name for f in os.scandir(corrupted_dir)]
130
- images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:500] for f in files]
131
- labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:500] for f in files]
132
  self.data = np.vstack(images)
133
  self.labels = np.hstack(labels)
134
 
@@ -283,24 +285,40 @@ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
283
  optimizer.load_state_dict(optimizer_state_dict)
284
 
285
  else:
286
- # Evaluate model to get initial evaluation with no adversarial training
287
- torch.save(network.state_dict(), MODEL_WEIGHTS_PATH)
288
- torch.save(optimizer.state_dict(), OPTIMIZER_PATH)
 
 
289
  _ = train_and_test(False)
290
 
291
 
292
- # Train
293
- #train(n_epochs,network,optimizer)
294
-
295
-
296
  def image_classifier(inp):
297
  """
298
- It takes an image as input and returns a dictionary of class labels and their corresponding
299
- confidence scores.
300
 
301
  :param inp: the image to be classified
302
- :return: A dictionary of the class index and the confidence value.
303
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
305
  with torch.no_grad():
306
 
@@ -314,6 +332,19 @@ def image_classifier(inp):
314
 
315
 
316
  def flag(input_image,correct_result,adversarial_number):
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  adversarial_number = 0 if None else adversarial_number
319
 
@@ -375,6 +406,12 @@ def get_number_dict(DATA_DIR):
375
 
376
 
377
  def get_statistics():
 
 
 
 
 
 
378
  model_repo.git_pull()
379
  model_state_dict = MODEL_WEIGHTS_PATH
380
  optimizer_state_dict = OPTIMIZER_PATH
 
24
  log_interval = 10
25
  random_seed = 1
26
  TRAIN_CUTOFF = 10
27
+ TEST_PER_SAMPLE = 1500
28
+ DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
29
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
30
  MODEL_PATH = 'model'
31
  METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
 
88
  return img, label
89
 
90
  class MNISTCorrupted_By_Digit(Dataset):
91
+ def __init__(self,transform,digit,limit=TEST_PER_SAMPLE):
92
  self.transform = transform
93
  self.digit = digit
94
  corrupted_dir="./mnist_c"
 
129
  self.transform = transform
130
  corrupted_dir="./mnist_c"
131
  files = [f.name for f in os.scandir(corrupted_dir)]
132
+ images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files]
133
+ labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files]
134
  self.data = np.vstack(images)
135
  self.labels = np.hstack(labels)
136
 
 
285
  optimizer.load_state_dict(optimizer_state_dict)
286
 
287
  else:
288
+ # Use best weights
289
+ BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
290
+ BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
291
+ torch.save(network.state_dict(), BEST_WEIGHTS_MODEL)
292
+ torch.save(optimizer.state_dict(), BEST_WEIGHTS_OPTIMIZER)
293
  _ = train_and_test(False)
294
 
295
 
 
 
 
 
296
  def image_classifier(inp):
297
  """
298
+ It loads the latest model weights from the model repository, and then uses those weights to make a
299
+ prediction on the input image.
300
 
301
  :param inp: the image to be classified
302
+ :return: A dictionary of the form {class_number: confidence}
303
  """
304
+
305
+ # Get latest model weights ----------------
306
+ model_repo.git_pull()
307
+ model_state_dict = MODEL_WEIGHTS_PATH
308
+ optimizer_state_dict = OPTIMIZER_PATH
309
+
310
+ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
311
+ network_state_dict = torch.load(model_state_dict)
312
+ network.load_state_dict(network_state_dict)
313
+ optimizer_state_dict = torch.load(optimizer_state_dict)
314
+ optimizer.load_state_dict(optimizer_state_dict)
315
+ else:
316
+ # Use best weights
317
+ BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
318
+ BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
319
+ network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
320
+ optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
321
+
322
  input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
323
  with torch.no_grad():
324
 
 
332
 
333
 
334
  def flag(input_image,correct_result,adversarial_number):
335
+ """
336
+ It takes in an image, the correct result, and the number of adversarial images that have been
337
+ uploaded so far. It saves the image and metadata to a local directory, uploads the image and
338
+ metadata to the hub, and then pulls the data from the hub to the local directory. If the number of
339
+ images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the
340
+ adversarial data
341
+
342
+ :param input_image: The adversarial image that you want to save
343
+ :param correct_result: The correct number that the image represents
344
+ :param adversarial_number: This is the number of adversarial examples that have been uploaded to the
345
+ dataset
346
+ :return: The output is the output of the flag function.
347
+ """
348
 
349
  adversarial_number = 0 if None else adversarial_number
350
 
 
406
 
407
 
408
  def get_statistics():
409
+ """
410
+ It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number
411
+ of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the
412
+ test accuracy per digit per train step, and plots the test accuracy for all digits per train step
413
+ :return: the following:
414
+ """
415
  model_repo.git_pull()
416
  model_state_dict = MODEL_WEIGHTS_PATH
417
  optimizer_state_dict = OPTIMIZER_PATH
best_weights/mnist_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba8d282674beb300db53069e4972cfed358f8c7c627cf449215e44b365fcdc54
3
+ size 89871
best_weights/optimizer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe255c0ca501d01ae3c2083ea760ea95759fcfe9075e39fba299c57a9907bf1b
3
+ size 623
data_mnist CHANGED
@@ -1 +1 @@
1
- Subproject commit b85a2ad15f628eb33a6595afbaba38cfb6a98ece
 
1
+ Subproject commit ed62a26e764902f519ff43df850842e07dfe2cc0
utils.py CHANGED
@@ -24,7 +24,7 @@ MODEL_IS_WRONG = """
24
  """
25
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
26
 
27
- DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
28
  DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
29
 
30
  STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
 
24
  """
25
  DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
26
 
27
+ DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
28
  DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
29
 
30
  STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."