Horace He
Renamed files
a38ccb4
raw
history blame
7.91 kB
import math
import cutlass.cute as cute
import cutlass
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ast
def visualize_tv_layout(
tiler_mn: tuple[int, int],
tv_layout, # (((thr_shape),(val_shape)),
# ((thr_stride),(val_stride)))
*,
font_size: int = 10,
cell_px: int = 70,
grid_lw: float = 1.5,
dpi: int = 100,
max_rows: int = None,
max_cols: int = None,
color_fn=None, # optional (tid,vid) -> colour
):
"""Draw a T/V checkerboard for an arbitrary TV layout."""
# -----------------------------------------------------------------
# 1) Build a real CuTe layout from the tuple the user passed
# -----------------------------------------------------------------
shape, stride = tv_layout
def compute_recursive_size(shape):
if isinstance(shape, int):
return shape
else:
return math.prod(compute_recursive_size(i) for i in shape)
n_thr = compute_recursive_size(shape[0])
n_val = compute_recursive_size(shape[1])
M, N = tiler_mn
# Apply max rows/cols limits if specified
if max_rows is not None and max_rows > 0:
M = min(M, max_rows)
if max_cols is not None and max_cols > 0:
N = min(N, max_cols)
thr_ids = np.full((M, N), -1, dtype=int)
val_ids = np.full((M, N), -1, dtype=int)
filled = np.zeros((M, N), dtype=bool)
# -----------------------------------------------------------------
# 2) Query CuTe for every (tid, vid) → (m,n)
# -----------------------------------------------------------------
@cute.jit
def g():
tv_layout = cute.make_layout(shape, stride=stride)
tid_vals = []
for tid in cutlass.range_constexpr(n_thr):
vid_vals = []
for vid in cutlass.range_constexpr(n_val):
vid_vals.append(tv_layout((tid, vid)))
tid_vals.append(vid_vals)
return tid_vals
vals = g()
full_M, full_N = tiler_mn
for tid in range(n_thr):
for vid in range(n_val):
pos = vals[tid][vid]
n = pos // full_M
m = pos % full_M
# Skip if outside the display limits
if m >= M or n >= N:
continue
if filled[m, n]:
continue
thr_ids[m, n] = tid
val_ids[m, n] = vid
filled[m, n] = True
# -----------------------------------------------------------------
# 3) Colours (default: pastel per-thread)
# -----------------------------------------------------------------
if color_fn is None:
# pastel = list(plt.cm.Set3.colors) # + plt.cm.Set2.colors + plt.cm.Set1.colors)
color_palettes = [
plt.cm.Set3.colors,
plt.cm.Set2.colors,
plt.cm.Set1.colors,
# plt.cm.Pastel1.colors,
# plt.cm.Pastel2.colors,
]
color_palettes = [j for i in color_palettes for j in i]
# breakpoint()
# cmap = []
# for i in range(n_thr):
# cmap += [[k * ((n_thr) - i)/ n_thr for k in j] for j in color_palettes]
cmap = (color_palettes * n_thr)[:n_thr]
# cmap = (pastel * n_thr)[:n_thr]
color_fn = lambda t, v: cmap[t % len(cmap)]
bg_rgb = np.zeros((M, N, 3))
for m in range(M):
for n in range(N):
tid = thr_ids[m, n]
if tid >= 0:
bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n]))
# -----------------------------------------------------------------
# 4) Draw
# -----------------------------------------------------------------
fig_w, fig_h = N * cell_px / 100, M * cell_px / 100
fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi)
ax.imshow(bg_rgb, interpolation="none")
for m in range(M):
for n in range(N):
if thr_ids[m, n] >= 0:
ax.text(
n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}",
ha="center", va="center",
fontsize=font_size, weight="bold"
)
ax.set_xticks(np.arange(N + 1) - 0.5)
ax.set_yticks(np.arange(M + 1) - 0.5)
ax.set_xticklabels([str(i) for i in range(N + 1)])
ax.set_yticklabels([str(i) for i in range(M + 1)])
ax.tick_params(axis='both', which='both', length=6, width=1)
ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False)
ax.tick_params(axis='y', which='both', left=True, right=False)
ax.grid(which="major", color="black", linewidth=grid_lw)
ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5)
# Format title with colon notation
ax.set_title(f"tv_layout = {shape}:{stride}", fontsize=font_size + 2, pad=12)
plt.tight_layout()
return fig
# visualize_tv_layout((32, 16), (((4,8,2,2),((2,2),(1,1))),((64,1,16,256),((32,8),(0,0)))), dpi=100, max_rows=16, max_cols=16)
# exit(0)
def gradio_visualize(tiler_mn_str, tv_layout_str, dpi, max_rows, max_cols):
"""Gradio wrapper for visualize_tv_layout."""
try:
# Parse input strings
tiler_mn = ast.literal_eval(tiler_mn_str)
# Support colon notation: (128,64):(1,128) or comma notation
if ':' in tv_layout_str:
# Split by colon to get shape and stride parts
parts = tv_layout_str.split(':')
if len(parts) != 2:
raise ValueError("Colon format must be shape:stride")
shape = ast.literal_eval(parts[0])
stride = ast.literal_eval(parts[1])
tv_layout = (shape, stride)
else:
# Traditional nested tuple format
tv_layout = ast.literal_eval(tv_layout_str)
fig = visualize_tv_layout(tiler_mn, tv_layout, dpi=dpi, max_rows=max_rows, max_cols=max_cols)
return fig
except Exception as e:
# Return error message
fig, ax = plt.subplots(figsize=(8, 4))
ax.text(0.5, 0.5, f"Error: {str(e)}",
ha='center', va='center', fontsize=12, color='red')
ax.axis('off')
return fig
# Create Gradio interface
with gr.Blocks(title="CuTe TV Layout Visualizer") as demo:
gr.Markdown("# CuTe TV Layout Visualizer")
gr.Markdown("Visualize thread/value (T/V) layouts for CuTe tensor operations.")
with gr.Row():
with gr.Column():
gr.Markdown("### Layout Parameters")
tiler_mn = gr.Textbox(
label="Tiler Dimensions (M, N)",
value="(8, 8)",
placeholder="(8, 8)"
)
tv_layout = gr.Textbox(
label="TV Layout",
value="((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))",
lines=2
)
dpi = gr.Number(label="DPI", value=200, precision=0)
max_rows = gr.Number(label="Max Rows (leave empty for no limit)", value=None, precision=0)
max_cols = gr.Number(label="Max Cols (leave empty for no limit)", value=None, precision=0)
visualize_btn = gr.Button("Visualize", variant="primary")
with gr.Column():
output_plot = gr.Plot(label="TV Layout Visualization")
visualize_btn.click(
fn=gradio_visualize,
inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols],
outputs=output_plot
)
# Add examples
gr.Examples(
examples=[
["(8, 8)", "((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", 200, None, None],
["(4, 8)", "((4, 2), 4):((1, 16), 4)", 200, None, None],
["(8, 4)", "((4, 2), 4):((8, 4), 1)", 200, None, None],
],
inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols],
)
if __name__ == "__main__":
demo.launch()