Spaces:
Running
Running
| import math | |
| import cutlass.cute as cute | |
| import cutlass | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| import ast | |
| def visualize_tv_layout( | |
| tiler_mn: tuple[int, int], | |
| tv_layout, # (((thr_shape),(val_shape)), | |
| # ((thr_stride),(val_stride))) | |
| *, | |
| font_size: int = 10, | |
| cell_px: int = 70, | |
| grid_lw: float = 1.5, | |
| dpi: int = 100, | |
| max_rows: int = None, | |
| max_cols: int = None, | |
| color_fn=None, # optional (tid,vid) -> colour | |
| ): | |
| """Draw a T/V checkerboard for an arbitrary TV layout.""" | |
| # ----------------------------------------------------------------- | |
| # 1) Build a real CuTe layout from the tuple the user passed | |
| # ----------------------------------------------------------------- | |
| shape, stride = tv_layout | |
| def compute_recursive_size(shape): | |
| if isinstance(shape, int): | |
| return shape | |
| else: | |
| return math.prod(compute_recursive_size(i) for i in shape) | |
| n_thr = compute_recursive_size(shape[0]) | |
| n_val = compute_recursive_size(shape[1]) | |
| M, N = tiler_mn | |
| # Apply max rows/cols limits if specified | |
| if max_rows is not None and max_rows > 0: | |
| M = min(M, max_rows) | |
| if max_cols is not None and max_cols > 0: | |
| N = min(N, max_cols) | |
| thr_ids = np.full((M, N), -1, dtype=int) | |
| val_ids = np.full((M, N), -1, dtype=int) | |
| filled = np.zeros((M, N), dtype=bool) | |
| # ----------------------------------------------------------------- | |
| # 2) Query CuTe for every (tid, vid) → (m,n) | |
| # ----------------------------------------------------------------- | |
| def g(): | |
| tv_layout = cute.make_layout(shape, stride=stride) | |
| tid_vals = [] | |
| for tid in cutlass.range_constexpr(n_thr): | |
| vid_vals = [] | |
| for vid in cutlass.range_constexpr(n_val): | |
| vid_vals.append(tv_layout((tid, vid))) | |
| tid_vals.append(vid_vals) | |
| return tid_vals | |
| vals = g() | |
| full_M, full_N = tiler_mn | |
| for tid in range(n_thr): | |
| for vid in range(n_val): | |
| pos = vals[tid][vid] | |
| n = pos // full_M | |
| m = pos % full_M | |
| # Skip if outside the display limits | |
| if m >= M or n >= N: | |
| continue | |
| if filled[m, n]: | |
| continue | |
| thr_ids[m, n] = tid | |
| val_ids[m, n] = vid | |
| filled[m, n] = True | |
| # ----------------------------------------------------------------- | |
| # 3) Colours (default: pastel per-thread) | |
| # ----------------------------------------------------------------- | |
| if color_fn is None: | |
| # pastel = list(plt.cm.Set3.colors) # + plt.cm.Set2.colors + plt.cm.Set1.colors) | |
| color_palettes = [ | |
| plt.cm.Set3.colors, | |
| plt.cm.Set2.colors, | |
| plt.cm.Set1.colors, | |
| # plt.cm.Pastel1.colors, | |
| # plt.cm.Pastel2.colors, | |
| ] | |
| color_palettes = [j for i in color_palettes for j in i] | |
| # breakpoint() | |
| # cmap = [] | |
| # for i in range(n_thr): | |
| # cmap += [[k * ((n_thr) - i)/ n_thr for k in j] for j in color_palettes] | |
| cmap = (color_palettes * n_thr)[:n_thr] | |
| # cmap = (pastel * n_thr)[:n_thr] | |
| color_fn = lambda t, v: cmap[t % len(cmap)] | |
| bg_rgb = np.zeros((M, N, 3)) | |
| for m in range(M): | |
| for n in range(N): | |
| tid = thr_ids[m, n] | |
| if tid >= 0: | |
| bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n])) | |
| # ----------------------------------------------------------------- | |
| # 4) Draw | |
| # ----------------------------------------------------------------- | |
| fig_w, fig_h = N * cell_px / 100, M * cell_px / 100 | |
| fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi) | |
| ax.imshow(bg_rgb, interpolation="none") | |
| for m in range(M): | |
| for n in range(N): | |
| if thr_ids[m, n] >= 0: | |
| ax.text( | |
| n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}", | |
| ha="center", va="center", | |
| fontsize=font_size, weight="bold" | |
| ) | |
| ax.set_xticks(np.arange(N + 1) - 0.5) | |
| ax.set_yticks(np.arange(M + 1) - 0.5) | |
| ax.set_xticklabels([str(i) for i in range(N + 1)]) | |
| ax.set_yticklabels([str(i) for i in range(M + 1)]) | |
| ax.tick_params(axis='both', which='both', length=6, width=1) | |
| ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False) | |
| ax.tick_params(axis='y', which='both', left=True, right=False) | |
| ax.grid(which="major", color="black", linewidth=grid_lw) | |
| ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5) | |
| # Format title with colon notation | |
| ax.set_title(f"tv_layout = {shape}:{stride}", fontsize=font_size + 2, pad=12) | |
| plt.tight_layout() | |
| return fig | |
| # visualize_tv_layout((32, 16), (((4,8,2,2),((2,2),(1,1))),((64,1,16,256),((32,8),(0,0)))), dpi=100, max_rows=16, max_cols=16) | |
| # exit(0) | |
| def gradio_visualize(tiler_mn_str, tv_layout_str, dpi, max_rows, max_cols): | |
| """Gradio wrapper for visualize_tv_layout.""" | |
| try: | |
| # Parse input strings | |
| tiler_mn = ast.literal_eval(tiler_mn_str) | |
| # Support colon notation: (128,64):(1,128) or comma notation | |
| if ':' in tv_layout_str: | |
| # Split by colon to get shape and stride parts | |
| parts = tv_layout_str.split(':') | |
| if len(parts) != 2: | |
| raise ValueError("Colon format must be shape:stride") | |
| shape = ast.literal_eval(parts[0]) | |
| stride = ast.literal_eval(parts[1]) | |
| tv_layout = (shape, stride) | |
| else: | |
| # Traditional nested tuple format | |
| tv_layout = ast.literal_eval(tv_layout_str) | |
| fig = visualize_tv_layout(tiler_mn, tv_layout, dpi=dpi, max_rows=max_rows, max_cols=max_cols) | |
| return fig | |
| except Exception as e: | |
| # Return error message | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| ax.text(0.5, 0.5, f"Error: {str(e)}", | |
| ha='center', va='center', fontsize=12, color='red') | |
| ax.axis('off') | |
| return fig | |
| # Create Gradio interface | |
| with gr.Blocks(title="CuTe TV Layout Visualizer") as demo: | |
| gr.Markdown("# CuTe TV Layout Visualizer") | |
| gr.Markdown("Visualize thread/value (T/V) layouts for CuTe tensor operations.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Layout Parameters") | |
| tiler_mn = gr.Textbox( | |
| label="Tiler Dimensions (M, N)", | |
| value="(8, 8)", | |
| placeholder="(8, 8)" | |
| ) | |
| tv_layout = gr.Textbox( | |
| label="TV Layout", | |
| value="((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", | |
| lines=2 | |
| ) | |
| dpi = gr.Number(label="DPI", value=200, precision=0) | |
| max_rows = gr.Number(label="Max Rows (leave empty for no limit)", value=None, precision=0) | |
| max_cols = gr.Number(label="Max Cols (leave empty for no limit)", value=None, precision=0) | |
| visualize_btn = gr.Button("Visualize", variant="primary") | |
| with gr.Column(): | |
| output_plot = gr.Plot(label="TV Layout Visualization") | |
| visualize_btn.click( | |
| fn=gradio_visualize, | |
| inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols], | |
| outputs=output_plot | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["(8, 8)", "((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", 200, None, None], | |
| ["(4, 8)", "((4, 2), 4):((1, 16), 4)", 200, None, None], | |
| ["(8, 4)", "((4, 2), 4):((8, 4), 1)", 200, None, None], | |
| ], | |
| inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |