Spaces:
Running
on
Zero
Running
on
Zero
import matplotlib.pyplot as plt | |
import torch | |
import torchvision.utils as vutils | |
from torch import nn | |
def plot_example_grid( | |
X, | |
nrow=4, | |
column_titles=None, | |
cmap=None, | |
normalize=True, | |
renormalize_fn=None, | |
title=None, | |
figsize=None, | |
save_path=None, | |
dpi=80, | |
): | |
""" | |
X: tensor (B, C, H, W) | |
nrow: number of images in each row | |
column_titles: list of str - only if B is a multiple of nrow | |
""" | |
X = X.detach().cpu() | |
if renormalize_fn: | |
X = renormalize_fn(X) | |
normalize = False | |
grid = vutils.make_grid( | |
X, nrow=nrow, normalize=normalize, scale_each=True, padding=2 | |
) | |
npimg = grid.permute(1, 2, 0).numpy() | |
H, W = npimg.shape[:2] | |
img_h = H // (len(X) // nrow) | |
img_w = W // nrow | |
if figsize is not None: | |
plt.figure(figsize=figsize, dpi=dpi) | |
else: | |
plt.figure(figsize=(min(20, 3.5 * nrow), 3.5 * (len(X) // nrow + 1)), dpi=dpi) | |
plt.imshow(npimg, cmap=cmap or ("gray" if npimg.shape[-1] == 1 else None)) | |
# Add column titles | |
if column_titles is not None: | |
for i, col_title in enumerate(column_titles): | |
x_center = i * img_w + img_w / 2 | |
plt.text(x_center, y=-5, s=col_title, fontsize=12, ha="center", va="bottom") | |
if title: | |
plt.title(title) | |
plt.axis("off") | |
plt.tight_layout() | |
if save_path is not None: | |
plt.savefig(save_path, bbox_inches="tight", pad_inches=0) | |
else: | |
plt.show() | |
def plot_function( | |
f, | |
x_range=(-5, 5), | |
num_points=1000, | |
title="Function plot", | |
xlabel="x", | |
ylabel="f(x)", | |
dpi=80, | |
): | |
x = torch.linspace(x_range[0], x_range[1], num_points) | |
y = f(x) | |
y = y.detach().cpu().numpy() | |
x = x.detach().cpu().numpy() | |
plt.figure(figsize=(8, 4), dpi=dpi) | |
plt.plot(x, y, label="f(x)") | |
plt.axhline(0, color="black", linewidth=1, linestyle="--") # y=0 axis | |
plt.title(title) | |
plt.xlabel(xlabel) | |
plt.ylabel(ylabel) | |
plt.grid(True) | |
plt.legend() | |
plt.tight_layout() | |
plt.show() | |
def maxpool2d_param_extractor(child): | |
return { | |
"kernel_size": child.kernel_size, | |
"stride": child.stride, | |
"padding": child.padding, | |
"dilation": child.dilation, | |
"return_indices": child.return_indices, | |
"ceil_mode": child.ceil_mode, | |
} | |
def replace_module_with_custom_( | |
module, custom_cls, original_cls=None, by_name=None, param_extractor=None | |
): | |
for name, child in module.named_children(): | |
if (by_name is not None and name == by_name) or ( | |
original_cls is not None and isinstance(child, original_cls) | |
): | |
params = param_extractor(child) if param_extractor is not None else {} | |
setattr(module, name, custom_cls(**params)) | |
else: | |
replace_module_with_custom_( | |
child, | |
custom_cls, | |
original_cls=original_cls, | |
by_name=by_name, | |
param_extractor=param_extractor, | |
) | |
def show_images(images, adv_images, k=5): | |
uimages = images.unflatten(0, (5, 1)) | |
uadv_images = adv_images.unflatten(0, (5, k)) | |
udiff = uadv_images - uimages | |
show_adv = torch.cat([uimages, uadv_images], dim=1).flatten(0, 1) | |
show_diff = torch.cat([uimages, udiff], dim=1).flatten(0, 1) | |
return show_adv, show_diff | |