Abdullah-Nazhat
commited on
Commit
•
60f4ce6
1
Parent(s):
5e17b26
Update train.py
Browse files
train.py
CHANGED
@@ -158,7 +158,7 @@ def test(dataloader, model, loss_fn):
|
|
158 |
|
159 |
# apply train and test
|
160 |
|
161 |
-
logname = "/
|
162 |
if not os.path.exists(logname):
|
163 |
with open(logname, 'w') as logfile:
|
164 |
logwriter = csv.writer(logfile, delimiter=',')
|
@@ -170,9 +170,7 @@ epochs = 100
|
|
170 |
for epoch in range(epochs):
|
171 |
print(f"Epoch {epoch+1}\n-----------------------------------")
|
172 |
train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
|
173 |
-
|
174 |
-
#if scheduler is not None:
|
175 |
-
# scheduler.step()
|
176 |
test_loss, test_acc = test(test_dataloader, model, loss_fn)
|
177 |
with open(logname, 'a') as logfile:
|
178 |
logwriter = csv.writer(logfile, delimiter=',')
|
@@ -182,7 +180,7 @@ print("Done!")
|
|
182 |
|
183 |
# saving trained model
|
184 |
|
185 |
-
path = "/
|
186 |
model_name = "FourierizerImageClassification_cifar10"
|
187 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
188 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
158 |
|
159 |
# apply train and test
|
160 |
|
161 |
+
logname = "/PATH/Fourierizer/Experiments_cifar10/logs_fourierizer/logs_cifar10.csv"
|
162 |
if not os.path.exists(logname):
|
163 |
with open(logname, 'w') as logfile:
|
164 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
170 |
for epoch in range(epochs):
|
171 |
print(f"Epoch {epoch+1}\n-----------------------------------")
|
172 |
train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
|
173 |
+
|
|
|
|
|
174 |
test_loss, test_acc = test(test_dataloader, model, loss_fn)
|
175 |
with open(logname, 'a') as logfile:
|
176 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
180 |
|
181 |
# saving trained model
|
182 |
|
183 |
+
path = "/PATH/Fourierizer/Experiments_cifar10/weights_fourierizer"
|
184 |
model_name = "FourierizerImageClassification_cifar10"
|
185 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
186 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|