import numpy as np import math import PIL def postprocess(x): """[0,1] to uint8.""" x = np.clip(255 * x, 0, 255) x = np.cast[np.uint8](x) return x def tile(X, rows, cols): """Tile images for display.""" tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) for i in range(rows): for j in range(cols): idx = i * cols + j if idx < X.shape[0]: img = X[idx,...] tiling[ i*X.shape[1]:(i+1)*X.shape[1], j*X.shape[2]:(j+1)*X.shape[2], :] = img return tiling def plot_batch(X, out_path): """Save batch of images tiled.""" n_channels = X.shape[3] if n_channels > 3: X = X[:,:,:,np.random.choice(n_channels, size = 3)] X = postprocess(X) rc = math.sqrt(X.shape[0]) rows = cols = math.ceil(rc) canvas = tile(X, rows, cols) canvas = np.squeeze(canvas) PIL.Image.fromarray(canvas).save(out_path)