|
import os
|
|
|
|
class WandbLogger:
|
|
"""
|
|
Log using `Weights and Biases`.
|
|
"""
|
|
def __init__(self):
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
raise ImportError(
|
|
"To use the Weights and Biases Logger please install wandb."
|
|
"Run `pip install wandb` to install it."
|
|
)
|
|
|
|
self._wandb = wandb
|
|
|
|
|
|
if self._wandb.run is None:
|
|
self._wandb.init(
|
|
project='diff_derain',
|
|
dir='./experiments'
|
|
)
|
|
|
|
self.config = self._wandb.config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.eval_table = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.infer_table = None
|
|
|
|
def log_metrics(self, metrics, commit=True):
|
|
"""
|
|
Log train/validation metrics onto W&B.
|
|
|
|
metrics: dictionary of metrics to be logged
|
|
"""
|
|
self._wandb.log(metrics, commit=commit)
|
|
|
|
def log_image(self, key_name, image_array):
|
|
"""
|
|
Log image array onto W&B.
|
|
|
|
key_name: name of the key
|
|
image_array: numpy array of image.
|
|
"""
|
|
self._wandb.log({key_name: self._wandb.Image(image_array)})
|
|
|
|
def log_images(self, key_name, list_images):
|
|
"""
|
|
Log list of image array onto W&B
|
|
|
|
key_name: name of the key
|
|
list_images: list of numpy image arrays
|
|
"""
|
|
self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})
|
|
|
|
def log_checkpoint(self, current_epoch, current_step):
|
|
"""
|
|
Log the model checkpoint as W&B artifacts
|
|
|
|
current_epoch: the current epoch
|
|
current_step: the current batch step
|
|
"""
|
|
model_artifact = self._wandb.Artifact(
|
|
self._wandb.run.id + "_model", type="model"
|
|
)
|
|
|
|
gen_path = os.path.join(
|
|
self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
|
|
opt_path = os.path.join(
|
|
self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))
|
|
|
|
model_artifact.add_file(gen_path)
|
|
model_artifact.add_file(opt_path)
|
|
self._wandb.log_artifact(model_artifact, aliases=["latest"])
|
|
|
|
def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
|
|
"""
|
|
Add data row-wise to the initialized table.
|
|
"""
|
|
if psnr is not None and ssim is not None:
|
|
self.eval_table.add_data(
|
|
self._wandb.Image(fake_img),
|
|
self._wandb.Image(sr_img),
|
|
self._wandb.Image(hr_img),
|
|
psnr,
|
|
ssim
|
|
)
|
|
else:
|
|
self.infer_table.add_data(
|
|
self._wandb.Image(fake_img),
|
|
self._wandb.Image(sr_img),
|
|
self._wandb.Image(hr_img)
|
|
)
|
|
|
|
def log_eval_table(self, commit=False):
|
|
"""
|
|
Log the table
|
|
"""
|
|
if self.eval_table:
|
|
self._wandb.log({'eval_data': self.eval_table}, commit=commit)
|
|
elif self.infer_table:
|
|
self._wandb.log({'infer_data': self.infer_table}, commit=commit)
|
|
|