File size: 9,860 Bytes
27ca8b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
from abc import ABC, abstractmethod
import warnings
from typing import Any, Union, Sequence, Optional
from lightning.pytorch.utilities.types import STEP_OUTPUT
from omegaconf import DictConfig
import lightning.pytorch as pl
import torch
import numpy as np
from PIL import Image
import wandb
import einops
class BasePytorchAlgo(pl.LightningModule, ABC):
"""
A base class for Pytorch algorithms using Pytorch Lightning.
See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
self._build_model()
@abstractmethod
def _build_model(self):
"""
Create all pytorch nn.Modules here.
"""
raise NotImplementedError
@abstractmethod
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
logger.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
Return:
Any of these options:
- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
- ``None`` - Skip to the next batch. This is only supported for automatic optimization.
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
.. code-block:: python
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
normalized by ``accumulate_grad_batches`` internally.
"""
return super().training_step(*args, **kwargs)
def configure_optimizers(self):
"""
Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
https://lightning.ai/docs/pytorch/stable/common/optimization.html
"""
parameters = self.parameters()
return torch.optim.Adam(parameters, lr=self.cfg.lr)
def log_video(
self,
key: str,
video: Union[np.ndarray, torch.Tensor],
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
fps: int = 5,
format: str = "mp4",
):
"""
Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
Args:
video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
(batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
or [0, 1] otherwise.
mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
key: the name of the video.
fps: the frame rate of the video.
format: the format of the video. Can be either "mp4" or "gif".
"""
if isinstance(video, torch.Tensor):
video = video.detach().cpu().numpy()
expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
if std is not None:
if isinstance(std, (float, int)):
std = [std] * 3
if isinstance(std, torch.Tensor):
std = std.detach().cpu().numpy()
std = np.array(std).reshape(*expand_shape)
video = video * std
if mean is not None:
if isinstance(mean, (float, int)):
mean = [mean] * 3
if isinstance(mean, torch.Tensor):
mean = mean.detach().cpu().numpy()
mean = np.array(mean).reshape(*expand_shape)
video = video + mean
if video.dtype != np.uint8:
video = np.clip(video, a_min=0, a_max=1) * 255
video = video.astype(np.uint8)
self.logger.experiment.log(
{
key: wandb.Video(video, fps=fps, format=format),
},
step=self.global_step,
)
def log_image(
self,
key: str,
image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
**kwargs: Any,
):
"""
Log image(s) using WandbLogger.
Args:
key: the name of the video.
image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
"""
if isinstance(image, Image.Image):
image = [image]
elif len(image) and not isinstance(image[0], Image.Image):
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if len(image.shape) == 3:
image = image[None]
if image.shape[1] == 3:
if image.shape[-1] == 3:
warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
image = einops.rearrange(image, "b c h w -> b h w c")
if std is not None:
if isinstance(std, (float, int)):
std = [std] * 3
if isinstance(std, torch.Tensor):
std = std.detach().cpu().numpy()
std = np.array(std)[None, None, None]
image = image * std
if mean is not None:
if isinstance(mean, (float, int)):
mean = [mean] * 3
if isinstance(mean, torch.Tensor):
mean = mean.detach().cpu().numpy()
mean = np.array(mean)[None, None, None]
image = image + mean
if image.dtype != np.uint8:
image = np.clip(image, a_min=0.0, a_max=1.0) * 255
image = image.astype(np.uint8)
image = [img for img in image]
self.logger.log_image(key=key, images=image, **kwargs)
def log_gradient_stats(self):
"""Log gradient statistics such as the mean or std of norm."""
with torch.no_grad():
grad_norms = []
gpr = [] # gradient-to-parameter ratio
for param in self.parameters():
if param.grad is not None:
grad_norms.append(torch.norm(param.grad).item())
gpr.append(torch.norm(param.grad) / torch.norm(param))
if len(grad_norms) == 0:
return
grad_norms = torch.tensor(grad_norms)
gpr = torch.tensor(gpr)
self.log_dict(
{
"train/grad_norm/min": grad_norms.min(),
"train/grad_norm/max": grad_norms.max(),
"train/grad_norm/std": grad_norms.std(),
"train/grad_norm/mean": grad_norms.mean(),
"train/grad_norm/median": torch.median(grad_norms),
"train/gpr/min": gpr.min(),
"train/gpr/max": gpr.max(),
"train/gpr/std": gpr.std(),
"train/gpr/mean": gpr.mean(),
"train/gpr/median": torch.median(gpr),
}
)
def register_data_mean_std(
self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
):
"""
Register mean and std of data as tensor buffer.
Args:
mean: the mean of data.
std: the std of data.
namespace: the namespace of the registered buffer.
"""
for k, v in [("mean", mean), ("std", std)]:
if isinstance(v, str):
if v.endswith(".npy"):
v = torch.from_numpy(np.load(v))
elif v.endswith(".pt"):
v = torch.load(v)
else:
raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
else:
v = torch.tensor(v)
self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))
|