Maciej Satkiewicz
app works locally
ed46d32
import matplotlib.pyplot as plt
import torch
import torchvision.utils as vutils
from torch import nn
def plot_example_grid(
X,
nrow=4,
column_titles=None,
cmap=None,
normalize=True,
renormalize_fn=None,
title=None,
figsize=None,
save_path=None,
dpi=80,
):
"""
X: tensor (B, C, H, W)
nrow: number of images in each row
column_titles: list of str - only if B is a multiple of nrow
"""
X = X.detach().cpu()
if renormalize_fn:
X = renormalize_fn(X)
normalize = False
grid = vutils.make_grid(
X, nrow=nrow, normalize=normalize, scale_each=True, padding=2
)
npimg = grid.permute(1, 2, 0).numpy()
H, W = npimg.shape[:2]
img_h = H // (len(X) // nrow)
img_w = W // nrow
if figsize is not None:
plt.figure(figsize=figsize, dpi=dpi)
else:
plt.figure(figsize=(min(20, 3.5 * nrow), 3.5 * (len(X) // nrow + 1)), dpi=dpi)
plt.imshow(npimg, cmap=cmap or ("gray" if npimg.shape[-1] == 1 else None))
# Add column titles
if column_titles is not None:
for i, col_title in enumerate(column_titles):
x_center = i * img_w + img_w / 2
plt.text(x_center, y=-5, s=col_title, fontsize=12, ha="center", va="bottom")
if title:
plt.title(title)
plt.axis("off")
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
else:
plt.show()
def plot_function(
f,
x_range=(-5, 5),
num_points=1000,
title="Function plot",
xlabel="x",
ylabel="f(x)",
dpi=80,
):
x = torch.linspace(x_range[0], x_range[1], num_points)
y = f(x)
y = y.detach().cpu().numpy()
x = x.detach().cpu().numpy()
plt.figure(figsize=(8, 4), dpi=dpi)
plt.plot(x, y, label="f(x)")
plt.axhline(0, color="black", linewidth=1, linestyle="--") # y=0 axis
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
def maxpool2d_param_extractor(child):
return {
"kernel_size": child.kernel_size,
"stride": child.stride,
"padding": child.padding,
"dilation": child.dilation,
"return_indices": child.return_indices,
"ceil_mode": child.ceil_mode,
}
def replace_module_with_custom_(
module, custom_cls, original_cls=None, by_name=None, param_extractor=None
):
for name, child in module.named_children():
if (by_name is not None and name == by_name) or (
original_cls is not None and isinstance(child, original_cls)
):
params = param_extractor(child) if param_extractor is not None else {}
setattr(module, name, custom_cls(**params))
else:
replace_module_with_custom_(
child,
custom_cls,
original_cls=original_cls,
by_name=by_name,
param_extractor=param_extractor,
)
def show_images(images, adv_images, k=5):
uimages = images.unflatten(0, (5, 1))
uadv_images = adv_images.unflatten(0, (5, k))
udiff = uadv_images - uimages
show_adv = torch.cat([uimages, uadv_images], dim=1).flatten(0, 1)
show_diff = torch.cat([uimages, udiff], dim=1).flatten(0, 1)
return show_adv, show_diff