HDTV / app.py
AUST001's picture
Update app.py
8023908
import numpy as np
import torch
import matplotlib.pyplot as plt
import gradio as gr
import io
import numpy as np
from PIL import Image
from einops.layers.torch import Rearrange, Reduce
def visualize_matrices(matrices_text, show_colorbar=False):
def mul(x):
res = 1
for i in x:
res *= i
return res
# Example usage:
matrices = torch.arange(mul(eval(matrices_text))).reshape(*eval(matrices_text))
# 只支持pytorch中的tensor数据类型
if not torch.is_tensor(matrices):
raise ValueError("Input should be a pytorch tensor.")
if len(matrices.shape)==1:
matrices = matrices.reshape(1, matrices.shape[0])
if len(matrices.shape)==3 and matrices.shape[0]==1:
matrices = matrices.reshape(matrices.shape[1], matrices.shape[2])
# 支持二维矩阵
if len(matrices.shape)==2:
matrices = torch.flip(matrices, (0,)).numpy()
plt.figure(figsize=(5, 5))
cax = plt.matshow(matrices, cmap='coolwarm', origin='lower')
for i in range(matrices.shape[0]):
for j in range(matrices.shape[1]):
plt.text(j, i, str(round(matrices[i, j],3)), ha='center', va='center', fontsize=12, color='black')
plt.xticks([])
plt.yticks([])
if show_colorbar:
plt.colorbar(cax)
# 将Matplotlib图像转换为PIL图像
buf = io.BytesIO()
# plt.savefig(buf, format='png')
# buf.seek(0)
# image = Image.open(buf)
# 使用bbox_inches和pad_inches调整保存的图像
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
image = Image.open(buf)
# 清除当前图像,以便为下一个请求绘制新图像
plt.clf()
return image
else:
cols = 1
rows = 1
num = 0
for i in matrices.shape[:-2]:
if num%2==0:
rows = rows*i
else:
cols = cols*i
num += 1
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
matrices = matrices.reshape(-1,matrices.shape[-2],matrices.shape[-1])
for i, matrix in enumerate(matrices):
if len(matrix.shape) != 2:
raise ValueError("Each matrix should have exactly 2 dimensions.")
matrix = torch.flip(matrix, (0,)).numpy()
ax = axes.flatten()[i]
cax = ax.matshow(matrix, cmap='coolwarm', origin='lower')
for x in range(matrix.shape[0]):
for y in range(matrix.shape[1]):
ax.text(y, x, str(round(matrix[x, y],2)), ha='center', va='center', fontsize=12, color='black')
ax.set_xticks([])
ax.set_yticks([])
# 添加标题
# axs[i, j].set_title(f"Layer {i+1}, Row {j+1}", fontsize=14)
if show_colorbar:
plt.colorbar(cax, ax=ax)
plt.tight_layout()
# 将Matplotlib图像转换为PIL图像
buf = io.BytesIO()
# plt.savefig(buf, format='png')
# buf.seek(0)
# image = Image.open(buf)
# 使用bbox_inches和pad_inches调整保存的图像
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
image = Image.open(buf)
# 清除当前图像,以便为下一个请求绘制新图像
plt.clf()
return image
def visualize_second_matrices(matrices_text, do_what, show_colorbar=False):
def mul(x):
res = 1
for i in x:
res *= i
return res
# Example usage:
matrices = torch.arange(mul(eval(matrices_text))).reshape(*eval(matrices_text))
for do in do_what.split('&'):
matrices = eval(do)(matrices)
# 只支持pytorch中的tensor数据类型
if not torch.is_tensor(matrices):
raise ValueError("Input should be a pytorch tensor.")
if len(matrices.shape)==1:
matrices = matrices.reshape(1, matrices.shape[0])
if len(matrices.shape)==3 and matrices.shape[0]==1:
matrices = matrices.reshape(matrices.shape[1], matrices.shape[2])
# 支持二维矩阵
if len(matrices.shape)==2:
matrices = torch.flip(matrices, (0,)).numpy()
plt.figure(figsize=(5, 5))
cax = plt.matshow(matrices, cmap='coolwarm', origin='lower')
for i in range(matrices.shape[0]):
for j in range(matrices.shape[1]):
plt.text(j, i, str(round(matrices[i, j],3)), ha='center', va='center', fontsize=12, color='black')
plt.xticks([])
plt.yticks([])
if show_colorbar:
plt.colorbar(cax)
# 将Matplotlib图像转换为PIL图像
buf = io.BytesIO()
# plt.savefig(buf, format='png')
# buf.seek(0)
# image = Image.open(buf)
# 使用bbox_inches和pad_inches调整保存的图像
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
image = Image.open(buf)
# 清除当前图像,以便为下一个请求绘制新图像
plt.clf()
return image
else:
cols = 1
rows = 1
num = 0
for i in matrices.shape[:-2]:
if num%2==0:
rows = rows*i
else:
cols = cols*i
num += 1
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
matrices = matrices.reshape(-1,matrices.shape[-2],matrices.shape[-1])
for i, matrix in enumerate(matrices):
if len(matrix.shape) != 2:
raise ValueError("Each matrix should have exactly 2 dimensions.")
matrix = torch.flip(matrix, (0,)).numpy()
ax = axes.flatten()[i]
cax = ax.matshow(matrix, cmap='coolwarm', origin='lower')
for x in range(matrix.shape[0]):
for y in range(matrix.shape[1]):
ax.text(y, x, str(round(matrix[x, y],2)), ha='center', va='center', fontsize=12, color='black')
ax.set_xticks([])
ax.set_yticks([])
# 添加标题
# axs[i, j].set_title(f"Layer {i+1}, Row {j+1}", fontsize=14)
if show_colorbar:
plt.colorbar(cax, ax=ax)
plt.tight_layout()
# 将Matplotlib图像转换为PIL图像
buf = io.BytesIO()
# plt.savefig(buf, format='png')
# buf.seek(0)
# image = Image.open(buf)
# 使用bbox_inches和pad_inches调整保存的图像
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
image = Image.open(buf)
# 清除当前图像,以便为下一个请求绘制新图像
plt.clf()
return image
def generate_images(text1, text2):
image1 = visualize_matrices(text1)
image2 = visualize_second_matrices(text1, text2)
return image1, image2
inputs = [gr.inputs.Textbox(lines=2, placeholder="tensor dims"),
gr.inputs.Textbox(lines=2, placeholder="what to do?")]
outputs = [gr.outputs.Image(type="pil"),
gr.outputs.Image(type="pil")]
demo = gr.Interface(fn=generate_images, inputs=inputs, outputs=outputs,
title="高维数据可视化工具",
description="""
理解维度变换的三个关键:
1.理解每个维度代表的含义,例如(b,c,h,w)(b,l,e)等
2.理解reshape/view的本质
3.理解高维张量转置的本质
矩阵乘和Linear的理解:
1.attention中的矩阵乘就是用下图中的每一个矩阵和权重矩阵相乘,矩阵和矩阵之间没有特征交互
2.Linear中的矩阵乘就是用下图中的每一个矩阵的每一行和权重矩阵相乘,行与行之间没有特征交互
""",
examples=[
["[2, 3, 4]", "Rearrange('c h w -> c w h')"],
["[2, 3, 4]", "Rearrange('c h w -> c w h')&Rearrange('c h w -> c w h')&Rearrange('c h w -> c w h')"],
["[2, 3, 4, 4]", "Rearrange('b c h w -> b c (h w)')"],
["[2, 3, 4, 4]", "Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = 2, p2 = 2)"],
["[2, 3, 4, 4]", "Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = 2, p2 = 2)&Rearrange('b h w (c s) -> b w c (h s)', s=2)"]
]
)
if __name__ == "__main__":
demo.launch()