tinyvgg / train.py
ajitsi's picture
adding config files
9a437fa
"""
Trains a Pytorch image classification model using device agonstic code
"""
import os
import torch
import data_setup, engine, model_builder, utils
from torchvision import transforms
from timeit import default_timer as timer
from pathlib import Path
from get_data import fetch_data
import json
def train_torch():
# Load config.json
with open("config.json", "r") as f:
config = json.load(f)
# Access configuration parameters
input_shape = config["input_shape"]
hidden_units = config["hidden_units"]
# Setup hyperparameters
NUM_EPOCHS = config["num_epochs"]
BATCH_SIZE = config["batch_size"]
LEARNING_RATE = config["learning_rate"]
# Define the URL and paths
# Define the URL and paths
url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
data_path = Path("/data")
image_path = Path(data_path / "pizza_steak_sushi")
# Call the fetch_data function
fetch_data(url, data_path, image_path)
# Setup directories
train_dir = "/data/pizza_steak_sushi/train"
test_dir = "/data/pizza_steak_sushi/test"
# Setup target device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create transforms
data_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
# Create dataloaders with help from data_Setup.py
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir,
test_dir=test_dir,
transform=data_transform,
batch_size=BATCH_SIZE)
# Create model with help from model_builder.py
model = model_builder.TinyVGG(
input_shape = input_shape,
hidden_units=hidden_units,
output_shape=len(class_names)
).to(device)
# Set loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE)
# Start the timer
start_time = timer()
# Start training with help from engine.py
results = engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=NUM_EPOCHS,
device=device)
# End the timer and print out how long it took
end_time = timer()
print(f"[INFO] Total training time: {end_time-start_time:.3f} seconds")
print(f"[INFO] results dict: {results}")
# Save the model with help from utils.py
utils.save_model(model=model,
target_dir="models",
model_name="05_going_modular_script_mode_tinyvgg_model.pth")
if __name__ == '__main__':
train_torch()