from abc import ABC, abstractmethod import warnings from typing import Any, Union, Sequence, Optional from lightning.pytorch.utilities.types import STEP_OUTPUT from omegaconf import DictConfig import lightning.pytorch as pl import torch import numpy as np from PIL import Image import wandb import einops class BasePytorchAlgo(pl.LightningModule, ABC): """ A base class for Pytorch algorithms using Pytorch Lightning. See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details. """ def __init__(self, cfg: DictConfig): super().__init__() self.cfg = cfg self._build_model() @abstractmethod def _build_model(self): """ Create all pytorch nn.Modules here. """ raise NotImplementedError @abstractmethod def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. Args: batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`. batch_idx: The index of this batch. dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch. Return: Any of these options: - :class:`~torch.Tensor` - The loss tensor - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``. - ``None`` - Skip to the next batch. This is only supported for automatic optimization. This is not supported for multi-GPU, TPU, IPU, or DeepSpeed. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. Example:: def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss To use multiple optimizers, you can switch to 'manual optimization' and control their stepping: .. code-block:: python def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step() Note: When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically normalized by ``accumulate_grad_batches`` internally. """ return super().training_step(*args, **kwargs) def configure_optimizers(self): """ Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation: https://lightning.ai/docs/pytorch/stable/common/optimization.html """ parameters = self.parameters() return torch.optim.Adam(parameters, lr=self.cfg.lr) def log_video( self, key: str, video: Union[np.ndarray, torch.Tensor], mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None, std: Union[np.ndarray, torch.Tensor, Sequence, float] = None, fps: int = 5, format: str = "mp4", ): """ Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly. Args: video: a numpy array or tensor, either in form (time, channel, height, width) or in the form (batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8 or [0, 1] otherwise. mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1]. std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1]. key: the name of the video. fps: the frame rate of the video. format: the format of the video. Can be either "mp4" or "gif". """ if isinstance(video, torch.Tensor): video = video.detach().cpu().numpy() expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1] if std is not None: if isinstance(std, (float, int)): std = [std] * 3 if isinstance(std, torch.Tensor): std = std.detach().cpu().numpy() std = np.array(std).reshape(*expand_shape) video = video * std if mean is not None: if isinstance(mean, (float, int)): mean = [mean] * 3 if isinstance(mean, torch.Tensor): mean = mean.detach().cpu().numpy() mean = np.array(mean).reshape(*expand_shape) video = video + mean if video.dtype != np.uint8: video = np.clip(video, a_min=0, a_max=1) * 255 video = video.astype(np.uint8) self.logger.experiment.log( { key: wandb.Video(video, fps=fps, format=format), }, step=self.global_step, ) def log_image( self, key: str, image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]], mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None, std: Union[np.ndarray, torch.Tensor, Sequence, float] = None, **kwargs: Any, ): """ Log image(s) using WandbLogger. Args: key: the name of the video. image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width). mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1]. std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1]. kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx. """ if isinstance(image, Image.Image): image = [image] elif len(image) and not isinstance(image[0], Image.Image): if isinstance(image, torch.Tensor): image = image.detach().cpu().numpy() if len(image.shape) == 3: image = image[None] if image.shape[1] == 3: if image.shape[-1] == 3: warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.") image = einops.rearrange(image, "b c h w -> b h w c") if std is not None: if isinstance(std, (float, int)): std = [std] * 3 if isinstance(std, torch.Tensor): std = std.detach().cpu().numpy() std = np.array(std)[None, None, None] image = image * std if mean is not None: if isinstance(mean, (float, int)): mean = [mean] * 3 if isinstance(mean, torch.Tensor): mean = mean.detach().cpu().numpy() mean = np.array(mean)[None, None, None] image = image + mean if image.dtype != np.uint8: image = np.clip(image, a_min=0.0, a_max=1.0) * 255 image = image.astype(np.uint8) image = [img for img in image] self.logger.log_image(key=key, images=image, **kwargs) def log_gradient_stats(self): """Log gradient statistics such as the mean or std of norm.""" with torch.no_grad(): grad_norms = [] gpr = [] # gradient-to-parameter ratio for param in self.parameters(): if param.grad is not None: grad_norms.append(torch.norm(param.grad).item()) gpr.append(torch.norm(param.grad) / torch.norm(param)) if len(grad_norms) == 0: return grad_norms = torch.tensor(grad_norms) gpr = torch.tensor(gpr) self.log_dict( { "train/grad_norm/min": grad_norms.min(), "train/grad_norm/max": grad_norms.max(), "train/grad_norm/std": grad_norms.std(), "train/grad_norm/mean": grad_norms.mean(), "train/grad_norm/median": torch.median(grad_norms), "train/gpr/min": gpr.min(), "train/gpr/max": gpr.max(), "train/gpr/std": gpr.std(), "train/gpr/mean": gpr.mean(), "train/gpr/median": torch.median(gpr), } ) def register_data_mean_std( self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data" ): """ Register mean and std of data as tensor buffer. Args: mean: the mean of data. std: the std of data. namespace: the namespace of the registered buffer. """ for k, v in [("mean", mean), ("std", std)]: if isinstance(v, str): if v.endswith(".npy"): v = torch.from_numpy(np.load(v)) elif v.endswith(".pt"): v = torch.load(v) else: raise ValueError(f"Unsupported file type {v.split('.')[-1]}.") else: v = torch.tensor(v) self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))