|
""" |
|
Trains a Pytorch image classification model using device agonstic code |
|
""" |
|
|
|
import os |
|
import torch |
|
import data_setup, engine, model_builder, utils |
|
from torchvision import transforms |
|
from timeit import default_timer as timer |
|
from pathlib import Path |
|
from get_data import fetch_data |
|
import json |
|
|
|
|
|
|
|
|
|
def train_torch(): |
|
|
|
with open("config.json", "r") as f: |
|
config = json.load(f) |
|
|
|
|
|
input_shape = config["input_shape"] |
|
hidden_units = config["hidden_units"] |
|
|
|
|
|
NUM_EPOCHS = config["num_epochs"] |
|
BATCH_SIZE = config["batch_size"] |
|
LEARNING_RATE = config["learning_rate"] |
|
|
|
|
|
|
|
url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip" |
|
data_path = Path("/data") |
|
image_path = Path(data_path / "pizza_steak_sushi") |
|
|
|
|
|
fetch_data(url, data_path, image_path) |
|
|
|
|
|
train_dir = "/data/pizza_steak_sushi/train" |
|
test_dir = "/data/pizza_steak_sushi/test" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
data_transform = transforms.Compose([ |
|
transforms.Resize((64, 64)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir, |
|
test_dir=test_dir, |
|
transform=data_transform, |
|
batch_size=BATCH_SIZE) |
|
|
|
model = model_builder.TinyVGG( |
|
input_shape = input_shape, |
|
hidden_units=hidden_units, |
|
output_shape=len(class_names) |
|
).to(device) |
|
|
|
|
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
optimizer= torch.optim.Adam(model.parameters(), |
|
lr=LEARNING_RATE) |
|
|
|
|
|
start_time = timer() |
|
|
|
|
|
results = engine.train(model=model, |
|
train_dataloader=train_dataloader, |
|
test_dataloader=test_dataloader, |
|
loss_fn=loss_fn, |
|
optimizer=optimizer, |
|
epochs=NUM_EPOCHS, |
|
device=device) |
|
|
|
end_time = timer() |
|
print(f"[INFO] Total training time: {end_time-start_time:.3f} seconds") |
|
print(f"[INFO] results dict: {results}") |
|
|
|
utils.save_model(model=model, |
|
target_dir="models", |
|
model_name="05_going_modular_script_mode_tinyvgg_model.pth") |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
train_torch() |
|
|