Spaces:
Sleeping
Sleeping
| # ========================================================== | |
| # TensorFlow Computation Graph Visualizer (Advanced) | |
| # - Standard TensorFlow (Keras) based | |
| # - Gradio 5 compatible (no theme=) | |
| # - CPU-friendly (disables GPU usage) | |
| # ========================================================== | |
| import io | |
| import os | |
| import math | |
| import traceback | |
| import warnings | |
| from typing import Any, Dict, List, Optional | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| import networkx as nx | |
| warnings.filterwarnings("ignore") | |
| # Try importing tensorflow | |
| try: | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| try: | |
| tf.config.set_visible_devices([], "GPU") | |
| except Exception: | |
| pass | |
| TF_AVAILABLE = True | |
| except Exception as e: | |
| TF_AVAILABLE = False | |
| TF_IMPORT_ERROR = str(e) | |
| # -------------------- Helpers -------------------- | |
| def safe_load_keras_model(fileobj: Optional[io.BytesIO], chosen: str): | |
| if not TF_AVAILABLE: | |
| raise RuntimeError("TensorFlow not available") | |
| if fileobj: | |
| fileobj.seek(0) | |
| tmp_path = "/tmp/uploaded_model.h5" | |
| with open(tmp_path, "wb") as f: | |
| f.write(fileobj.read()) | |
| model = keras.models.load_model(tmp_path) | |
| return model, "uploaded .h5 model" | |
| if chosen == "small_cnn": | |
| model = keras.Sequential([ | |
| keras.layers.InputLayer(input_shape=(64, 64, 3)), | |
| keras.layers.Conv2D(16, 3, activation="relu", padding="same"), | |
| keras.layers.MaxPool2D(), | |
| keras.layers.Conv2D(32, 3, activation="relu", padding="same"), | |
| keras.layers.MaxPool2D(), | |
| keras.layers.Conv2D(64, 3, activation="relu", padding="same"), | |
| keras.layers.GlobalAveragePooling2D(), | |
| keras.layers.Dense(64, activation="relu"), | |
| keras.layers.Dense(10, activation="softmax"), | |
| ]) | |
| model.build((None, 64, 64, 3)) | |
| return model, "Small CNN (example)" | |
| if chosen == "toy_resnet": | |
| inputs = keras.Input(shape=(64, 64, 3)) | |
| x = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs) | |
| for _ in range(2): | |
| y = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(x) | |
| y = keras.layers.Conv2D(32, 3, padding="same")(y) | |
| x = keras.layers.add([x, y]) | |
| x = keras.layers.ReLU()(x) | |
| x = keras.layers.GlobalAveragePooling2D()(x) | |
| outputs = keras.layers.Dense(5, activation="softmax")(x) | |
| model = keras.Model(inputs, outputs) | |
| model.build((None, 64, 64, 3)) | |
| return model, "Toy ResNet-like (example)" | |
| return safe_load_keras_model(None, "small_cnn") | |
| def model_summary_str(model): | |
| stream = io.StringIO() | |
| model.summary(print_fn=lambda s: stream.write(s + "\n")) | |
| return stream.getvalue() | |
| # -------------------- Graph builder -------------------- | |
| def build_layer_graph(model): | |
| G = nx.DiGraph() | |
| for layer in model.layers: | |
| inbound = [] | |
| for node in getattr(layer, "_inbound_nodes", []) or []: | |
| for l in getattr(node, "inbound_layers", []) or []: | |
| inbound.append(l.name) | |
| G.add_node( | |
| layer.name, | |
| class_name=layer.__class__.__name__, | |
| input_shape=getattr(layer, "input_shape", None), | |
| output_shape=getattr(layer, "output_shape", None), | |
| params=layer.count_params(), | |
| inbound_layers=inbound, | |
| ) | |
| for n, d in G.nodes(data=True): | |
| for src in d["inbound_layers"]: | |
| if src in G: | |
| G.add_edge(src, n) | |
| return G | |
| def nx_to_plotly_fig(G): | |
| pos = nx.spring_layout(G, seed=42) | |
| edge_x, edge_y = [], [] | |
| for u, v in G.edges(): | |
| x0, y0 = pos[u] | |
| x1, y1 = pos[v] | |
| edge_x += [x0, x1, None] | |
| edge_y += [y0, y1, None] | |
| node_x, node_y, labels = [], [], [] | |
| for n in G.nodes(): | |
| x, y = pos[n] | |
| node_x.append(x) | |
| node_y.append(y) | |
| labels.append(n) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=edge_x, y=edge_y, mode="lines")) | |
| fig.add_trace(go.Scatter(x=node_x, y=node_y, mode="markers+text", text=labels)) | |
| fig.update_layout(height=600, showlegend=False) | |
| fig.update_xaxes(visible=False) | |
| fig.update_yaxes(visible=False) | |
| return fig | |
| # -------------------- Inspect -------------------- | |
| def node_inspect_callback(state, node_name): | |
| if not state: | |
| return "No model loaded.", None, None | |
| model = state["model"] | |
| layer = model.get_layer(node_name) | |
| weights = layer.get_weights() | |
| hist_fig = None | |
| img = None | |
| if weights: | |
| w = weights[0] | |
| hist = np.histogram(w.flatten(), bins=50) | |
| hist_fig = go.Figure(go.Bar(x=hist[1][:-1], y=hist[0])) | |
| if w.ndim == 4: | |
| ch = w[:, :, :, 0].mean(axis=2) | |
| ch = (ch - ch.min()) / (ch.ptp() + 1e-6) | |
| img = Image.fromarray((ch * 255).astype("uint8")).resize((256, 256)) | |
| txt = ( | |
| f"**Layer:** {layer.name}\n\n" | |
| f"- Type: `{layer.__class__.__name__}`\n" | |
| f"- Input: `{layer.input_shape}`\n" | |
| f"- Output: `{layer.output_shape}`\n" | |
| f"- Params: `{layer.count_params()}`" | |
| ) | |
| return txt, img, hist_fig | |
| # -------------------- UI -------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🔎 TensorFlow Computation Graph Visualizer") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_file = gr.File(label="Upload .h5") | |
| example = gr.Dropdown(["small_cnn", "toy_resnet"], value="small_cnn") | |
| load_btn = gr.Button("Load model") | |
| summary = gr.Textbox(lines=12) | |
| params = gr.Textbox() | |
| error = gr.Markdown() | |
| with gr.Column(scale=2): | |
| graph_plot = gr.Plot() | |
| layer_select = gr.Dropdown(label="Select layer to inspect") | |
| node_info = gr.Markdown() | |
| weights_img = gr.Image() | |
| weights_hist = gr.Plot() | |
| state = gr.State() | |
| def on_load(file, ex): | |
| model, _ = safe_load_keras_model(file, ex) | |
| G = build_layer_graph(model) | |
| fig = nx_to_plotly_fig(G) | |
| return ( | |
| {"model": model, "graph": G}, | |
| fig, | |
| model_summary_str(model), | |
| str(model.count_params()), | |
| "", | |
| list(G.nodes()) | |
| ) | |
| load_btn.click( | |
| on_load, | |
| inputs=[model_file, example], | |
| outputs=[state, graph_plot, summary, params, error, layer_select] | |
| ) | |
| layer_select.change( | |
| node_inspect_callback, | |
| inputs=[state, layer_select], | |
| outputs=[node_info, weights_img, weights_hist] | |
| ) | |
| demo.launch() |