Spaces:
Runtime error
Runtime error
| """Module to define utility functions for the project.""" | |
| import os | |
| import torch | |
| def get_num_workers(model_run_location): | |
| """Given a run mode, return the number of workers to be used for data loading.""" | |
| # calculate the number of workers | |
| num_workers = (os.cpu_count() - 1) if os.cpu_count() > 3 else 2 | |
| # If run_mode is local, use only 2 workers | |
| num_workers = num_workers if model_run_location == "colab" else 0 | |
| return num_workers | |
| # Function to save the model | |
| # https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/ | |
| def save_model(epoch, model, optimizer, scheduler, batch_size, criterion, file_name): | |
| """ | |
| Function to save the trained model along with other information to disk. | |
| """ | |
| # print(f"Saving model from epoch {epoch}...") | |
| torch.save( | |
| { | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "batch_size": batch_size, | |
| "loss": criterion, | |
| }, | |
| file_name, | |
| ) | |
| # Given a list of train_losses, train_accuracies, test_losses, | |
| # test_accuracies, loop through epoch and print the metrics | |
| def pretty_print_metrics(num_epochs, results): | |
| """ | |
| Function to print the metrics in a pretty format. | |
| """ | |
| # Extract train_losses, train_acc, test_losses, test_acc from results | |
| train_losses = results["train_loss"] | |
| train_acc = results["train_acc"] | |
| test_losses = results["test_loss"] | |
| test_acc = results["test_acc"] | |
| for i in range(num_epochs): | |
| print( | |
| f"Epoch: {i+1:02d}, Train Loss: {train_losses[i]:.4f}, " | |
| f"Test Loss: {test_losses[i]:.4f}, Train Accuracy: {train_acc[i]:.4f}, " | |
| f"Test Accuracy: {test_acc[i]:.4f}" | |
| ) | |
| # Given a file path, extract the folder path and create folder recursively if it does not already exist | |
| def create_folder_if_not_exists(file_path): | |
| """ | |
| Function to create a folder if it does not exist. | |
| """ | |
| # Extract the folder path | |
| folder_path = os.path.dirname(file_path) | |
| print(f"Folder path: {folder_path}") | |
| # Create the folder if it does not exist | |
| if not os.path.exists(folder_path): | |
| os.makedirs(folder_path,exist_ok=True) | |
| print(f"Created folder: {folder_path}") |