| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import torch.nn.functional as F |
| |
|
| | from matplotlib.colors import ListedColormap, BoundaryNorm |
| | from matplotlib.lines import Line2D |
| | import matplotlib.animation as animation |
| | import scienceplots |
| |
|
| | def resize(seq, size): |
| | |
| | seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) |
| | seq = seq.clamp(0,1) |
| | return seq.unsqueeze(2) |
| |
|
| | |
| | |
| | |
| | def to_cpu_tensor(*args): |
| | ''' |
| | Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor |
| | ''' |
| | out = [] |
| | for tensor in args: |
| | if type(tensor) is np.ndarray: |
| | tensor = torch.Tensor(tensor) |
| | if type(tensor) is torch.Tensor: |
| | tensor = tensor.cpu() |
| | out.append(tensor) |
| | |
| | if len(out) == 1: |
| | return out[0] |
| | return out |
| |
|
| | from tempfile import NamedTemporaryFile |
| |
|
| | plt.style.use(['science', 'no-latex']) |
| | VIL_COLORS = [[0, 0, 0], |
| | [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], |
| | [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], |
| | [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], |
| | [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], |
| | [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], |
| | [0.9607843137254902, 0.9607843137254902, 0.0], |
| | [0.9294117647058824, 0.6745098039215687, 0.0], |
| | [0.9411764705882353, 0.43137254901960786, 0.0], |
| | [0.6274509803921569, 0.0, 0.0], |
| | [0.9058823529411765, 0.0, 1.0]] |
| |
|
| | VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] |
| |
|
| | """ Visualize function with colorbar and a line seprate input and output """ |
| | def gradio_visualize(sequence): |
| | ''' |
| | input: sequences, a list/dict of numpy/torch arrays with shape (T, C, H, W) |
| | C is assumed to be 1 and squeezed |
| | If batch > 1, only the first sequence will be printed |
| | ''' |
| | |
| | fig_size = 3 |
| | fig, axes = plt.subplots(1, len(sequence), figsize=(fig_size*len(sequence), fig_size), tight_layout=True) |
| | plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| | plt.setp(axes, xticks=[], yticks=[]) |
| |
|
| | for i, frame in enumerate(sequence): |
| | axes[i].set_xticks([]) |
| | axes[i].set_yticks([]) |
| | axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12) |
| | frame = frame.squeeze() |
| | im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N)) |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | cax = fig.add_axes([1, 0.05, 0.02, 0.5]) |
| | fig.colorbar(im, cax=cax) |
| |
|
| | |
| | with NamedTemporaryFile(suffix=".png", delete=False) as ff: |
| | fig.savefig(ff.name) |
| | file_path = ff.name |
| | |
| | |
| | plt.close(fig) |
| | |
| | return file_path |
| |
|
| |
|
| | def gradio_gif(sequences, T): |
| | ''' |
| | input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) |
| | C is assumed to be 1 and squeezed |
| | If batch > 1, only the first sequence will be printed |
| | ''' |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | horizontal = len(sequences) |
| | fig_size = 3 |
| | fig, axes = plt.subplots(nrows=1, ncols=horizontal, figsize=(fig_size*horizontal, fig_size), tight_layout=True) |
| | plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| | plt.setp(axes, xticks=[], yticks=[]) |
| |
|
| | if horizontal == 1: |
| | for i, (key, sequence) in enumerate(sequences.items()): |
| | axes.set_xticks([]) |
| | axes.set_yticks([]) |
| | axes.set_xlabel(f'{key}', fontsize=12) |
| | frame = sequence[0].squeeze() |
| | im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| | norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| | else: |
| | for i, (key, sequence) in enumerate(sequences.items()): |
| | axes[i].set_xticks([]) |
| | axes[i].set_yticks([]) |
| | axes[i].set_xlabel(f'{key}', fontsize=12) |
| | frame = sequence[0].squeeze() |
| | im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| | norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| |
|
| | title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) |
| |
|
| | |
| |
|
| | def animate(t): |
| | if horizontal == 1: |
| | for i, sequence in enumerate(sequences.values()): |
| | frame = sequence[t].squeeze() |
| | im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| | norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| | else: |
| | for i, sequence in enumerate(sequences.values()): |
| | frame = sequence[t].squeeze() |
| | im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| | norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| | plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| | |
| | title.set_text(f'$t + {t}$') |
| |
|
| | return fig, |
| |
|
| | ani = animation.FuncAnimation(fig, animate, frames=T, interval=750, blit=True, repeat_delay=50,) |
| | |
| | |
| | with NamedTemporaryFile(suffix=".gif", delete=False) as ff: |
| | ani.save(ff.name, writer='pillow', fps=5) |
| | file_path = ff.name |
| | |
| | plt.close() |
| | return file_path |