chrisjay commited on
Commit
603879a
1 Parent(s): c4c6bd6

fixed issue with test acc on line 249

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -239,15 +239,14 @@ def train_and_test(train_model=True):
239
  # Train for one epoch and test
240
  train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
241
 
242
- train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
243
- )
244
  train(n_epochs,network,optimizer,train_loader)
245
 
246
  test_metric,test_acc = test()
247
 
248
  if os.path.exists(METRIC_PATH):
249
  metric_dict = read_json(METRIC_PATH)
250
- metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc]
251
  else:
252
  metric_dict={}
253
  metric_dict['all'] = [test_acc]
239
  # Train for one epoch and test
240
  train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
241
 
242
+ train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True)
 
243
  train(n_epochs,network,optimizer,train_loader)
244
 
245
  test_metric,test_acc = test()
246
 
247
  if os.path.exists(METRIC_PATH):
248
  metric_dict = read_json(METRIC_PATH)
249
+ metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
250
  else:
251
  metric_dict={}
252
  metric_dict['all'] = [test_acc]