|
|
"""A collection of useful helper functions""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
import json |
|
|
|
|
|
import torch |
|
|
from torch.profiler import profile, record_function, ProfilerActivity |
|
|
import pandas as pd |
|
|
from torchmetrics.functional import( |
|
|
scale_invariant_signal_noise_ratio as si_snr, |
|
|
signal_noise_ratio as snr, |
|
|
signal_distortion_ratio as sdr, |
|
|
scale_invariant_signal_distortion_ratio as si_sdr) |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class Params(): |
|
|
"""Class that loads hyperparameters from a json file. |
|
|
Example: |
|
|
``` |
|
|
params = Params(json_path) |
|
|
print(params.learning_rate) |
|
|
params.learning_rate = 0.5 # change the value of learning_rate in params |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self, json_path): |
|
|
with open(json_path) as f: |
|
|
params = json.load(f) |
|
|
self.__dict__.update(params) |
|
|
|
|
|
def save(self, json_path): |
|
|
with open(json_path, 'w') as f: |
|
|
json.dump(self.__dict__, f, indent=4) |
|
|
|
|
|
def update(self, json_path): |
|
|
"""Loads parameters from json file""" |
|
|
with open(json_path) as f: |
|
|
params = json.load(f) |
|
|
self.__dict__.update(params) |
|
|
|
|
|
@property |
|
|
def dict(self): |
|
|
"""Gives dict-like access to Params instance by `params.dict['learning_rate']""" |
|
|
return self.__dict__ |
|
|
|
|
|
def save_graph(train_metrics, test_metrics, save_dir): |
|
|
metrics = [snr, si_snr] |
|
|
results = {'train_loss': train_metrics['loss'], |
|
|
'test_loss' : test_metrics['loss']} |
|
|
|
|
|
for m_fn in metrics: |
|
|
results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__] |
|
|
results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__] |
|
|
|
|
|
results_pd = pd.DataFrame(results) |
|
|
|
|
|
results_pd.to_csv(os.path.join(save_dir, 'results.csv')) |
|
|
|
|
|
fig, temp_ax = plt.subplots(2, 3, figsize=(15,10)) |
|
|
axs=[] |
|
|
for i in temp_ax: |
|
|
for j in i: |
|
|
axs.append(j) |
|
|
|
|
|
x = range(len(train_metrics['loss'])) |
|
|
axs[0].plot(x, train_metrics['loss'], label='train') |
|
|
axs[0].plot(x, test_metrics['loss'], label='test') |
|
|
axs[0].set(ylabel='Loss') |
|
|
axs[0].set(xlabel='Epoch') |
|
|
axs[0].set_title('loss',fontweight='bold') |
|
|
axs[0].legend() |
|
|
|
|
|
for i in range(len(metrics)): |
|
|
axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train') |
|
|
axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test') |
|
|
axs[i+1].set(xlabel='Epoch') |
|
|
axs[i+1].set_title(metrics[i].__name__,fontweight='bold') |
|
|
axs[i+1].legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(save_dir, 'results.png')) |
|
|
plt.close(fig) |
|
|
|
|
|
def set_logger(log_path): |
|
|
"""Set the logger to log info in terminal and file `log_path`. |
|
|
In general, it is useful to have a logger so that every output to the terminal is saved |
|
|
in a permanent file. Here we save it to `model_dir/train.log`. |
|
|
Example: |
|
|
``` |
|
|
logging.info("Starting training...") |
|
|
``` |
|
|
Args: |
|
|
log_path: (string) where to log |
|
|
""" |
|
|
logger = logging.getLogger() |
|
|
logger.setLevel(logging.INFO) |
|
|
logger.handlers.clear() |
|
|
|
|
|
|
|
|
file_handler = logging.FileHandler(log_path) |
|
|
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) |
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
|
|
|
stream_handler = logging.StreamHandler() |
|
|
stream_handler.setFormatter(logging.Formatter('%(message)s')) |
|
|
logger.addHandler(stream_handler) |
|
|
|
|
|
def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False): |
|
|
"""Loads model parameters (state_dict) from file_path. |
|
|
|
|
|
Args: |
|
|
checkpoint: (string) filename which needs to be loaded |
|
|
model: (torch.nn.Module) model for which the parameters are loaded |
|
|
data_parallel: (bool) if the model is a data parallel model |
|
|
""" |
|
|
if not os.path.exists(checkpoint): |
|
|
raise("File doesn't exist {}".format(checkpoint)) |
|
|
|
|
|
state_dict = torch.load(checkpoint) |
|
|
|
|
|
if data_parallel: |
|
|
state_dict['model_state_dict'] = { |
|
|
'module.' + k: state_dict['model_state_dict'][k] |
|
|
for k in state_dict['model_state_dict'].keys()} |
|
|
model.load_state_dict(state_dict['model_state_dict']) |
|
|
|
|
|
if optim is not None: |
|
|
optim.load_state_dict(state_dict['optim_state_dict']) |
|
|
|
|
|
if lr_sched is not None: |
|
|
lr_sched.load_state_dict(state_dict['lr_sched_state_dict']) |
|
|
|
|
|
return state_dict['epoch'], state_dict['train_metrics'], \ |
|
|
state_dict['val_metrics'] |
|
|
|
|
|
def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None, |
|
|
train_metrics=None, val_metrics=None, data_parallel=False): |
|
|
"""Saves model parameters (state_dict) to file_path. |
|
|
|
|
|
Args: |
|
|
checkpoint: (string) filename which needs to be loaded |
|
|
model: (torch.nn.Module) model for which the parameters are loaded |
|
|
data_parallel: (bool) if the model is a data parallel model |
|
|
""" |
|
|
if os.path.exists(checkpoint): |
|
|
raise("File already exists {}".format(checkpoint)) |
|
|
|
|
|
model_state_dict = model.state_dict() |
|
|
if data_parallel: |
|
|
model_state_dict = { |
|
|
k.partition('module.')[2]: |
|
|
model_state_dict[k] for k in model_state_dict.keys()} |
|
|
|
|
|
optim_state_dict = None if not optim else optim.state_dict() |
|
|
lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict() |
|
|
|
|
|
state_dict = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model_state_dict, |
|
|
'optim_state_dict': optim_state_dict, |
|
|
'lr_sched_state_dict': lr_sched_state_dict, |
|
|
'train_metrics': train_metrics, |
|
|
'val_metrics': val_metrics |
|
|
} |
|
|
|
|
|
torch.save(state_dict, checkpoint) |
|
|
|
|
|
def model_size(model): |
|
|
""" |
|
|
Returns size of the `model` in millions of parameters. |
|
|
""" |
|
|
num_train_params = sum( |
|
|
p.numel() for p in model.parameters() if p.requires_grad) |
|
|
return num_train_params / 1e6 |
|
|
|
|
|
def run_time(model, inputs, profiling=False): |
|
|
""" |
|
|
Returns runtime of a model in ms. |
|
|
""" |
|
|
|
|
|
for _ in range(100): |
|
|
output = model(*inputs) |
|
|
|
|
|
with profile(activities=[ProfilerActivity.CPU], |
|
|
record_shapes=True) as prof: |
|
|
with record_function("model_inference"): |
|
|
output = model(*inputs) |
|
|
|
|
|
|
|
|
if profiling: |
|
|
print(prof.key_averages().table(sort_by="self_cpu_time_total", |
|
|
row_limit=20)) |
|
|
|
|
|
|
|
|
return prof.profiler.self_cpu_time_total / 1000 |
|
|
|
|
|
def format_lr_info(optimizer): |
|
|
lr_info = "" |
|
|
for i, pg in enumerate(optimizer.param_groups): |
|
|
lr_info += " {group %d: params=%.5fM lr=%.1E}" % ( |
|
|
i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr']) |
|
|
return lr_info |
|
|
|
|
|
|