|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
import io |
|
import numpy as np |
|
from PIL import Image |
|
from einops.layers.torch import Rearrange, Reduce |
|
|
|
def visualize_matrices(matrices_text, show_colorbar=False): |
|
def mul(x): |
|
res = 1 |
|
for i in x: |
|
res *= i |
|
return res |
|
|
|
matrices = torch.arange(mul(eval(matrices_text))).reshape(*eval(matrices_text)) |
|
|
|
if not torch.is_tensor(matrices): |
|
raise ValueError("Input should be a pytorch tensor.") |
|
if len(matrices.shape)==1: |
|
matrices = matrices.reshape(1, matrices.shape[0]) |
|
if len(matrices.shape)==3 and matrices.shape[0]==1: |
|
matrices = matrices.reshape(matrices.shape[1], matrices.shape[2]) |
|
|
|
if len(matrices.shape)==2: |
|
matrices = torch.flip(matrices, (0,)).numpy() |
|
plt.figure(figsize=(5, 5)) |
|
cax = plt.matshow(matrices, cmap='coolwarm', origin='lower') |
|
|
|
for i in range(matrices.shape[0]): |
|
for j in range(matrices.shape[1]): |
|
plt.text(j, i, str(round(matrices[i, j],3)), ha='center', va='center', fontsize=12, color='black') |
|
|
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
if show_colorbar: |
|
plt.colorbar(cax) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
|
|
|
|
|
|
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
|
|
|
|
plt.clf() |
|
|
|
return image |
|
else: |
|
cols = 1 |
|
rows = 1 |
|
num = 0 |
|
for i in matrices.shape[:-2]: |
|
if num%2==0: |
|
rows = rows*i |
|
else: |
|
cols = cols*i |
|
num += 1 |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5)) |
|
|
|
|
|
matrices = matrices.reshape(-1,matrices.shape[-2],matrices.shape[-1]) |
|
|
|
|
|
for i, matrix in enumerate(matrices): |
|
if len(matrix.shape) != 2: |
|
raise ValueError("Each matrix should have exactly 2 dimensions.") |
|
matrix = torch.flip(matrix, (0,)).numpy() |
|
|
|
ax = axes.flatten()[i] |
|
cax = ax.matshow(matrix, cmap='coolwarm', origin='lower') |
|
|
|
for x in range(matrix.shape[0]): |
|
for y in range(matrix.shape[1]): |
|
ax.text(y, x, str(round(matrix[x, y],2)), ha='center', va='center', fontsize=12, color='black') |
|
|
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
if show_colorbar: |
|
plt.colorbar(cax, ax=ax) |
|
|
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
|
|
|
|
|
|
|
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
|
|
|
|
plt.clf() |
|
|
|
return image |
|
|
|
def visualize_second_matrices(matrices_text, do_what, show_colorbar=False): |
|
def mul(x): |
|
res = 1 |
|
for i in x: |
|
res *= i |
|
return res |
|
|
|
matrices = torch.arange(mul(eval(matrices_text))).reshape(*eval(matrices_text)) |
|
for do in do_what.split('&'): |
|
matrices = eval(do)(matrices) |
|
|
|
if not torch.is_tensor(matrices): |
|
raise ValueError("Input should be a pytorch tensor.") |
|
if len(matrices.shape)==1: |
|
matrices = matrices.reshape(1, matrices.shape[0]) |
|
if len(matrices.shape)==3 and matrices.shape[0]==1: |
|
matrices = matrices.reshape(matrices.shape[1], matrices.shape[2]) |
|
|
|
if len(matrices.shape)==2: |
|
matrices = torch.flip(matrices, (0,)).numpy() |
|
plt.figure(figsize=(5, 5)) |
|
cax = plt.matshow(matrices, cmap='coolwarm', origin='lower') |
|
|
|
for i in range(matrices.shape[0]): |
|
for j in range(matrices.shape[1]): |
|
plt.text(j, i, str(round(matrices[i, j],3)), ha='center', va='center', fontsize=12, color='black') |
|
|
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
if show_colorbar: |
|
plt.colorbar(cax) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
|
|
|
|
|
|
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
|
|
|
|
plt.clf() |
|
|
|
return image |
|
else: |
|
cols = 1 |
|
rows = 1 |
|
num = 0 |
|
for i in matrices.shape[:-2]: |
|
if num%2==0: |
|
rows = rows*i |
|
else: |
|
cols = cols*i |
|
num += 1 |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5)) |
|
|
|
|
|
matrices = matrices.reshape(-1,matrices.shape[-2],matrices.shape[-1]) |
|
|
|
|
|
for i, matrix in enumerate(matrices): |
|
if len(matrix.shape) != 2: |
|
raise ValueError("Each matrix should have exactly 2 dimensions.") |
|
matrix = torch.flip(matrix, (0,)).numpy() |
|
|
|
ax = axes.flatten()[i] |
|
cax = ax.matshow(matrix, cmap='coolwarm', origin='lower') |
|
|
|
for x in range(matrix.shape[0]): |
|
for y in range(matrix.shape[1]): |
|
ax.text(y, x, str(round(matrix[x, y],2)), ha='center', va='center', fontsize=12, color='black') |
|
|
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
if show_colorbar: |
|
plt.colorbar(cax, ax=ax) |
|
|
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
|
|
|
|
|
|
|
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
|
|
|
|
plt.clf() |
|
|
|
return image |
|
|
|
|
|
def generate_images(text1, text2): |
|
image1 = visualize_matrices(text1) |
|
image2 = visualize_second_matrices(text1, text2) |
|
|
|
return image1, image2 |
|
|
|
inputs = [gr.inputs.Textbox(lines=2, placeholder="tensor dims"), |
|
gr.inputs.Textbox(lines=2, placeholder="what to do?")] |
|
|
|
outputs = [gr.outputs.Image(type="pil"), |
|
gr.outputs.Image(type="pil")] |
|
|
|
demo = gr.Interface(fn=generate_images, inputs=inputs, outputs=outputs, |
|
title="高维数据可视化工具", |
|
description=""" |
|
理解维度变换的三个关键: |
|
1.理解每个维度代表的含义,例如(b,c,h,w)(b,l,e)等 |
|
2.理解reshape/view的本质 |
|
3.理解高维张量转置的本质 |
|
|
|
矩阵乘和Linear的理解: |
|
1.attention中的矩阵乘就是用下图中的每一个矩阵和权重矩阵相乘,矩阵和矩阵之间没有特征交互 |
|
2.Linear中的矩阵乘就是用下图中的每一个矩阵的每一行和权重矩阵相乘,行与行之间没有特征交互 |
|
""", |
|
examples=[ |
|
["[2, 3, 4]", "Rearrange('c h w -> c w h')"], |
|
["[2, 3, 4]", "Rearrange('c h w -> c w h')&Rearrange('c h w -> c w h')&Rearrange('c h w -> c w h')"], |
|
["[2, 3, 4, 4]", "Rearrange('b c h w -> b c (h w)')"], |
|
["[2, 3, 4, 4]", "Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = 2, p2 = 2)"], |
|
["[2, 3, 4, 4]", "Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = 2, p2 = 2)&Rearrange('b h w (c s) -> b w c (h s)', s=2)"] |
|
] |
|
) |
|
if __name__ == "__main__": |
|
demo.launch() |