import gradio as gr from transformers import AutoTokenizer, AutoModel import torch import numpy as np import matplotlib.pyplot as plt from sklearn.decomposition import PCA import plotly.express as px # Predefined models models = { "bert-base-uncased": {"type": "BERT", "layers": 12, "heads": 12, "params": 110e6}, "gpt2": {"type": "GPT-2", "layers": 12, "heads": 12, "params": 117e6}, "distilbert-base-uncased": {"type": "DistilBERT", "layers": 6, "heads": 12, "params": 66e6}, "roberta-base": {"type": "RoBERTa", "layers": 12, "heads": 12, "params": 125e6}, } model_cache = {} def load_model(name): if name not in model_cache: tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModel.from_pretrained(name, output_attentions=True) model_cache[name] = (tokenizer, model) return model_cache[name] def get_model_info(name): info = models[name] return ( info["type"], info["layers"], info["heads"], f"{int(info['params'] / 1e6)}M" ) def show_model_chart(): data = [{"Model": k, "Size (Million Parameters)": v["params"] / 1e6} for k, v in models.items()] fig = px.bar(data, x="Model", y="Size (Million Parameters)", title="Model Size Comparison") return fig def tokenize_input(text, model_name): tokenizer, _ = load_model(model_name) return tokenizer.tokenize(text) def visualize_embeddings(text, model_name): tokenizer, model = load_model(model_name) inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) embeddings = outputs.last_hidden_state.squeeze(0).numpy() tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) pca = PCA(n_components=2) reduced = pca.fit_transform(embeddings) fig = px.scatter( x=reduced[:, 0], y=reduced[:, 1], text=tokens, labels={'x': 'PCA1', 'y': 'PCA2'}, title="Token Embeddings (PCA Projection)" ) fig.update_traces(textposition='top center') return fig def plot_attention(text, model_name, layer_idx, head_idx): tokenizer, model = load_model(model_name) inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) attn = outputs.attentions[layer_idx][0][head_idx].detach().numpy() fig, ax = plt.subplots(figsize=(8, 6)) cax = ax.matshow(attn, cmap="viridis") fig.colorbar(cax) ax.set_xticks(np.arange(len(tokens))) ax.set_yticks(np.arange(len(tokens))) ax.set_xticklabels(tokens, rotation=90) ax.set_yticklabels(tokens) ax.set_title(f"Layer {layer_idx} Head {head_idx}") plt.tight_layout() plt.savefig("attn.png") plt.close() return "attn.png" with gr.Blocks() as demo: gr.Markdown("## 🚀 Transformer Model Explorer\nExplore different transformer models, their architectures, tokenization, and attention mechanisms.") with gr.Row(): model_selector = gr.Dropdown(list(models.keys()), label="Choose a Transformer Model", value="bert-base-uncased") with gr.Column(): gr.Markdown("### 🛠️ Model Details") model_type = gr.Text(label="Model Type") num_layers = gr.Text(label="Number of Layers") num_heads = gr.Text(label="Number of Attention Heads") num_params = gr.Text(label="Total Parameters") model_selector.change(fn=get_model_info, inputs=model_selector, outputs=[model_type, num_layers, num_heads, num_params]) gr.Markdown("### 📊 Model Size Comparison") chart = gr.Plot() chart.render_fn = show_model_chart with gr.Column(): gr.Markdown("### ✨ Tokenization Visualization") input_text = gr.Textbox(label="Enter Text:", value="Hello, how are you?") token_output = gr.JSON() tokenize_btn = gr.Button("Tokenize") tokenize_btn.click(fn=tokenize_input, inputs=[input_text, model_selector], outputs=token_output) gr.Markdown("### 🌟 Token Embeddings Visualization") embedding_plot = gr.Plot() embed_btn = gr.Button("Show Embeddings (PCA)") embed_btn.click(fn=visualize_embeddings, inputs=[input_text, model_selector], outputs=embedding_plot) gr.Markdown("### 🔍 Attention Map") layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer") head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Head") attn_plot = gr.Image(type="filepath") attn_btn = gr.Button("Show Attention") attn_btn.click(fn=plot_attention, inputs=[input_text, model_selector, layer_slider, head_slider], outputs=attn_plot) demo.launch()