| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import torch |
| | import logging |
| | import numpy as np |
| | import io |
| | import math |
| | import random |
| | from PIL import Image |
| | from matplotlib import pyplot as plt |
| | from torch.utils.tensorboard import SummaryWriter |
| | from torchvision import transforms |
| | from .functions import REG_FUNCTION_MAP |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def register( |
| | flags: dict, |
| | tensor: float | torch.Tensor, |
| | valid_mask: torch.Tensor, |
| | epoch: int, |
| | writer: SummaryWriter, |
| | logger: logging.Logger, |
| | tensorboard_required: bool, |
| | parameter_name: str = '' |
| | ): |
| | """ |
| | Registers a parameter according to the register flags (DEFAULT_WATCHER style). |
| | |
| | :param flags: A specific watch flag. |
| | :param tensor: The tensor to register. |
| | :param valid_mask: The valid mask to apply. |
| | :param epoch: The current epoch. |
| | :param writer: The tensorboard writer. |
| | :param logger: The logger. |
| | :param tensorboard_required: Whether the tensorboard writer is required. |
| | :param parameter_name: The name of the parameter. |
| | :return: |
| | """ |
| | |
| | if isinstance(tensor, torch.nn.Parameter): |
| | flag_type = 'parameters' |
| | elif isinstance(tensor, torch.Tensor): |
| | |
| | flag_type = 'activations' |
| | elif isinstance(tensor, float): |
| | flag_type = 'train' |
| | else: |
| | raise ValueError(f"{type(tensor)} is not a torch.nn.Parameter or torch.Tensor.") |
| |
|
| | |
| | safe_names = list() |
| | |
| | if flag_type == 'parameters': |
| | for flag_key, flag_value in flags['parameters'].items(): |
| | |
| | if flag_value: |
| | safe_names.append((f'{flag_type}/{flag_key}/{parameter_name}/', flag_key)) |
| | else: |
| | safe_names.append((f'{flag_type}/{parameter_name}/', '')) |
| |
|
| |
|
| | |
| | for name, flag_key in safe_names: |
| | |
| | transformation = None |
| | if isinstance(tensor, torch.nn.Parameter): |
| | if tensor.grad is not None and 'grad' in flag_key: |
| | transformation = REG_FUNCTION_MAP[flag_key](tensor, valid_mask) |
| | else: |
| | transformation = float(tensor) if tensor is not None else None |
| | |
| | if transformation is not None: |
| | write_tensorboard( |
| | name=name, |
| | value=transformation, |
| | epoch=epoch, |
| | writer=writer, |
| | logger=logger, |
| | tensorboard_required=tensorboard_required, |
| | ) |
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def register_replay( |
| | predicted: torch.Tensor, |
| | target: torch.Tensor, |
| | epoch: int, |
| | writer: SummaryWriter, |
| | logger: logging.Logger, |
| | valid_mask: torch.Tensor = Ellipsis, |
| | element: int = None, |
| | tensorboard_required: bool = True, |
| | ) -> plt.Figure: |
| | """ |
| | Registers a replay as an image. |
| | :param predicted: The predicted value (prediction). |
| | :param target: The expected value (labels). |
| | :param epoch: The current epoch. |
| | :param writer: The tensorboard writer. |
| | :param logger: The logger. |
| | :param valid_mask: A valid mask tensor of same shape. False positions are ignored (valid mask). |
| | :param element: The element to register, None chooses a random batch element. |
| | :param tensorboard_required: Whether the tensorboard writer is required. |
| | :return: A matplotlib figure. |
| | """ |
| | |
| | if element is None: |
| | element = random.randint(0, len(predicted) - 1) |
| | else: |
| | element = min(len(predicted) - 1, max(0, element)) |
| |
|
| | |
| | predicted_np = predicted[element].detach().cpu().numpy() |
| | target_np = target[element].detach().cpu().numpy() |
| |
|
| | |
| | if not target_np.shape: |
| | target_np_aux = np.zeros_like(predicted_np) |
| | target_np_aux[target_np] = 1. |
| | target_np = target_np_aux |
| | del target_np_aux |
| |
|
| | |
| | if valid_mask is not None: |
| | mask_np = valid_mask[element].detach().cpu().numpy().astype(bool) |
| | else: |
| | mask_np = np.ones_like(predicted_np, dtype=bool) |
| |
|
| | |
| | predicted_flat = predicted_np[mask_np].flatten() |
| | target_flat = target_np[mask_np].flatten() |
| |
|
| | |
| | s = predicted_flat.shape[0] |
| | b = math.ceil(math.sqrt(s)) |
| | total = b * b |
| | pad = total - s |
| |
|
| | |
| | predicted_padded = np.pad(predicted_flat, (0, pad), constant_values=0.0).reshape(b, b) |
| | target_padded = np.pad(target_flat, (0, pad), constant_values=0.0).reshape(b, b) |
| |
|
| | |
| | fig, axs = plt.subplots(1, 2, figsize=(10, 5)) |
| | plot_with_values(axs[0], predicted_padded, "Predicted (y_hat)") |
| | plot_with_values(axs[1], target_padded, "Target (y)") |
| | plt.tight_layout() |
| | write_tensorboard( |
| | 'replay/', |
| | fig, |
| | epoch=epoch, |
| | writer=writer, |
| | logger=logger, |
| | tensorboard_required=tensorboard_required, |
| | ) |
| | return fig |
| |
|
| | def plot_with_values(ax, data, title): |
| | """ |
| | Plots data with values and title. |
| | :param ax: A matplotlib axes. |
| | :param data: A numpy array. |
| | :param title: The title of the plot. |
| | :return: |
| | """ |
| | ax.imshow(data, cmap='viridis', interpolation='nearest') |
| | ax.set_title(title) |
| | ax.axis('off') |
| | for i in range(data.shape[0]): |
| | for j in range(data.shape[1]): |
| | text_color = "white" if data[i, j] < 0.5 else "black" |
| | ax.text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=8) |
| |
|
| | |
| | |
| | |
| | def write_tensorboard( |
| | name: str, |
| | value: int | float | plt.Figure | np.ndarray | torch.Tensor, |
| | epoch: int, |
| | writer: SummaryWriter, |
| | logger: logging.Logger, |
| | tensorboard_required: bool = True, |
| | ) -> None: |
| | """ |
| | Write to tensorboard. |
| | :param name: The name of the tensorboard. |
| | :param value: The value to write. |
| | :param epoch: The current epoch. |
| | :param writer: The tensorboard writer. |
| | :param logger: The logger. |
| | :param tensorboard_required: Whether the tensorboard writer is required. |
| | """ |
| | |
| | if writer is None: |
| | if tensorboard_required: |
| | logger.warning("Writer is None. Please set the writer first.") |
| | return |
| | |
| | if value is None: |
| | logger.warning("Value is None. Please set the value first.") |
| | return |
| | |
| | if name is None: |
| | logger.warning("Name is None. Please set the name first.") |
| | return |
| |
|
| | |
| | if isinstance(value, int): |
| | writer.add_scalar(name, float(value), epoch) |
| | elif isinstance(value, float): |
| | writer.add_scalar(name, value, epoch) |
| | elif isinstance(value, torch.Tensor): |
| | value = value.detach().cpu().numpy() |
| | writer.add_histogram(name, value, epoch) |
| | elif isinstance(value, list): |
| | value = np.array(value) |
| | writer.add_histogram(name, value, epoch) |
| | elif isinstance(value, np.ndarray): |
| | writer.add_histogram(name, value, epoch) |
| | elif isinstance(value, str): |
| | writer.add_text(name, value, epoch) |
| | elif isinstance(value, bytes): |
| | image = Image.open(io.BytesIO(value)) |
| | transform = transforms.ToTensor() |
| | value = transform(image) |
| | writer.add_image(name, value, epoch) |
| | elif isinstance(value, plt.Figure): |
| | buf = io.BytesIO() |
| | value.savefig(buf, format='png') |
| | buf.seek(0) |
| | image = Image.open(buf) |
| | image = transforms.ToTensor()(image) |
| | writer.add_image(name, image, epoch) |
| | plt.close() |
| | else: |
| | raise ValueError(f"Type {type(value)} not supported.") |
| | |
| | |
| | |
| |
|