Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
import plotly.express as px | |
# Predefined models | |
models = { | |
"bert-base-uncased": {"type": "BERT", "layers": 12, "heads": 12, "params": 110e6}, | |
"gpt2": {"type": "GPT-2", "layers": 12, "heads": 12, "params": 117e6}, | |
"distilbert-base-uncased": {"type": "DistilBERT", "layers": 6, "heads": 12, "params": 66e6}, | |
"roberta-base": {"type": "RoBERTa", "layers": 12, "heads": 12, "params": 125e6}, | |
} | |
model_cache = {} | |
def load_model(name): | |
if name not in model_cache: | |
tokenizer = AutoTokenizer.from_pretrained(name) | |
model = AutoModel.from_pretrained(name, output_attentions=True) | |
model_cache[name] = (tokenizer, model) | |
return model_cache[name] | |
def get_model_info(name): | |
info = models[name] | |
return ( | |
info["type"], | |
info["layers"], | |
info["heads"], | |
f"{int(info['params'] / 1e6)}M" | |
) | |
def show_model_chart(): | |
data = [{"Model": k, "Size (Million Parameters)": v["params"] / 1e6} for k, v in models.items()] | |
fig = px.bar(data, x="Model", y="Size (Million Parameters)", title="Model Size Comparison") | |
return fig | |
def tokenize_input(text, model_name): | |
tokenizer, _ = load_model(model_name) | |
return tokenizer.tokenize(text) | |
def visualize_embeddings(text, model_name): | |
tokenizer, model = load_model(model_name) | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state.squeeze(0).numpy() | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
pca = PCA(n_components=2) | |
reduced = pca.fit_transform(embeddings) | |
fig = px.scatter( | |
x=reduced[:, 0], y=reduced[:, 1], | |
text=tokens, | |
labels={'x': 'PCA1', 'y': 'PCA2'}, | |
title="Token Embeddings (PCA Projection)" | |
) | |
fig.update_traces(textposition='top center') | |
return fig | |
def plot_attention(text, model_name, layer_idx, head_idx): | |
tokenizer, model = load_model(model_name) | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
attn = outputs.attentions[layer_idx][0][head_idx].detach().numpy() | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
cax = ax.matshow(attn, cmap="viridis") | |
fig.colorbar(cax) | |
ax.set_xticks(np.arange(len(tokens))) | |
ax.set_yticks(np.arange(len(tokens))) | |
ax.set_xticklabels(tokens, rotation=90) | |
ax.set_yticklabels(tokens) | |
ax.set_title(f"Layer {layer_idx} Head {head_idx}") | |
plt.tight_layout() | |
plt.savefig("attn.png") | |
plt.close() | |
return "attn.png" | |
with gr.Blocks() as demo: | |
gr.Markdown("## π Transformer Model Explorer\nExplore different transformer models, their architectures, tokenization, and attention mechanisms.") | |
with gr.Row(): | |
model_selector = gr.Dropdown(list(models.keys()), label="Choose a Transformer Model", value="bert-base-uncased") | |
with gr.Column(): | |
gr.Markdown("### π οΈ Model Details") | |
model_type = gr.Text(label="Model Type") | |
num_layers = gr.Text(label="Number of Layers") | |
num_heads = gr.Text(label="Number of Attention Heads") | |
num_params = gr.Text(label="Total Parameters") | |
model_selector.change(fn=get_model_info, inputs=model_selector, outputs=[model_type, num_layers, num_heads, num_params]) | |
gr.Markdown("### π Model Size Comparison") | |
chart = gr.Plot() | |
chart.render_fn = show_model_chart | |
with gr.Column(): | |
gr.Markdown("### β¨ Tokenization Visualization") | |
input_text = gr.Textbox(label="Enter Text:", value="Hello, how are you?") | |
token_output = gr.JSON() | |
tokenize_btn = gr.Button("Tokenize") | |
tokenize_btn.click(fn=tokenize_input, inputs=[input_text, model_selector], outputs=token_output) | |
gr.Markdown("### π Token Embeddings Visualization") | |
embedding_plot = gr.Plot() | |
embed_btn = gr.Button("Show Embeddings (PCA)") | |
embed_btn.click(fn=visualize_embeddings, inputs=[input_text, model_selector], outputs=embedding_plot) | |
gr.Markdown("### π Attention Map") | |
layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer") | |
head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Head") | |
attn_plot = gr.Image(type="filepath") | |
attn_btn = gr.Button("Show Attention") | |
attn_btn.click(fn=plot_attention, inputs=[input_text, model_selector, layer_slider, head_slider], outputs=attn_plot) | |
demo.launch() | |