Akazi's picture
Rename model.py to train.py
74a67e7
import os
import tarfile
import shutil
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
# Set the paths
data_dir = "C:/Users/kazia/Downloads/Trees"
train_dir = os.path.join(data_dir, "Train")
checkpoint_dir = "resnet_c_s_redwood_finetuned"
tar_file = "resnet_c_s_redwood_finetuned.tar.gz"
# Define the image transformations
image_transforms = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Assuming RGB images
])
# Load the dataset with data augmentation
print("Loading dataset with data augmentation...")
train_dataset = ImageFolder(train_dir, transform=image_transforms)
# Load the pre-trained ResNet-50 model
print("Loading pre-trained model...")
model = models.resnet50(pretrained=True)
# Modify the model for your classification task
num_features = model.fc.in_features
num_labels = 2 # Number of classes (Sierra Redwood, Coastal Redwood)
model.fc = torch.nn.Linear(num_features, num_labels)
# Enable GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define the optimizer and learning rate scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# Define the training loop
def train(model, optimizer, lr_scheduler, dataloader):
model.train()
for batch in dataloader:
inputs = batch[0].to(device)
labels = batch[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
lr_scheduler.step()
# Train the model
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
num_epochs = 10
print("Starting training...")
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train(model, optimizer, lr_scheduler, train_dataloader)
# Save the fine-tuned model
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(checkpoint_dir, "model.pth"))
# Create a tar.gz file
print("Creating tar.gz file...")
with tarfile.open(tar_file, "w:gz") as tar:
tar.add(checkpoint_dir, arcname=os.path.basename(checkpoint_dir))
# Move the tar.gz file to the desired location
destination_path = "C:/Users/kazia/Downloads/resnet_c_s_redwood_finetuned.tar.gz"
shutil.move(tar_file, destination_path)
# Cleanup the intermediate checkpoint directory
shutil.rmtree(checkpoint_dir)
print("Fine-tuned model saved and compressed as tar.gz.")