swetchareddytukkani
Initial commit with PhysMamba rPPG application
1c6711c
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import os
import pickle
class BaseTrainer:
@staticmethod
def add_trainer_args(parser):
"""Adds arguments to Paser for training process"""
parser.add_argument('--lr', default=None, type=float)
parser.add_argument('--model_file_name', default=None, type=float)
return parser
def __init__(self):
pass
def train(self, data_loader):
pass
def valid(self, data_loader):
pass
def test(self):
pass
def save_test_outputs(self, predictions, labels, config):
output_dir = config.TEST.OUTPUT_SAVE_DIR
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# Filename ID to be used in any output files that get saved
if config.TOOLBOX_MODE == 'train_and_test':
filename_id = self.model_file_name
elif config.TOOLBOX_MODE == 'only_test':
model_file_root = config.INFERENCE.MODEL_PATH.split("/")[-1].split(".pth")[0]
filename_id = model_file_root + "_" + config.TEST.DATA.DATASET
else:
raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!')
output_path = os.path.join(output_dir, filename_id + '_outputs.pickle')
data = dict()
data['predictions'] = predictions
data['labels'] = labels
data['label_type'] = config.TEST.DATA.PREPROCESS.LABEL_TYPE
data['fs'] = config.TEST.DATA.FS
with open(output_path, 'wb') as handle: # save out frame dict pickle file
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('Saving outputs to:', output_path)
def plot_losses_and_lrs(self, train_loss, valid_loss, lrs, config):
output_dir = os.path.join(config.LOG.PATH, config.TRAIN.DATA.EXP_DATA_NAME, 'plots')
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# Filename ID to be used in plots that get saved
if config.TOOLBOX_MODE == 'train_and_test':
filename_id = self.model_file_name
else:
raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!')
# Create a single plot for training and validation losses
plt.figure(figsize=(10, 6))
epochs = range(0, len(train_loss)) # Integer values for x-axis
plt.plot(epochs, train_loss, label='Training Loss')
if len(valid_loss) > 0:
plt.plot(epochs, valid_loss, label='Validation Loss')
else:
print("The list of validation losses is empty. The validation loss will not be plotted!")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'{filename_id} Losses')
plt.legend()
plt.xticks(epochs)
# Set y-axis ticks with more granularity
ax = plt.gca()
ax.yaxis.set_major_locator(MaxNLocator(integer=False, prune='both'))
loss_plot_filename = os.path.join(output_dir, filename_id + '_losses.pdf')
plt.savefig(loss_plot_filename, dpi=300)
plt.close()
# Create a separate plot for learning rates
plt.figure(figsize=(6, 4))
scheduler_steps = range(0, len(lrs))
plt.plot(scheduler_steps, lrs, label='Learning Rate')
plt.xlabel('Scheduler Step')
plt.ylabel('Learning Rate')
plt.title(f'{filename_id} LR Schedule')
plt.legend()
# Set y-axis values in scientific notation
ax = plt.gca()
ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True, useOffset=False))
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0)) # Force scientific notation
lr_plot_filename = os.path.join(output_dir, filename_id + '_learning_rates.pdf')
plt.savefig(lr_plot_filename, bbox_inches='tight', dpi=300)
plt.close()
print('Saving plots of losses and learning rates to:', output_dir)