Abdullah-Nazhat commited on
Commit
60f4ce6
1 Parent(s): 5e17b26

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -5
train.py CHANGED
@@ -158,7 +158,7 @@ def test(dataloader, model, loss_fn):
158
 
159
  # apply train and test
160
 
161
- logname = "/home/abdullah/Desktop/Proposals_experiments/Fourierizer/Experiments_cifar10/logs_fourierizer/logs_cifar10.csv"
162
  if not os.path.exists(logname):
163
  with open(logname, 'w') as logfile:
164
  logwriter = csv.writer(logfile, delimiter=',')
@@ -170,9 +170,7 @@ epochs = 100
170
  for epoch in range(epochs):
171
  print(f"Epoch {epoch+1}\n-----------------------------------")
172
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
173
- # learning rate scheduler
174
- #if scheduler is not None:
175
- # scheduler.step()
176
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
177
  with open(logname, 'a') as logfile:
178
  logwriter = csv.writer(logfile, delimiter=',')
@@ -182,7 +180,7 @@ print("Done!")
182
 
183
  # saving trained model
184
 
185
- path = "/home/abdullah/Desktop/Proposals_experiments/Fourierizer/Experiments_cifar10/weights_fourierizer"
186
  model_name = "FourierizerImageClassification_cifar10"
187
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
188
  print(f"Saved Model State to {path}/{model_name}.pth ")
 
158
 
159
  # apply train and test
160
 
161
+ logname = "/PATH/Fourierizer/Experiments_cifar10/logs_fourierizer/logs_cifar10.csv"
162
  if not os.path.exists(logname):
163
  with open(logname, 'w') as logfile:
164
  logwriter = csv.writer(logfile, delimiter=',')
 
170
  for epoch in range(epochs):
171
  print(f"Epoch {epoch+1}\n-----------------------------------")
172
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
173
+
 
 
174
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
175
  with open(logname, 'a') as logfile:
176
  logwriter = csv.writer(logfile, delimiter=',')
 
180
 
181
  # saving trained model
182
 
183
+ path = "/PATH/Fourierizer/Experiments_cifar10/weights_fourierizer"
184
  model_name = "FourierizerImageClassification_cifar10"
185
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
186
  print(f"Saved Model State to {path}/{model_name}.pth ")