""" Trains a Pytorch image classification model using device agonstic code """ import os import torch import data_setup, engine, model_builder, utils from torchvision import transforms from timeit import default_timer as timer from pathlib import Path from get_data import fetch_data import json def train_torch(): # Load config.json with open("config.json", "r") as f: config = json.load(f) # Access configuration parameters input_shape = config["input_shape"] hidden_units = config["hidden_units"] # Setup hyperparameters NUM_EPOCHS = config["num_epochs"] BATCH_SIZE = config["batch_size"] LEARNING_RATE = config["learning_rate"] # Define the URL and paths # Define the URL and paths url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip" data_path = Path("/data") image_path = Path(data_path / "pizza_steak_sushi") # Call the fetch_data function fetch_data(url, data_path, image_path) # Setup directories train_dir = "/data/pizza_steak_sushi/train" test_dir = "/data/pizza_steak_sushi/test" # Setup target device device = "cuda" if torch.cuda.is_available() else "cpu" # Create transforms data_transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor() ]) # Create dataloaders with help from data_Setup.py train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir, test_dir=test_dir, transform=data_transform, batch_size=BATCH_SIZE) # Create model with help from model_builder.py model = model_builder.TinyVGG( input_shape = input_shape, hidden_units=hidden_units, output_shape=len(class_names) ).to(device) # Set loss and optimizer loss_fn = torch.nn.CrossEntropyLoss() optimizer= torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) # Start the timer start_time = timer() # Start training with help from engine.py results = engine.train(model=model, train_dataloader=train_dataloader, test_dataloader=test_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=NUM_EPOCHS, device=device) # End the timer and print out how long it took end_time = timer() print(f"[INFO] Total training time: {end_time-start_time:.3f} seconds") print(f"[INFO] results dict: {results}") # Save the model with help from utils.py utils.save_model(model=model, target_dir="models", model_name="05_going_modular_script_mode_tinyvgg_model.pth") if __name__ == '__main__': train_torch()