mudassir111's picture
Update app.py
1c9f60e verified
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.express as px
# Predefined models
models = {
"bert-base-uncased": {"type": "BERT", "layers": 12, "heads": 12, "params": 110e6},
"gpt2": {"type": "GPT-2", "layers": 12, "heads": 12, "params": 117e6},
"distilbert-base-uncased": {"type": "DistilBERT", "layers": 6, "heads": 12, "params": 66e6},
"roberta-base": {"type": "RoBERTa", "layers": 12, "heads": 12, "params": 125e6},
}
model_cache = {}
def load_model(name):
if name not in model_cache:
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModel.from_pretrained(name, output_attentions=True)
model_cache[name] = (tokenizer, model)
return model_cache[name]
def get_model_info(name):
info = models[name]
return (
info["type"],
info["layers"],
info["heads"],
f"{int(info['params'] / 1e6)}M"
)
def show_model_chart():
data = [{"Model": k, "Size (Million Parameters)": v["params"] / 1e6} for k, v in models.items()]
fig = px.bar(data, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
return fig
def tokenize_input(text, model_name):
tokenizer, _ = load_model(model_name)
return tokenizer.tokenize(text)
def visualize_embeddings(text, model_name):
tokenizer, model = load_model(model_name)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
pca = PCA(n_components=2)
reduced = pca.fit_transform(embeddings)
fig = px.scatter(
x=reduced[:, 0], y=reduced[:, 1],
text=tokens,
labels={'x': 'PCA1', 'y': 'PCA2'},
title="Token Embeddings (PCA Projection)"
)
fig.update_traces(textposition='top center')
return fig
def plot_attention(text, model_name, layer_idx, head_idx):
tokenizer, model = load_model(model_name)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attn = outputs.attentions[layer_idx][0][head_idx].detach().numpy()
fig, ax = plt.subplots(figsize=(8, 6))
cax = ax.matshow(attn, cmap="viridis")
fig.colorbar(cax)
ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title(f"Layer {layer_idx} Head {head_idx}")
plt.tight_layout()
plt.savefig("attn.png")
plt.close()
return "attn.png"
with gr.Blocks() as demo:
gr.Markdown("## πŸš€ Transformer Model Explorer\nExplore different transformer models, their architectures, tokenization, and attention mechanisms.")
with gr.Row():
model_selector = gr.Dropdown(list(models.keys()), label="Choose a Transformer Model", value="bert-base-uncased")
with gr.Column():
gr.Markdown("### πŸ› οΈ Model Details")
model_type = gr.Text(label="Model Type")
num_layers = gr.Text(label="Number of Layers")
num_heads = gr.Text(label="Number of Attention Heads")
num_params = gr.Text(label="Total Parameters")
model_selector.change(fn=get_model_info, inputs=model_selector, outputs=[model_type, num_layers, num_heads, num_params])
gr.Markdown("### πŸ“Š Model Size Comparison")
chart = gr.Plot()
chart.render_fn = show_model_chart
with gr.Column():
gr.Markdown("### ✨ Tokenization Visualization")
input_text = gr.Textbox(label="Enter Text:", value="Hello, how are you?")
token_output = gr.JSON()
tokenize_btn = gr.Button("Tokenize")
tokenize_btn.click(fn=tokenize_input, inputs=[input_text, model_selector], outputs=token_output)
gr.Markdown("### 🌟 Token Embeddings Visualization")
embedding_plot = gr.Plot()
embed_btn = gr.Button("Show Embeddings (PCA)")
embed_btn.click(fn=visualize_embeddings, inputs=[input_text, model_selector], outputs=embedding_plot)
gr.Markdown("### πŸ” Attention Map")
layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer")
head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Head")
attn_plot = gr.Image(type="filepath")
attn_btn = gr.Button("Show Attention")
attn_btn.click(fn=plot_attention, inputs=[input_text, model_selector, layer_slider, head_slider], outputs=attn_plot)
demo.launch()