mudassir111's picture
Update app.py
615ecd1 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.express as px
# -------------------
# Model Metadata
# -------------------
models = {
"bert-base-uncased": {"type": "BERT", "layers": 12, "heads": 12, "params": 110e6},
"gpt2": {"type": "GPT-2", "layers": 12, "heads": 12, "params": 117e6},
"distilbert-base-uncased": {"type": "DistilBERT", "layers": 6, "heads": 12, "params": 66e6},
"roberta-base": {"type": "RoBERTa", "layers": 12, "heads": 12, "params": 125e6},
"albert-base-v2": {"type": "ALBERT", "layers": 12, "heads": 12, "params": 12e6},
"bert-base-multilingual-cased": {"type": "Multilingual BERT", "layers": 12, "heads": 12, "params": 177e6},
}
model_cache = {}
def load_model(name):
if name in model_cache:
return model_cache[name]
tokenizer = AutoTokenizer.from_pretrained(name)
if "t5" in name.lower():
model = AutoModel.from_pretrained(name, output_attentions=True)
elif "gpt" in name.lower():
model = AutoModelForCausalLM.from_pretrained(name, output_attentions=True)
else:
model = AutoModel.from_pretrained(name, output_attentions=True)
model_cache[name] = (tokenizer, model)
return tokenizer, model
def get_model_info(name):
try:
info = models[name]
return (
info["type"],
info["layers"],
info["heads"],
f"{int(info['params'] / 1e6)}M"
)
except KeyError:
return "Unknown", "-", "-", "-"
def show_model_chart():
try:
data = [{"Model": k, "Size (Million Parameters)": v["params"] / 1e6} for k, v in models.items()]
fig = px.bar(data, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
return fig
except:
return px.bar(title="Model Chart Unavailable")
def tokenize_input(text, model_name):
tokenizer, _ = load_model(model_name)
return tokenizer.tokenize(text)
def visualize_embeddings(text, model_name):
tokenizer, model = load_model(model_name)
inputs = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs)
if hasattr(outputs, "last_hidden_state"):
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
elif hasattr(outputs, "logits"):
embeddings = outputs.logits.squeeze(0).numpy()
else:
return px.scatter(title="No Embeddings Available")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
if embeddings.shape[0] != len(tokens):
return px.scatter(title="Token count mismatch")
pca = PCA(n_components=2)
reduced = pca.fit_transform(embeddings)
fig = px.scatter(
x=reduced[:, 0], y=reduced[:, 1],
text=tokens,
labels={'x': 'PCA1', 'y': 'PCA2'},
title="Token Embeddings (PCA Projection)"
)
fig.update_traces(textposition='top center')
return fig
def plot_attention(text, model_name, layer_idx, head_idx):
tokenizer, model = load_model(model_name)
inputs = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs)
if not hasattr(outputs, "attentions") or outputs.attentions is None:
return "attention.png"
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
try:
attn = outputs.attentions[layer_idx][0][head_idx].detach().numpy()
except:
return "attention.png"
fig, ax = plt.subplots(figsize=(8, 6))
cax = ax.matshow(attn, cmap="viridis")
fig.colorbar(cax)
ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title(f"Layer {layer_idx} Head {head_idx}")
plt.tight_layout()
plt.savefig("attention.png")
plt.close()
return "attention.png"
# Main update function
def update_all(model_name, text, layer, head):
info = get_model_info(model_name)
chart = show_model_chart()
tokens = tokenize_input(text, model_name)
embedding = visualize_embeddings(text, model_name)
attention = plot_attention(text, model_name, layer, head)
return info + (chart, tokens, embedding, attention)
# -------------------
# Gradio UI
# -------------------
with gr.Blocks() as demo:
gr.Markdown("## πŸš€ Transformer Model Explorer\nExplore transformer internals interactively – no clicks needed!")
with gr.Row():
model_selector = gr.Dropdown(choices=list(models.keys()), value="bert-base-uncased", label="Choose Model")
input_text = gr.Textbox(label="Enter Text", value="Hello, how are you?")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ› οΈ Model Details")
model_type = gr.Text(label="Model Type", interactive=False)
num_layers = gr.Text(label="Number of Layers", interactive=False)
num_heads = gr.Text(label="Number of Attention Heads", interactive=False)
num_params = gr.Text(label="Total Parameters", interactive=False)
with gr.Column():
gr.Markdown("### πŸ“Š Model Size Comparison")
model_plot = gr.Plot()
gr.Markdown("### ✨ Tokenization")
token_output = gr.JSON()
gr.Markdown("### 🌟 Embeddings (PCA)")
embedding_plot = gr.Plot()
gr.Markdown("### πŸ” Attention Map")
layer_slider = gr.Slider(minimum=0, maximum=11, value=0, label="Layer")
head_slider = gr.Slider(minimum=0, maximum=11, value=0, label="Head")
attn_plot = gr.Image(type="filepath", label="Attention Map")
# Connect all reactive triggers
for component in [model_selector, input_text, layer_slider, head_slider]:
component.change(
fn=update_all,
inputs=[model_selector, input_text, layer_slider, head_slider],
outputs=[
model_type, num_layers, num_heads, num_params,
model_plot, token_output, embedding_plot, attn_plot
]
)
# Trigger once at app load
demo.load(
fn=update_all,
inputs=[model_selector, input_text, layer_slider, head_slider],
outputs=[
model_type, num_layers, num_heads, num_params,
model_plot, token_output, embedding_plot, attn_plot
]
)
demo.launch()