chrisjay commited on
Commit
46eef9f
1 Parent(s): 213a820

save models and metrics to hub

Browse files
Files changed (3) hide show
  1. .gitignore +4 -1
  2. app.py +28 -19
  3. data_mnist +1 -1
.gitignore CHANGED
@@ -1,4 +1,7 @@
1
  __pycache__/*
2
  data_local/*
3
  flagged/*
4
- data_mnist/*
 
 
 
 
1
  __pycache__/*
2
  data_local/*
3
  flagged/*
4
+ data_mnist/*
5
+ model/*
6
+ model
7
+ data_mnist
app.py CHANGED
@@ -25,7 +25,10 @@ log_interval = 10
25
  random_seed = 1
26
  TRAIN_CUTOFF = 10
27
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
28
- METRIC_PATH = './metrics.json'
 
 
 
29
  REPOSITORY_DIR = "data"
30
  LOCAL_DIR = 'data_local'
31
  os.makedirs(LOCAL_DIR,exist_ok=True)
@@ -34,14 +37,21 @@ os.makedirs(LOCAL_DIR,exist_ok=True)
34
 
35
 
36
  HF_TOKEN = os.getenv("HF_TOKEN")
37
-
38
  HF_DATASET ="mnist-adversarial-dataset"
39
  DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
 
 
40
  repo = Repository(
41
  local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
42
  )
43
  repo.git_pull()
44
 
 
 
 
 
 
45
  torch.backends.cudnn.enabled = False
46
  torch.manual_seed(random_seed)
47
 
@@ -76,7 +86,7 @@ class MNISTAdversarial_Dataset(Dataset):
76
  return img, label
77
 
78
  class MNISTCorrupted_By_Digit(Dataset):
79
- def __init__(self,transform,digit,limit=300):
80
  self.transform = transform
81
  self.digit = digit
82
  corrupted_dir="./mnist_c"
@@ -112,15 +122,13 @@ class MNISTCorrupted_By_Digit(Dataset):
112
 
113
  return image, label
114
 
115
-
116
-
117
  class MNISTCorrupted(Dataset):
118
  def __init__(self,transform):
119
  self.transform = transform
120
  corrupted_dir="./mnist_c"
121
  files = [f.name for f in os.scandir(corrupted_dir)]
122
- images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:300] for f in files]
123
- labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:300] for f in files]
124
  self.data = np.vstack(images)
125
  self.labels = np.hstack(labels)
126
 
@@ -142,7 +150,6 @@ class MNISTCorrupted(Dataset):
142
  return image, label
143
 
144
 
145
-
146
  TRAIN_TRANSFORM = torchvision.transforms.Compose([
147
  torchvision.transforms.ToTensor(),
148
  torchvision.transforms.Normalize(
@@ -191,8 +198,8 @@ def train(epochs,network,optimizer,train_loader):
191
  100. * batch_idx / len(train_loader), loss.item()))
192
  train_losses.append(loss.item())
193
 
194
- torch.save(network.state_dict(), 'model.pth')
195
- torch.save(optimizer.state_dict(), 'optimizer.pth')
196
 
197
  def test():
198
  test_losses=[]
@@ -224,19 +231,16 @@ optimizer = optim.SGD(network.parameters(), lr=learning_rate,
224
  momentum=momentum)
225
 
226
 
227
- model_state_dict = 'model.pth'
228
- optimizer_state_dict = 'optmizer.pth'
229
-
230
- if os.path.exists(model_state_dict):
231
  network_state_dict = torch.load(model_state_dict)
232
  network.load_state_dict(network_state_dict)
233
 
234
- if os.path.exists(optimizer_state_dict):
235
  optimizer_state_dict = torch.load(optimizer_state_dict)
236
  optimizer.load_state_dict(optimizer_state_dict)
237
 
238
-
239
-
240
  # Train
241
  #train(n_epochs,network,optimizer)
242
 
@@ -291,6 +295,10 @@ def train_and_test():
291
  metric_dict[str(i)] = [acc]
292
 
293
  dump_json(thing=metric_dict,file=METRIC_PATH)
 
 
 
 
294
  return test_metric
295
 
296
  def flag(input_image,correct_result,adversarial_number):
@@ -355,8 +363,9 @@ def get_number_dict(DATA_DIR):
355
 
356
 
357
  def get_statistics():
358
- model_state_dict = 'model.pth'
359
- optimizer_state_dict = 'optmizer.pth'
 
360
 
361
  if os.path.exists(model_state_dict):
362
  network_state_dict = torch.load(model_state_dict)
 
25
  random_seed = 1
26
  TRAIN_CUTOFF = 10
27
  WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
28
+ MODEL_PATH = 'model'
29
+ METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
30
+ MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth')
31
+ OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth')
32
  REPOSITORY_DIR = "data"
33
  LOCAL_DIR = 'data_local'
34
  os.makedirs(LOCAL_DIR,exist_ok=True)
 
37
 
38
 
39
  HF_TOKEN = os.getenv("HF_TOKEN")
40
+ MODEL_REPO = 'mnist-adversarial-model'
41
  HF_DATASET ="mnist-adversarial-dataset"
42
  DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
43
+ MODEL_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{MODEL_REPO}"
44
+
45
  repo = Repository(
46
  local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
47
  )
48
  repo.git_pull()
49
 
50
+ model_repo = Repository(
51
+ local_dir=MODEL_PATH, clone_from=MODEL_REPO_URL, use_auth_token=HF_TOKEN
52
+ )
53
+ model_repo.git_pull()
54
+
55
  torch.backends.cudnn.enabled = False
56
  torch.manual_seed(random_seed)
57
 
 
86
  return img, label
87
 
88
  class MNISTCorrupted_By_Digit(Dataset):
89
+ def __init__(self,transform,digit,limit=500):
90
  self.transform = transform
91
  self.digit = digit
92
  corrupted_dir="./mnist_c"
 
122
 
123
  return image, label
124
 
 
 
125
  class MNISTCorrupted(Dataset):
126
  def __init__(self,transform):
127
  self.transform = transform
128
  corrupted_dir="./mnist_c"
129
  files = [f.name for f in os.scandir(corrupted_dir)]
130
+ images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:500] for f in files]
131
+ labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:500] for f in files]
132
  self.data = np.vstack(images)
133
  self.labels = np.hstack(labels)
134
 
 
150
  return image, label
151
 
152
 
 
153
  TRAIN_TRANSFORM = torchvision.transforms.Compose([
154
  torchvision.transforms.ToTensor(),
155
  torchvision.transforms.Normalize(
 
198
  100. * batch_idx / len(train_loader), loss.item()))
199
  train_losses.append(loss.item())
200
 
201
+ torch.save(network.state_dict(), MODEL_WEIGHTS_PATH)
202
+ torch.save(optimizer.state_dict(), OPTIMIZER_PATH)
203
 
204
  def test():
205
  test_losses=[]
 
231
  momentum=momentum)
232
 
233
 
234
+ model_state_dict = MODEL_WEIGHTS_PATH
235
+ optimizer_state_dict = OPTIMIZER_PATH
236
+ model_repo.git_pull()
237
+ if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
238
  network_state_dict = torch.load(model_state_dict)
239
  network.load_state_dict(network_state_dict)
240
 
 
241
  optimizer_state_dict = torch.load(optimizer_state_dict)
242
  optimizer.load_state_dict(optimizer_state_dict)
243
 
 
 
244
  # Train
245
  #train(n_epochs,network,optimizer)
246
 
 
295
  metric_dict[str(i)] = [acc]
296
 
297
  dump_json(thing=metric_dict,file=METRIC_PATH)
298
+
299
+ # Push models and metrics to hub
300
+ model_repo.push_to_hub()
301
+
302
  return test_metric
303
 
304
  def flag(input_image,correct_result,adversarial_number):
 
363
 
364
 
365
  def get_statistics():
366
+ model_repo.git_pull()
367
+ model_state_dict = MODEL_WEIGHTS_PATH
368
+ optimizer_state_dict = OPTIMIZER_PATH
369
 
370
  if os.path.exists(model_state_dict):
371
  network_state_dict = torch.load(model_state_dict)
data_mnist CHANGED
@@ -1 +1 @@
1
- Subproject commit 650e2ac4a86b5e109a12b5adc7bf6436bbe578de
 
1
+ Subproject commit 5915a9276e314d92a5b533b5312616b28b9bcee5