|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.decomposition import PCA |
|
import plotly.express as px |
|
|
|
|
|
|
|
|
|
models = { |
|
"bert-base-uncased": {"type": "BERT", "layers": 12, "heads": 12, "params": 110e6}, |
|
"gpt2": {"type": "GPT-2", "layers": 12, "heads": 12, "params": 117e6}, |
|
"distilbert-base-uncased": {"type": "DistilBERT", "layers": 6, "heads": 12, "params": 66e6}, |
|
"roberta-base": {"type": "RoBERTa", "layers": 12, "heads": 12, "params": 125e6}, |
|
"albert-base-v2": {"type": "ALBERT", "layers": 12, "heads": 12, "params": 12e6}, |
|
"bert-base-multilingual-cased": {"type": "Multilingual BERT", "layers": 12, "heads": 12, "params": 177e6}, |
|
|
|
} |
|
|
|
model_cache = {} |
|
|
|
def load_model(name): |
|
if name in model_cache: |
|
return model_cache[name] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(name) |
|
|
|
if "t5" in name.lower(): |
|
model = AutoModel.from_pretrained(name, output_attentions=True) |
|
elif "gpt" in name.lower(): |
|
model = AutoModelForCausalLM.from_pretrained(name, output_attentions=True) |
|
else: |
|
model = AutoModel.from_pretrained(name, output_attentions=True) |
|
|
|
model_cache[name] = (tokenizer, model) |
|
return tokenizer, model |
|
|
|
def get_model_info(name): |
|
try: |
|
info = models[name] |
|
return ( |
|
info["type"], |
|
info["layers"], |
|
info["heads"], |
|
f"{int(info['params'] / 1e6)}M" |
|
) |
|
except KeyError: |
|
return "Unknown", "-", "-", "-" |
|
|
|
def show_model_chart(): |
|
try: |
|
data = [{"Model": k, "Size (Million Parameters)": v["params"] / 1e6} for k, v in models.items()] |
|
fig = px.bar(data, x="Model", y="Size (Million Parameters)", title="Model Size Comparison") |
|
return fig |
|
except: |
|
return px.bar(title="Model Chart Unavailable") |
|
|
|
def tokenize_input(text, model_name): |
|
tokenizer, _ = load_model(model_name) |
|
return tokenizer.tokenize(text) |
|
|
|
def visualize_embeddings(text, model_name): |
|
tokenizer, model = load_model(model_name) |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
if hasattr(outputs, "last_hidden_state"): |
|
embeddings = outputs.last_hidden_state.squeeze(0).numpy() |
|
elif hasattr(outputs, "logits"): |
|
embeddings = outputs.logits.squeeze(0).numpy() |
|
else: |
|
return px.scatter(title="No Embeddings Available") |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
if embeddings.shape[0] != len(tokens): |
|
return px.scatter(title="Token count mismatch") |
|
|
|
pca = PCA(n_components=2) |
|
reduced = pca.fit_transform(embeddings) |
|
|
|
fig = px.scatter( |
|
x=reduced[:, 0], y=reduced[:, 1], |
|
text=tokens, |
|
labels={'x': 'PCA1', 'y': 'PCA2'}, |
|
title="Token Embeddings (PCA Projection)" |
|
) |
|
fig.update_traces(textposition='top center') |
|
return fig |
|
|
|
def plot_attention(text, model_name, layer_idx, head_idx): |
|
tokenizer, model = load_model(model_name) |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
if not hasattr(outputs, "attentions") or outputs.attentions is None: |
|
return "attention.png" |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
try: |
|
attn = outputs.attentions[layer_idx][0][head_idx].detach().numpy() |
|
except: |
|
return "attention.png" |
|
|
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
cax = ax.matshow(attn, cmap="viridis") |
|
fig.colorbar(cax) |
|
ax.set_xticks(np.arange(len(tokens))) |
|
ax.set_yticks(np.arange(len(tokens))) |
|
ax.set_xticklabels(tokens, rotation=90) |
|
ax.set_yticklabels(tokens) |
|
ax.set_title(f"Layer {layer_idx} Head {head_idx}") |
|
plt.tight_layout() |
|
plt.savefig("attention.png") |
|
plt.close() |
|
return "attention.png" |
|
|
|
|
|
def update_all(model_name, text, layer, head): |
|
info = get_model_info(model_name) |
|
chart = show_model_chart() |
|
tokens = tokenize_input(text, model_name) |
|
embedding = visualize_embeddings(text, model_name) |
|
attention = plot_attention(text, model_name, layer, head) |
|
return info + (chart, tokens, embedding, attention) |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π Transformer Model Explorer\nExplore transformer internals interactively β no clicks needed!") |
|
|
|
with gr.Row(): |
|
model_selector = gr.Dropdown(choices=list(models.keys()), value="bert-base-uncased", label="Choose Model") |
|
input_text = gr.Textbox(label="Enter Text", value="Hello, how are you?") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### π οΈ Model Details") |
|
model_type = gr.Text(label="Model Type", interactive=False) |
|
num_layers = gr.Text(label="Number of Layers", interactive=False) |
|
num_heads = gr.Text(label="Number of Attention Heads", interactive=False) |
|
num_params = gr.Text(label="Total Parameters", interactive=False) |
|
with gr.Column(): |
|
gr.Markdown("### π Model Size Comparison") |
|
model_plot = gr.Plot() |
|
|
|
gr.Markdown("### β¨ Tokenization") |
|
token_output = gr.JSON() |
|
|
|
gr.Markdown("### π Embeddings (PCA)") |
|
embedding_plot = gr.Plot() |
|
|
|
gr.Markdown("### π Attention Map") |
|
layer_slider = gr.Slider(minimum=0, maximum=11, value=0, label="Layer") |
|
head_slider = gr.Slider(minimum=0, maximum=11, value=0, label="Head") |
|
attn_plot = gr.Image(type="filepath", label="Attention Map") |
|
|
|
|
|
for component in [model_selector, input_text, layer_slider, head_slider]: |
|
component.change( |
|
fn=update_all, |
|
inputs=[model_selector, input_text, layer_slider, head_slider], |
|
outputs=[ |
|
model_type, num_layers, num_heads, num_params, |
|
model_plot, token_output, embedding_plot, attn_plot |
|
] |
|
) |
|
|
|
|
|
demo.load( |
|
fn=update_all, |
|
inputs=[model_selector, input_text, layer_slider, head_slider], |
|
outputs=[ |
|
model_type, num_layers, num_heads, num_params, |
|
model_plot, token_output, embedding_plot, attn_plot |
|
] |
|
) |
|
|
|
demo.launch() |
|
|