Akazi commited on
Commit
74a67e7
1 Parent(s): c96be90

Rename model.py to train.py

Browse files
Files changed (1) hide show
  1. model.py → train.py +7 -7
model.py → train.py RENAMED
@@ -1,16 +1,15 @@
1
- import torch
2
- import torchvision.transforms as transforms
3
- from torchvision.datasets import ImageFolder
4
- from torchvision import models
5
  import os
6
  import tarfile
7
  import shutil
8
- from PIL import Image
 
 
 
 
9
 
10
  # Set the paths
11
  data_dir = "C:/Users/kazia/Downloads/Trees"
12
  train_dir = os.path.join(data_dir, "Train")
13
- test_dir = os.path.join(data_dir, "Test")
14
  checkpoint_dir = "resnet_c_s_redwood_finetuned"
15
  tar_file = "resnet_c_s_redwood_finetuned.tar.gz"
16
 
@@ -77,6 +76,7 @@ torch.save(model.state_dict(), os.path.join(checkpoint_dir, "model.pth"))
77
  print("Creating tar.gz file...")
78
  with tarfile.open(tar_file, "w:gz") as tar:
79
  tar.add(checkpoint_dir, arcname=os.path.basename(checkpoint_dir))
 
80
  # Move the tar.gz file to the desired location
81
  destination_path = "C:/Users/kazia/Downloads/resnet_c_s_redwood_finetuned.tar.gz"
82
  shutil.move(tar_file, destination_path)
@@ -84,4 +84,4 @@ shutil.move(tar_file, destination_path)
84
  # Cleanup the intermediate checkpoint directory
85
  shutil.rmtree(checkpoint_dir)
86
 
87
- print("Fine-tuned model saved and compressed as tar.gz.")
 
 
 
 
 
1
  import os
2
  import tarfile
3
  import shutil
4
+ import torch
5
+ import torchvision.models as models
6
+ import torchvision.transforms as transforms
7
+ from torchvision.datasets import ImageFolder
8
+
9
 
10
  # Set the paths
11
  data_dir = "C:/Users/kazia/Downloads/Trees"
12
  train_dir = os.path.join(data_dir, "Train")
 
13
  checkpoint_dir = "resnet_c_s_redwood_finetuned"
14
  tar_file = "resnet_c_s_redwood_finetuned.tar.gz"
15
 
 
76
  print("Creating tar.gz file...")
77
  with tarfile.open(tar_file, "w:gz") as tar:
78
  tar.add(checkpoint_dir, arcname=os.path.basename(checkpoint_dir))
79
+
80
  # Move the tar.gz file to the desired location
81
  destination_path = "C:/Users/kazia/Downloads/resnet_c_s_redwood_finetuned.tar.gz"
82
  shutil.move(tar_file, destination_path)
 
84
  # Cleanup the intermediate checkpoint directory
85
  shutil.rmtree(checkpoint_dir)
86
 
87
+ print("Fine-tuned model saved and compressed as tar.gz.")