Rename model.py to train.py
Browse files- model.py → train.py +7 -7
model.py → train.py
RENAMED
@@ -1,16 +1,15 @@
|
|
1 |
-
import torch
|
2 |
-
import torchvision.transforms as transforms
|
3 |
-
from torchvision.datasets import ImageFolder
|
4 |
-
from torchvision import models
|
5 |
import os
|
6 |
import tarfile
|
7 |
import shutil
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Set the paths
|
11 |
data_dir = "C:/Users/kazia/Downloads/Trees"
|
12 |
train_dir = os.path.join(data_dir, "Train")
|
13 |
-
test_dir = os.path.join(data_dir, "Test")
|
14 |
checkpoint_dir = "resnet_c_s_redwood_finetuned"
|
15 |
tar_file = "resnet_c_s_redwood_finetuned.tar.gz"
|
16 |
|
@@ -77,6 +76,7 @@ torch.save(model.state_dict(), os.path.join(checkpoint_dir, "model.pth"))
|
|
77 |
print("Creating tar.gz file...")
|
78 |
with tarfile.open(tar_file, "w:gz") as tar:
|
79 |
tar.add(checkpoint_dir, arcname=os.path.basename(checkpoint_dir))
|
|
|
80 |
# Move the tar.gz file to the desired location
|
81 |
destination_path = "C:/Users/kazia/Downloads/resnet_c_s_redwood_finetuned.tar.gz"
|
82 |
shutil.move(tar_file, destination_path)
|
@@ -84,4 +84,4 @@ shutil.move(tar_file, destination_path)
|
|
84 |
# Cleanup the intermediate checkpoint directory
|
85 |
shutil.rmtree(checkpoint_dir)
|
86 |
|
87 |
-
print("Fine-tuned model saved and compressed as tar.gz.")
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import tarfile
|
3 |
import shutil
|
4 |
+
import torch
|
5 |
+
import torchvision.models as models
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
|
9 |
|
10 |
# Set the paths
|
11 |
data_dir = "C:/Users/kazia/Downloads/Trees"
|
12 |
train_dir = os.path.join(data_dir, "Train")
|
|
|
13 |
checkpoint_dir = "resnet_c_s_redwood_finetuned"
|
14 |
tar_file = "resnet_c_s_redwood_finetuned.tar.gz"
|
15 |
|
|
|
76 |
print("Creating tar.gz file...")
|
77 |
with tarfile.open(tar_file, "w:gz") as tar:
|
78 |
tar.add(checkpoint_dir, arcname=os.path.basename(checkpoint_dir))
|
79 |
+
|
80 |
# Move the tar.gz file to the desired location
|
81 |
destination_path = "C:/Users/kazia/Downloads/resnet_c_s_redwood_finetuned.tar.gz"
|
82 |
shutil.move(tar_file, destination_path)
|
|
|
84 |
# Cleanup the intermediate checkpoint directory
|
85 |
shutil.rmtree(checkpoint_dir)
|
86 |
|
87 |
+
print("Fine-tuned model saved and compressed as tar.gz.")
|