AvianVision / utils.py
Vedmani's picture
added utils.py
00d5f74
raw
history blame contribute delete
No virus
2.93 kB
import torch
from pathlib import Path
import torch
from pathlib import Path
def save_model(model, optimizer, epoch, loss, directory, model_name='model', **kwargs):
"""
Save a PyTorch model checkpoint.
Args:
model: Trained model.
optimizer: Optimizer used for training.
epoch: The last epoch the model was trained on.
loss: The last loss recorded during training.
directory: The directory where to save the model.
model_name: Base name for the model file, defaults to 'model'.
kwargs: Additional keyword arguments representing metrics to be included in the filename.
To use the function, you would do something like this:
>>>save_checkpoint(model, optimizer, epoch, loss, './model_dir', f1_score=val_f1score)
"""
# Create the directory if it does not exist
Path(directory).mkdir(parents=True, exist_ok=True)
# Create the filename
metrics_str = '_'.join(f'{key}={value:.4f}' for key, value in kwargs.items())
filename = f'{directory}/{model_name}_epoch={epoch}_loss={loss:.4f}_{metrics_str}.pth'
# Save the model checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
**kwargs
}, filename)
def get_device() -> torch.device:
"""
Retrieves the appropriate Torch device for running computations.
Returns:
torch.device: The Torch device to be used for computations.
Raises:
None
Examples:
>>> device = get_device()
>>> print(device)
cuda
"""
if torch.cuda.is_available():
device = "cuda" # NVIDIA GPU
elif torch.backends.mps.is_available():
device = "mps" # Apple GPU
else:
device = "cpu" # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available
# print(f"Using {device} device")
return torch.device(device)
def load_checkpoint(model, optimizer, filename):
"""
Load a PyTorch model checkpoint.
Args:
model: Model to load the weights into.
optimizer: Optimizer to load the state into.
filename: The path of the checkpoint file.
Returns:
The epoch at which training was stopped, the last loss recorded, and any additional metrics.
"""
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# Extract additional metrics
metrics = {key: value for key, value in checkpoint.items() if
key not in ['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']}
return epoch, loss, metrics
# To use the function, you would do something like this:
# epoch, loss, metrics = load_checkpoint(model, optimizer, 'model_checkpoint.pth')