Pranesh64's picture
Update app.py
217598f verified
# ==========================================================
# TensorFlow Computation Graph Visualizer (Advanced)
# - Standard TensorFlow (Keras) based
# - Gradio 5 compatible (no theme=)
# - CPU-friendly (disables GPU usage)
# ==========================================================
import io
import os
import math
import traceback
import warnings
from typing import Any, Dict, List, Optional
import gradio as gr
import numpy as np
from PIL import Image
import plotly.graph_objects as go
import networkx as nx
warnings.filterwarnings("ignore")
# Try importing tensorflow
try:
import tensorflow as tf
from tensorflow import keras
try:
tf.config.set_visible_devices([], "GPU")
except Exception:
pass
TF_AVAILABLE = True
except Exception as e:
TF_AVAILABLE = False
TF_IMPORT_ERROR = str(e)
# -------------------- Helpers --------------------
def safe_load_keras_model(fileobj: Optional[io.BytesIO], chosen: str):
if not TF_AVAILABLE:
raise RuntimeError("TensorFlow not available")
if fileobj:
fileobj.seek(0)
tmp_path = "/tmp/uploaded_model.h5"
with open(tmp_path, "wb") as f:
f.write(fileobj.read())
model = keras.models.load_model(tmp_path)
return model, "uploaded .h5 model"
if chosen == "small_cnn":
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(64, 64, 3)),
keras.layers.Conv2D(16, 3, activation="relu", padding="same"),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(32, 3, activation="relu", padding="same"),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(64, 3, activation="relu", padding="same"),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(10, activation="softmax"),
])
model.build((None, 64, 64, 3))
return model, "Small CNN (example)"
if chosen == "toy_resnet":
inputs = keras.Input(shape=(64, 64, 3))
x = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs)
for _ in range(2):
y = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(x)
y = keras.layers.Conv2D(32, 3, padding="same")(y)
x = keras.layers.add([x, y])
x = keras.layers.ReLU()(x)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs, outputs)
model.build((None, 64, 64, 3))
return model, "Toy ResNet-like (example)"
return safe_load_keras_model(None, "small_cnn")
def model_summary_str(model):
stream = io.StringIO()
model.summary(print_fn=lambda s: stream.write(s + "\n"))
return stream.getvalue()
# -------------------- Graph builder --------------------
def build_layer_graph(model):
G = nx.DiGraph()
for layer in model.layers:
inbound = []
for node in getattr(layer, "_inbound_nodes", []) or []:
for l in getattr(node, "inbound_layers", []) or []:
inbound.append(l.name)
G.add_node(
layer.name,
class_name=layer.__class__.__name__,
input_shape=getattr(layer, "input_shape", None),
output_shape=getattr(layer, "output_shape", None),
params=layer.count_params(),
inbound_layers=inbound,
)
for n, d in G.nodes(data=True):
for src in d["inbound_layers"]:
if src in G:
G.add_edge(src, n)
return G
def nx_to_plotly_fig(G):
pos = nx.spring_layout(G, seed=42)
edge_x, edge_y = [], []
for u, v in G.edges():
x0, y0 = pos[u]
x1, y1 = pos[v]
edge_x += [x0, x1, None]
edge_y += [y0, y1, None]
node_x, node_y, labels = [], [], []
for n in G.nodes():
x, y = pos[n]
node_x.append(x)
node_y.append(y)
labels.append(n)
fig = go.Figure()
fig.add_trace(go.Scatter(x=edge_x, y=edge_y, mode="lines"))
fig.add_trace(go.Scatter(x=node_x, y=node_y, mode="markers+text", text=labels))
fig.update_layout(height=600, showlegend=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
# -------------------- Inspect --------------------
def node_inspect_callback(state, node_name):
if not state:
return "No model loaded.", None, None
model = state["model"]
layer = model.get_layer(node_name)
weights = layer.get_weights()
hist_fig = None
img = None
if weights:
w = weights[0]
hist = np.histogram(w.flatten(), bins=50)
hist_fig = go.Figure(go.Bar(x=hist[1][:-1], y=hist[0]))
if w.ndim == 4:
ch = w[:, :, :, 0].mean(axis=2)
ch = (ch - ch.min()) / (ch.ptp() + 1e-6)
img = Image.fromarray((ch * 255).astype("uint8")).resize((256, 256))
txt = (
f"**Layer:** {layer.name}\n\n"
f"- Type: `{layer.__class__.__name__}`\n"
f"- Input: `{layer.input_shape}`\n"
f"- Output: `{layer.output_shape}`\n"
f"- Params: `{layer.count_params()}`"
)
return txt, img, hist_fig
# -------------------- UI --------------------
with gr.Blocks() as demo:
gr.Markdown("# 🔎 TensorFlow Computation Graph Visualizer")
with gr.Row():
with gr.Column(scale=1):
model_file = gr.File(label="Upload .h5")
example = gr.Dropdown(["small_cnn", "toy_resnet"], value="small_cnn")
load_btn = gr.Button("Load model")
summary = gr.Textbox(lines=12)
params = gr.Textbox()
error = gr.Markdown()
with gr.Column(scale=2):
graph_plot = gr.Plot()
layer_select = gr.Dropdown(label="Select layer to inspect")
node_info = gr.Markdown()
weights_img = gr.Image()
weights_hist = gr.Plot()
state = gr.State()
def on_load(file, ex):
model, _ = safe_load_keras_model(file, ex)
G = build_layer_graph(model)
fig = nx_to_plotly_fig(G)
return (
{"model": model, "graph": G},
fig,
model_summary_str(model),
str(model.count_params()),
"",
list(G.nodes())
)
load_btn.click(
on_load,
inputs=[model_file, example],
outputs=[state, graph_plot, summary, params, error, layer_select]
)
layer_select.change(
node_inspect_callback,
inputs=[state, layer_select],
outputs=[node_info, weights_img, weights_hist]
)
demo.launch()