yourusername's picture
:beers: cheers
66a6dc0
raw history blame
No virus
2.98 kB
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import deepafx_st.utils as utils
class LogParametersCallback(pl.callbacks.Callback):
def __init__(self, num_examples=4):
super().__init__()
self.num_examples = 4
def on_validation_epoch_start(self, trainer, pl_module):
"""At the start of validation init storage for parameters."""
self.params = []
def on_validation_batch_end(
self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
dataloader_idx,
):
"""Called when the validation batch ends.
Here we log the parameters only from the first batch.
"""
if outputs is not None and batch_idx == 0:
examples = np.min([self.num_examples, outputs["x"].shape[0]])
for n in range(examples):
self.log_parameters(
outputs,
n,
pl_module.processor.ports,
trainer.global_step,
trainer.logger,
True if batch_idx == 0 else False,
)
def on_validation_epoch_end(self, trainer, pl_module):
pass
def log_parameters(self, outputs, batch_idx, ports, global_step, logger, log=True):
p = outputs["p"][batch_idx, ...]
table = ""
# table += f"""## {plugin["name"]}\n"""
table += "| Index| Name | Value | Units | Min | Max | Default | Raw Value | \n"
table += "|------|------|------:|:------|----:|----:|--------:| ---------:| \n"
start_idx = 0
# set plugin parameters based on provided normalized parameters
for port_list in ports:
for pidx, port in enumerate(port_list):
param_max = port["max"]
param_min = port["min"]
param_name = port["name"]
param_default = port["default"]
param_units = port["units"]
param_val = p[start_idx]
denorm_val = utils.denormalize(param_val, param_max, param_min)
# add values to table in row
table += f"| {start_idx + 1} | {param_name} "
if np.abs(denorm_val) > 10:
table += f"| {denorm_val:0.1f} "
table += f"| {param_units} "
table += f"| {param_min:0.1f} | {param_max:0.1f} "
table += f"| {param_default:0.1f} "
else:
table += f"| {denorm_val:0.3f} "
table += f"| {param_units} "
table += f"| {param_min:0.3f} | {param_max:0.3f} "
table += f"| {param_default:0.3f} "
table += f"| {np.squeeze(param_val):0.2f} | \n"
start_idx += 1
table += "\n\n"
if log:
logger.experiment.add_text(f"params/{batch_idx+1}", table, global_step)