import matplotlib.pyplot as plt import numpy def plot_tensor_images(data, **kwargs): data = ((data + 1) / 2 * 255).permute(0, 2, 3, 1).byte().cpu().numpy() width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) height = int(numpy.ceil(data.shape[0] / float(width))) kwargs = dict(kwargs) margin = 0.01 if 'figsize' not in kwargs: # Size figure to one display pixel per data pixel dpi = plt.rcParams['figure.dpi'] kwargs['figsize'] = ( (1 + margin) * (width * data.shape[2] / dpi), (1 + margin) * (height * data.shape[1] / dpi)) f, axarr = plt.subplots(height, width, **kwargs) if len(numpy.shape(axarr)) == 0: axarr = numpy.array([[axarr]]) if len(numpy.shape(axarr)) == 1: axarr = axarr[None,:] for i, im in enumerate(data): ax = axarr[i // width, i % width] ax.imshow(data[i]) ax.axis('off') for i in range(i, width * height): ax = axarr[i // width, i % width] ax.axis('off') plt.subplots_adjust(wspace=margin, hspace=margin, left=0, right=1, bottom=0, top=1) plt.show() def plot_max_heatmap(data, shape=None, **kwargs): if shape is None: shape = data.shape[2:] data = data.max(1)[0].cpu().numpy() vmin = data.min() vmax = data.max() width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) height = int(numpy.ceil(data.shape[0] / float(width))) kwargs = dict(kwargs) margin = 0.01 if 'figsize' not in kwargs: # Size figure to one display pixel per data pixel dpi = plt.rcParams['figure.dpi'] kwargs['figsize'] = ( width * shape[1] / dpi, height * shape[0] / dpi) f, axarr = plt.subplots(height, width, **kwargs) if len(numpy.shape(axarr)) == 0: axarr = numpy.array([[axarr]]) if len(numpy.shape(axarr)) == 1: axarr = axarr[None,:] for i, im in enumerate(data): ax = axarr[i // width, i % width] img = ax.imshow(data[i], vmin=vmin, vmax=vmax, cmap='hot') ax.axis('off') for i in range(i, width * height): ax = axarr[i // width, i % width] ax.axis('off') plt.subplots_adjust(wspace=margin, hspace=margin, left=0, right=1, bottom=0, top=1) plt.show()