|
import os |
|
import tarfile |
|
import shutil |
|
import torch |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
from torchvision.datasets import ImageFolder |
|
|
|
|
|
|
|
data_dir = "C:/Users/kazia/Downloads/Trees" |
|
train_dir = os.path.join(data_dir, "Train") |
|
checkpoint_dir = "resnet_c_s_redwood_finetuned" |
|
tar_file = "resnet_c_s_redwood_finetuned.tar.gz" |
|
|
|
|
|
image_transforms = transforms.Compose([ |
|
transforms.RandomRotation(10), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ColorJitter(brightness=0.2), |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
|
|
print("Loading dataset with data augmentation...") |
|
train_dataset = ImageFolder(train_dir, transform=image_transforms) |
|
|
|
|
|
print("Loading pre-trained model...") |
|
model = models.resnet50(pretrained=True) |
|
|
|
|
|
num_features = model.fc.in_features |
|
num_labels = 2 |
|
model.fc = torch.nn.Linear(num_features, num_labels) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) |
|
|
|
|
|
def train(model, optimizer, lr_scheduler, dataloader): |
|
model.train() |
|
for batch in dataloader: |
|
inputs = batch[0].to(device) |
|
labels = batch[1].to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = torch.nn.functional.cross_entropy(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
lr_scheduler.step() |
|
|
|
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
|
|
num_epochs = 10 |
|
print("Starting training...") |
|
for epoch in range(num_epochs): |
|
print(f"Epoch {epoch+1}/{num_epochs}") |
|
train(model, optimizer, lr_scheduler, train_dataloader) |
|
|
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
torch.save(model.state_dict(), os.path.join(checkpoint_dir, "model.pth")) |
|
|
|
|
|
print("Creating tar.gz file...") |
|
with tarfile.open(tar_file, "w:gz") as tar: |
|
tar.add(checkpoint_dir, arcname=os.path.basename(checkpoint_dir)) |
|
|
|
|
|
destination_path = "C:/Users/kazia/Downloads/resnet_c_s_redwood_finetuned.tar.gz" |
|
shutil.move(tar_file, destination_path) |
|
|
|
|
|
shutil.rmtree(checkpoint_dir) |
|
|
|
print("Fine-tuned model saved and compressed as tar.gz.") |
|
|