|
import numpy as np |
|
import math |
|
import PIL |
|
|
|
def postprocess(x): |
|
"""[0,1] to uint8.""" |
|
|
|
x = np.clip(255 * x, 0, 255) |
|
x = np.cast[np.uint8](x) |
|
return x |
|
|
|
def tile(X, rows, cols): |
|
"""Tile images for display.""" |
|
tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) |
|
for i in range(rows): |
|
for j in range(cols): |
|
idx = i * cols + j |
|
if idx < X.shape[0]: |
|
img = X[idx,...] |
|
tiling[ |
|
i*X.shape[1]:(i+1)*X.shape[1], |
|
j*X.shape[2]:(j+1)*X.shape[2], |
|
:] = img |
|
return tiling |
|
|
|
|
|
def plot_batch(X, out_path): |
|
"""Save batch of images tiled.""" |
|
n_channels = X.shape[3] |
|
if n_channels > 3: |
|
X = X[:,:,:,np.random.choice(n_channels, size = 3)] |
|
X = postprocess(X) |
|
rc = math.sqrt(X.shape[0]) |
|
rows = cols = math.ceil(rc) |
|
canvas = tile(X, rows, cols) |
|
canvas = np.squeeze(canvas) |
|
PIL.Image.fromarray(canvas).save(out_path) |