worldmem / algorithms /common /base_pytorch_algo.py
xizaoqu
update
ae8fd03
from abc import ABC, abstractmethod
import warnings
from typing import Any, Union, Sequence, Optional
from lightning.pytorch.utilities.types import STEP_OUTPUT
from omegaconf import DictConfig
import lightning.pytorch as pl
import torch
import numpy as np
from PIL import Image
import wandb
import einops
class BasePytorchAlgo(pl.LightningModule, ABC):
"""
A base class for Pytorch algorithms using Pytorch Lightning.
See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
self._build_model()
@abstractmethod
def _build_model(self):
"""
Create all pytorch nn.Modules here.
"""
raise NotImplementedError
@abstractmethod
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
logger.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
Return:
Any of these options:
- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
- ``None`` - Skip to the next batch. This is only supported for automatic optimization.
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
.. code-block:: python
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
normalized by ``accumulate_grad_batches`` internally.
"""
return super().training_step(*args, **kwargs)
def configure_optimizers(self):
"""
Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
https://lightning.ai/docs/pytorch/stable/common/optimization.html
"""
parameters = self.parameters()
return torch.optim.Adam(parameters, lr=self.cfg.lr)
def log_video(
self,
key: str,
video: Union[np.ndarray, torch.Tensor],
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
fps: int = 5,
format: str = "mp4",
):
"""
Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
Args:
video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
(batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
or [0, 1] otherwise.
mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
key: the name of the video.
fps: the frame rate of the video.
format: the format of the video. Can be either "mp4" or "gif".
"""
if isinstance(video, torch.Tensor):
video = video.detach().cpu().numpy()
expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
if std is not None:
if isinstance(std, (float, int)):
std = [std] * 3
if isinstance(std, torch.Tensor):
std = std.detach().cpu().numpy()
std = np.array(std).reshape(*expand_shape)
video = video * std
if mean is not None:
if isinstance(mean, (float, int)):
mean = [mean] * 3
if isinstance(mean, torch.Tensor):
mean = mean.detach().cpu().numpy()
mean = np.array(mean).reshape(*expand_shape)
video = video + mean
if video.dtype != np.uint8:
video = np.clip(video, a_min=0, a_max=1) * 255
video = video.astype(np.uint8)
self.logger.experiment.log(
{
key: wandb.Video(video, fps=fps, format=format),
},
step=self.global_step,
)
def log_image(
self,
key: str,
image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
**kwargs: Any,
):
"""
Log image(s) using WandbLogger.
Args:
key: the name of the video.
image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
"""
if isinstance(image, Image.Image):
image = [image]
elif len(image) and not isinstance(image[0], Image.Image):
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if len(image.shape) == 3:
image = image[None]
if image.shape[1] == 3:
if image.shape[-1] == 3:
warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
image = einops.rearrange(image, "b c h w -> b h w c")
if std is not None:
if isinstance(std, (float, int)):
std = [std] * 3
if isinstance(std, torch.Tensor):
std = std.detach().cpu().numpy()
std = np.array(std)[None, None, None]
image = image * std
if mean is not None:
if isinstance(mean, (float, int)):
mean = [mean] * 3
if isinstance(mean, torch.Tensor):
mean = mean.detach().cpu().numpy()
mean = np.array(mean)[None, None, None]
image = image + mean
if image.dtype != np.uint8:
image = np.clip(image, a_min=0.0, a_max=1.0) * 255
image = image.astype(np.uint8)
image = [img for img in image]
self.logger.log_image(key=key, images=image, **kwargs)
def log_gradient_stats(self):
"""Log gradient statistics such as the mean or std of norm."""
with torch.no_grad():
grad_norms = []
gpr = [] # gradient-to-parameter ratio
for param in self.parameters():
if param.grad is not None:
grad_norms.append(torch.norm(param.grad).item())
gpr.append(torch.norm(param.grad) / torch.norm(param))
if len(grad_norms) == 0:
return
grad_norms = torch.tensor(grad_norms)
gpr = torch.tensor(gpr)
self.log_dict(
{
"train/grad_norm/min": grad_norms.min(),
"train/grad_norm/max": grad_norms.max(),
"train/grad_norm/std": grad_norms.std(),
"train/grad_norm/mean": grad_norms.mean(),
"train/grad_norm/median": torch.median(grad_norms),
"train/gpr/min": gpr.min(),
"train/gpr/max": gpr.max(),
"train/gpr/std": gpr.std(),
"train/gpr/mean": gpr.mean(),
"train/gpr/median": torch.median(gpr),
}
)
def register_data_mean_std(
self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
):
"""
Register mean and std of data as tensor buffer.
Args:
mean: the mean of data.
std: the std of data.
namespace: the namespace of the registered buffer.
"""
for k, v in [("mean", mean), ("std", std)]:
if isinstance(v, str):
if v.endswith(".npy"):
v = torch.from_numpy(np.load(v))
elif v.endswith(".pt"):
v = torch.load(v)
else:
raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
else:
v = torch.tensor(v)
self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))