chrisjay commited on
Commit
6142233
2 Parent(s): e4a62fe 329b525

Merge branch 'main' of https://huggingface.co/spaces/chrisjay/mnist-adversarial

Browse files
.gitignore CHANGED
@@ -4,4 +4,5 @@ flagged/*
4
  data_mnist/*
5
  model/*
6
  model
7
- data_mnist
 
4
  data_mnist/*
5
  model/*
6
  model
7
+ data_mnist
8
+ slurm*
README.md CHANGED
@@ -10,3 +10,4 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
app.py CHANGED
@@ -20,11 +20,12 @@ n_epochs = 10
20
  batch_size_train = 128
21
  batch_size_test = 1000
22
  learning_rate = 0.01
 
23
  momentum = 0.5
24
  log_interval = 10
25
  random_seed = 1
26
  TRAIN_CUTOFF = 10
27
- TEST_PER_SAMPLE = 1500
28
  DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
29
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
30
  MODEL_PATH = 'model'
@@ -163,7 +164,6 @@ TRAIN_TRANSFORM = torchvision.transforms.Compose([
163
  test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
164
  batch_size=batch_size_test, shuffle=False)
165
 
166
-
167
  # Source: https://nextjournal.com/gkoehler/pytorch-mnist
168
  class MNIST_Model(nn.Module):
169
  def __init__(self):
@@ -221,6 +221,7 @@ def test():
221
  acc = acc.item()
222
  test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
223
  test_loss,acc)
 
224
  return test_metric,acc
225
 
226
 
@@ -234,6 +235,34 @@ optimizer = optim.SGD(network.parameters(), lr=learning_rate,
234
  momentum=momentum)
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def train_and_test(train_model=True):
238
 
239
  if train_model:
@@ -245,6 +274,7 @@ def train_and_test(train_model=True):
245
 
246
  test_metric,test_acc = test()
247
 
 
248
  if os.path.exists(METRIC_PATH):
249
  metric_dict = read_json(METRIC_PATH)
250
  metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
@@ -274,6 +304,7 @@ def train_and_test(train_model=True):
274
  return test_metric
275
 
276
 
 
277
  model_state_dict = MODEL_WEIGHTS_PATH
278
  optimizer_state_dict = OPTIMIZER_PATH
279
  model_repo.git_pull()
@@ -288,9 +319,14 @@ else:
288
  # Use best weights
289
  BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
290
  BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
291
- torch.save(network.state_dict(), BEST_WEIGHTS_MODEL)
292
- torch.save(optimizer.state_dict(), BEST_WEIGHTS_OPTIMIZER)
293
- _ = train_and_test(False)
 
 
 
 
 
294
 
295
 
296
  def image_classifier(inp):
@@ -306,20 +342,23 @@ def image_classifier(inp):
306
  model_repo.git_pull()
307
  model_state_dict = MODEL_WEIGHTS_PATH
308
  optimizer_state_dict = OPTIMIZER_PATH
 
309
 
310
  if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
 
311
  network_state_dict = torch.load(model_state_dict)
312
  network.load_state_dict(network_state_dict)
313
  optimizer_state_dict = torch.load(optimizer_state_dict)
314
  optimizer.load_state_dict(optimizer_state_dict)
315
  else:
316
- # Use best weights
 
317
  BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
318
  BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
319
  network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
320
  optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
321
-
322
- input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
323
  with torch.no_grad():
324
 
325
  prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
20
  batch_size_train = 128
21
  batch_size_test = 1000
22
  learning_rate = 0.01
23
+ adv_learning_rate= 0.001
24
  momentum = 0.5
25
  log_interval = 10
26
  random_seed = 1
27
  TRAIN_CUTOFF = 10
28
+ TEST_PER_SAMPLE = 5000
29
  DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
30
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
31
  MODEL_PATH = 'model'
164
  test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
165
  batch_size=batch_size_test, shuffle=False)
166
 
 
167
  # Source: https://nextjournal.com/gkoehler/pytorch-mnist
168
  class MNIST_Model(nn.Module):
169
  def __init__(self):
221
  acc = acc.item()
222
  test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
223
  test_loss,acc)
224
+ print(test_metric)
225
  return test_metric,acc
226
 
227
 
235
  momentum=momentum)
236
 
237
 
238
+
239
+ train_loader = torch.utils.data.DataLoader(
240
+ torchvision.datasets.MNIST('./files/', train=True, download=True,
241
+ transform=TRAIN_TRANSFORM),
242
+ batch_size=batch_size_train, shuffle=True)
243
+
244
+ test_iid_loader = torch.utils.data.DataLoader(
245
+ torchvision.datasets.MNIST('./files/', train=False, download=True,
246
+ transform=TRAIN_TRANSFORM),
247
+ batch_size=batch_size_test, shuffle=True)
248
+
249
+ model_state_dict = MODEL_WEIGHTS_PATH
250
+ optimizer_state_dict = OPTIMIZER_PATH
251
+ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
252
+ network_state_dict = torch.load(model_state_dict)
253
+ network.load_state_dict(network_state_dict)
254
+
255
+ optimizer_state_dict = torch.load(optimizer_state_dict)
256
+ optimizer.load_state_dict(optimizer_state_dict)
257
+
258
+ # Train model
259
+ #n_epochs=20
260
+ #train(n_epochs,network,optimizer,train_loader)
261
+ #test()
262
+
263
+
264
+
265
+
266
  def train_and_test(train_model=True):
267
 
268
  if train_model:
274
 
275
  test_metric,test_acc = test()
276
 
277
+ network.eval()
278
  if os.path.exists(METRIC_PATH):
279
  metric_dict = read_json(METRIC_PATH)
280
  metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
304
  return test_metric
305
 
306
 
307
+ # Update model weights again
308
  model_state_dict = MODEL_WEIGHTS_PATH
309
  optimizer_state_dict = OPTIMIZER_PATH
310
  model_repo.git_pull()
319
  # Use best weights
320
  BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
321
  BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
322
+
323
+ network_state_dict = torch.load(BEST_WEIGHTS_MODEL)
324
+ network.load_state_dict(network_state_dict)
325
+
326
+ optimizer_state_dict = torch.load(BEST_WEIGHTS_OPTIMIZER)
327
+ optimizer.load_state_dict(optimizer_state_dict)
328
+ if not os.path.exists(METRIC_PATH):
329
+ _ = train_and_test(False)
330
 
331
 
332
  def image_classifier(inp):
342
  model_repo.git_pull()
343
  model_state_dict = MODEL_WEIGHTS_PATH
344
  optimizer_state_dict = OPTIMIZER_PATH
345
+ which_weights=''
346
 
347
  if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
348
+ which_weights = "Using weights from model repo"
349
  network_state_dict = torch.load(model_state_dict)
350
  network.load_state_dict(network_state_dict)
351
  optimizer_state_dict = torch.load(optimizer_state_dict)
352
  optimizer.load_state_dict(optimizer_state_dict)
353
  else:
354
+ # Use best weights
355
+ which_weights = "Using default best weights"
356
  BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
357
  BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
358
  network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
359
  optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
360
+ network.eval()
361
+ input_image = TRAIN_TRANSFORM(inp).unsqueeze(0)
362
  with torch.no_grad():
363
 
364
  prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
best_weights/mnist_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba8d282674beb300db53069e4972cfed358f8c7c627cf449215e44b365fcdc54
3
  size 89871
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:148112958ca9545938f0660cec604ac4c7f52dca3523091e1e8e4e6a26e1ebc7
3
  size 89871
best_weights/optimizer.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fe255c0ca501d01ae3c2083ea760ea95759fcfe9075e39fba299c57a9907bf1b
3
- size 623
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aac1c136737d50c2665563392a5b220396398cf1e2a2049dbefd7dc95473f5a5
3
+ size 89807
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  torch
2
  torchvision
3
- matplotlib
 
 
1
  torch
2
  torchvision
3
+ matplotlib
4
+ gradio
5
+ huggingface_hub