|
"""
|
|
contains various utility functions for pytorch model training and saving
|
|
"""
|
|
import torch
|
|
from pathlib import Path
|
|
import matplotlib.pyplot as plt
|
|
import torchvision
|
|
from PIL import Image
|
|
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
|
def save_model(model: torch.nn.Module,
|
|
target_dir: str,
|
|
model_name: str):
|
|
"""Saves a pytorch model to a target directory
|
|
|
|
Args:
|
|
model: target pytorch model
|
|
target_dir: string of target directory path to store the saved models
|
|
model_name: a filename for the saved model. Should be included either ".pth" or ".pt" as
|
|
the file extension.
|
|
"""
|
|
|
|
target_dir_path = Path(target_dir)
|
|
target_dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model name should end with .pt or .pth"
|
|
model_save_path = target_dir_path / model_name
|
|
|
|
|
|
print(f"[INFO] Saving model to: {model_save_path}")
|
|
torch.save(obj=model.state_dict(), f=model_save_path)
|
|
|
|
def pred_and_plot_image(
|
|
model: torch.nn.Module,
|
|
image_path: str,
|
|
class_names: list[str] = None,
|
|
transform=None,
|
|
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
):
|
|
"""Makes a prediction on a target image with a trained model and plots the image.
|
|
|
|
Args:
|
|
model (torch.nn.Module): trained PyTorch image classification model.
|
|
image_path (str): filepath to target image.
|
|
class_names (List[str], optional): different class names for target image. Defaults to None.
|
|
transform (_type_, optional): transform of target image. Defaults to None.
|
|
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
|
|
|
|
Returns:
|
|
Matplotlib plot of target image and model prediction as title.
|
|
|
|
Example usage:
|
|
pred_and_plot_image(model=model,
|
|
image="some_image.jpeg",
|
|
class_names=["class_1", "class_2", "class_3"],
|
|
transform=torchvision.transforms.ToTensor(),
|
|
device=device)
|
|
"""
|
|
|
|
|
|
img_list = Image.open(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
if transform:
|
|
target_image = transform(img_list)
|
|
|
|
|
|
model.to(device)
|
|
|
|
|
|
model.eval()
|
|
with torch.inference_mode():
|
|
|
|
target_image = target_image.unsqueeze(dim=0)
|
|
|
|
|
|
target_image_pred = model(target_image.to(device))
|
|
|
|
|
|
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
|
|
|
|
|
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
|
|
|
|
|
plt.imshow(
|
|
target_image.squeeze().permute(1, 2, 0)
|
|
)
|
|
if class_names:
|
|
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
|
else:
|
|
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
|
plt.title(title)
|
|
plt.axis(False)
|
|
|
|
def set_seeds(seed: int=42):
|
|
"""Sets random sets for torch operations.
|
|
|
|
Args:
|
|
seed (int, optional): Random seed to set. Defaults to 42.
|
|
"""
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
def create_writer(experiment_name: str, model_name: str, extra: str=None) -> torch.utils.tensorboard.writer.SummaryWriter():
|
|
"""
|
|
creates a torch.utils.tensorboard.writer.SummaryWriter() instance saving to a
|
|
specific log_dir.
|
|
|
|
log_dir is a combination of runs/timestamp/experiment_name/model_name/extra.
|
|
|
|
where timestamp is the current date in YYYY-MM-DD format.
|
|
|
|
Args:
|
|
experiment_name (str): Name of experiment
|
|
model_name (str): model name
|
|
extra (str, optional): anything extra to add to the directory. Defaults is None
|
|
|
|
Returns:
|
|
torch.utils.tensorboard.writer.SummaryWriter(): Instance of a writer saving to log_dir
|
|
|
|
Examples usage:
|
|
this is gonna create writer saving to "runs/2022-06-04/data_10_percent/effnetb2/5_epochs"
|
|
|
|
writer = create_writer(experiment_name="data_10_percent", model_name="effnetb2", extra="5_epochs")
|
|
|
|
This is the same as:
|
|
writer = SummaryWriter(log_dir="runs/2022-06-04/data_10_percent/effnetb2/5_epochs")
|
|
"""
|
|
|
|
from datetime import datetime
|
|
import os
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d")
|
|
|
|
if extra:
|
|
|
|
log_dir = os.path.join("runs", timestamp, experiment_name, model_name, extra)
|
|
else:
|
|
log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
|
|
|
|
print(f"[INFO] Created SummaryWriter(), saving to: {log_dir}")
|
|
|
|
return SummaryWriter(log_dir=log_dir)
|
|
|