Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| /** | |
| * Token Flow Visualizer Component | |
| * | |
| * Visualizes how tokens flow through transformer layers, | |
| * showing attention paths and information propagation | |
| * | |
| * @component | |
| */ | |
| "use client"; | |
| import { useState, useEffect, useRef } from "react"; | |
| import * as d3 from "d3"; | |
| import { getApiUrl, getWsUrl } from "@/lib/config"; | |
| import { | |
| GitBranch, | |
| Activity, | |
| Layers, | |
| Play, | |
| Pause, | |
| RotateCcw, | |
| ZoomIn, | |
| ZoomOut, | |
| Download, | |
| Info, | |
| HelpCircle, | |
| X, | |
| Zap, | |
| RefreshCw | |
| } from "lucide-react"; | |
| // Token data structure | |
| interface Token { | |
| id: string; | |
| text: string; | |
| position: number; | |
| embedding?: number[]; | |
| } | |
| // Layer data structure | |
| interface LayerData { | |
| layerIndex: number; | |
| layerName: string; | |
| tokens: TokenState[]; | |
| attention: number[][]; | |
| timestamp: number; | |
| } | |
| // Token state at a specific layer | |
| interface TokenState { | |
| tokenId: string; | |
| text: string; | |
| position: number; | |
| activation: number; | |
| attention_received: number; | |
| attention_given: number; | |
| importance: number; | |
| } | |
| // Flow connection between tokens across layers | |
| interface FlowConnection { | |
| source: { layer: number; token: number }; | |
| target: { layer: number; token: number }; | |
| strength: number; | |
| type: 'attention' | 'residual' | 'feedforward'; | |
| } | |
| export default function TokenFlowVisualizer() { | |
| const [tokens, setTokens] = useState<Token[]>([]); | |
| const [layers, setLayers] = useState<LayerData[]>([]); | |
| const [flowConnections, setFlowConnections] = useState<FlowConnection[]>([]); | |
| const [selectedToken, setSelectedToken] = useState<number | null>(null); | |
| const [selectedLayer, setSelectedLayer] = useState<number | null>(null); | |
| const [isPlaying, setIsPlaying] = useState(false); | |
| const [currentStep, setCurrentStep] = useState(0); | |
| const [zoom, setZoom] = useState(1); | |
| const [showResidual, setShowResidual] = useState(true); | |
| const [showAttention, setShowAttention] = useState(true); | |
| const [showExplanation, setShowExplanation] = useState(false); | |
| const [isConnected, setIsConnected] = useState(false); | |
| const [prompt, setPrompt] = useState("def fibonacci(n):\n '''Calculate fibonacci number'''"); | |
| const [isGenerating, setIsGenerating] = useState(false); | |
| const [traces, setTraces] = useState<Record<string, unknown>[]>([]); | |
| const svgRef = useRef<SVGSVGElement>(null); | |
| const animationRef = useRef<number | null>(null); | |
| const wsRef = useRef<WebSocket | null>(null); | |
| // Connect to WebSocket for real-time updates | |
| useEffect(() => { | |
| let mounted = true; | |
| let reconnectTimeout: NodeJS.Timeout; | |
| const connectWS = () => { | |
| if (!mounted) return; | |
| try { | |
| const ws = new WebSocket(getWsUrl()); | |
| ws.onopen = () => { | |
| if (!mounted) return; | |
| console.log('TokenFlow: WebSocket connected'); | |
| setIsConnected(true); | |
| }; | |
| ws.onmessage = (event) => { | |
| if (!mounted) return; | |
| let data; | |
| try { | |
| data = JSON.parse(event.data); | |
| } catch (e) { | |
| // Skip non-JSON messages | |
| return; | |
| } | |
| // Collect traces for visualization | |
| if (data.type === 'attention' || data.type === 'activation') { | |
| setTraces(prev => [...prev, data]); | |
| } else if (data.type === 'generated_token') { | |
| // Handle token generation | |
| setTokens(prev => { | |
| const newToken: Token = { | |
| id: `token_${prev.length}`, | |
| text: data.token, | |
| position: prev.length | |
| }; | |
| return [...prev, newToken]; | |
| }); | |
| } | |
| }; | |
| ws.onerror = () => { | |
| if (mounted) { | |
| setIsConnected(false); | |
| } | |
| }; | |
| ws.onclose = () => { | |
| if (!mounted) return; | |
| console.log('TokenFlow: WebSocket disconnected, will reconnect...'); | |
| setIsConnected(false); | |
| reconnectTimeout = setTimeout(() => { | |
| if (mounted) connectWS(); | |
| }, 3000); | |
| }; | |
| wsRef.current = ws; | |
| } catch (error) { | |
| console.log('WebSocket connection attempt failed, will retry...'); | |
| if (mounted) { | |
| setIsConnected(false); | |
| reconnectTimeout = setTimeout(() => { | |
| if (mounted) connectWS(); | |
| }, 3000); | |
| } | |
| } | |
| }; | |
| connectWS(); | |
| return () => { | |
| mounted = false; | |
| if (reconnectTimeout) { | |
| clearTimeout(reconnectTimeout); | |
| } | |
| if (wsRef.current) { | |
| wsRef.current.close(); | |
| } | |
| }; | |
| }, []); | |
| // Listen for demo events from LocalControlPanel | |
| useEffect(() => { | |
| const handleDemoPromptSelected = (event: CustomEvent) => { | |
| const { prompt, demoId } = event.detail; | |
| console.log('TokenFlow: Demo prompt selected -', demoId); | |
| if (prompt) { | |
| setPrompt(prompt); | |
| } | |
| }; | |
| const handleDemoStarting = (event: CustomEvent) => { | |
| const { demoId } = event.detail; | |
| console.log('TokenFlow: Demo starting, clearing data -', demoId); | |
| // Clear all data when demo starts | |
| setTokens([]); | |
| setLayers([]); | |
| setFlowConnections([]); | |
| setTraces([]); | |
| setSelectedToken(null); | |
| setSelectedLayer(null); | |
| }; | |
| const handleDemoCompleted = (event: CustomEvent) => { | |
| const data = event.detail; | |
| console.log('TokenFlow: Demo completed', data); | |
| // Process the completed demo data | |
| if (data && data.traces) { | |
| setTraces(data.traces); | |
| } | |
| }; | |
| window.addEventListener('demo-prompt-selected', handleDemoPromptSelected as EventListener); | |
| window.addEventListener('demo-starting', handleDemoStarting as EventListener); | |
| window.addEventListener('demo-completed', handleDemoCompleted as EventListener); | |
| return () => { | |
| window.removeEventListener('demo-prompt-selected', handleDemoPromptSelected as EventListener); | |
| window.removeEventListener('demo-starting', handleDemoStarting as EventListener); | |
| window.removeEventListener('demo-completed', handleDemoCompleted as EventListener); | |
| }; | |
| }, []); | |
| // Process traces to extract token flow data | |
| useEffect(() => { | |
| const attentionTraces = traces.filter(t => t.type === 'attention' && t.weights); | |
| const activationTraces = traces.filter(t => t.type === 'activation'); | |
| console.log('[TokenFlow] Total traces:', traces.length); | |
| console.log('[TokenFlow] Attention traces:', attentionTraces.length); | |
| console.log('[TokenFlow] First attention trace:', attentionTraces[0]); | |
| if (attentionTraces.length > 0 || tokens.length > 0) { | |
| // Use existing tokens if available (from streaming), otherwise extract from traces | |
| if (tokens.length === 0) { | |
| const traceWithTokens = attentionTraces.find(t => t.tokens); | |
| if (traceWithTokens?.tokens && Array.isArray(traceWithTokens.tokens)) { | |
| const extractedTokens: Token[] = (traceWithTokens.tokens as string[]).map((text: string, idx: number) => ({ | |
| id: `token_${idx}`, | |
| text, | |
| position: idx | |
| })); | |
| setTokens(extractedTokens); | |
| } | |
| } | |
| // Build layer data | |
| const layerMap = new Map<string, Record<string, unknown>[]>(); | |
| attentionTraces.forEach(trace => { | |
| const layer = String(trace.layer || 'unknown'); | |
| if (!layerMap.has(layer)) { | |
| layerMap.set(layer, []); | |
| } | |
| layerMap.get(layer)?.push(trace); | |
| }); | |
| const layerDataArray: LayerData[] = Array.from(layerMap.entries()) | |
| .map(([layerName, traces], idx) => { | |
| const latestTrace = traces[traces.length - 1]; | |
| const weights = (latestTrace.weights || []) as number[][]; | |
| // Calculate token states for this layer | |
| // Use the tokens we have collected (either from streaming or from traces) | |
| const tokenTexts = tokens.length > 0 ? tokens.map(t => t.text) : ((latestTrace.tokens || []) as string[]); | |
| const tokenStates: TokenState[] = tokenTexts.map((text: string, tokenIdx: number) => { | |
| // Calculate attention received (sum of column) | |
| const attention_received = weights.reduce((sum: number, row: number[]) => | |
| sum + (row[tokenIdx] || 0), 0 | |
| ); | |
| // Calculate attention given (sum of row) | |
| const attention_given = weights[tokenIdx]?.reduce((sum: number, val: number) => | |
| sum + val, 0 | |
| ) || 0; | |
| // Calculate importance as combination of received and given attention | |
| const importance = (attention_received + attention_given) / 2; | |
| return { | |
| tokenId: `token_${tokenIdx}`, | |
| text, | |
| position: tokenIdx, | |
| activation: Math.random(), // Would come from activation traces | |
| attention_received, | |
| attention_given, | |
| importance | |
| }; | |
| }); | |
| return { | |
| layerIndex: idx, | |
| layerName, | |
| tokens: tokenStates, | |
| attention: weights, | |
| timestamp: (latestTrace.timestamp || Date.now()) as number | |
| }; | |
| }) | |
| .sort((a, b) => { | |
| // Extract layer numbers for proper numerical sorting | |
| const aNum = parseInt(a.layerName.replace(/[^0-9]/g, '')) || 0; | |
| const bNum = parseInt(b.layerName.replace(/[^0-9]/g, '')) || 0; | |
| return aNum - bNum; | |
| }); | |
| setLayers(layerDataArray); | |
| // Generate flow connections | |
| generateFlowConnections(layerDataArray); | |
| } | |
| }, [traces, tokens]); | |
| // Generate flow connections between layers | |
| const generateFlowConnections = (layerData: LayerData[]) => { | |
| const connections: FlowConnection[] = []; | |
| for (let i = 0; i < layerData.length - 1; i++) { | |
| const currentLayer = layerData[i]; | |
| const nextLayer = layerData[i + 1]; | |
| // Add attention connections | |
| if (currentLayer.attention && showAttention) { | |
| currentLayer.attention.forEach((row, srcToken) => { | |
| row.forEach((weight, tgtToken) => { | |
| if (weight > 0.1) { // Threshold for visibility | |
| connections.push({ | |
| source: { layer: i, token: srcToken }, | |
| target: { layer: i + 1, token: tgtToken }, | |
| strength: weight, | |
| type: 'attention' | |
| }); | |
| } | |
| }); | |
| }); | |
| } | |
| // Add residual connections | |
| if (showResidual) { | |
| currentLayer.tokens.forEach((token, idx) => { | |
| if (idx < nextLayer.tokens.length) { | |
| connections.push({ | |
| source: { layer: i, token: idx }, | |
| target: { layer: i + 1, token: idx }, | |
| strength: 0.5, | |
| type: 'residual' | |
| }); | |
| } | |
| }); | |
| } | |
| } | |
| setFlowConnections(connections); | |
| }; | |
| // D3 Visualization | |
| useEffect(() => { | |
| if (!svgRef.current || layers.length === 0) return; | |
| const margin = { top: 60, right: 200, bottom: 60, left: 100 }; | |
| const width = 1300; | |
| const height = 600; | |
| // Clear previous visualization | |
| d3.select(svgRef.current).selectAll("*").remove(); | |
| const svg = d3.select(svgRef.current) | |
| .attr("width", width) | |
| .attr("height", height) | |
| .attr("viewBox", `0 0 ${width} ${height}`); | |
| const g = svg.append("g") | |
| .attr("transform", `translate(${margin.left},${margin.top}) scale(${zoom})`); | |
| // Calculate positions | |
| const layerWidth = (width - margin.left - margin.right) / (layers.length || 1); | |
| const tokenHeight = 40; | |
| const tokenWidth = 80; | |
| // Create layer groups | |
| const layerGroups = g.selectAll(".layer-group") | |
| .data(layers) | |
| .enter() | |
| .append("g") | |
| .attr("class", "layer-group") | |
| .attr("transform", (d, i) => `translate(${i * layerWidth}, 0)`); | |
| // Add layer labels | |
| layerGroups.append("text") | |
| .attr("x", layerWidth / 2) | |
| .attr("y", -20) | |
| .attr("text-anchor", "middle") | |
| .attr("fill", "#9ca3af") | |
| .attr("font-size", "12px") | |
| .attr("font-weight", "bold") | |
| .text(d => d.layerName); | |
| // Function to get token position | |
| const getTokenPosition = (layerIdx: number, tokenIdx: number) => { | |
| const x = layerIdx * layerWidth + layerWidth / 2; | |
| const y = tokenIdx * (tokenHeight + 10) + tokenHeight / 2; | |
| return { x, y }; | |
| }; | |
| // Draw flow connections | |
| const connectionPaths = g.selectAll(".flow-connection") | |
| .data(flowConnections) | |
| .enter() | |
| .append("path") | |
| .attr("class", d => `flow-connection flow-${d.type}`) | |
| .attr("d", d => { | |
| const source = getTokenPosition(d.source.layer, d.source.token); | |
| const target = getTokenPosition(d.target.layer, d.target.token); | |
| // Create curved path | |
| const midX = (source.x + target.x) / 2; | |
| return `M ${source.x} ${source.y} Q ${midX} ${source.y} ${midX} ${(source.y + target.y) / 2} T ${target.x} ${target.y}`; | |
| }) | |
| .attr("stroke", d => { | |
| if (d.type === 'attention') return "#3b82f6"; | |
| if (d.type === 'residual') return "#10b981"; | |
| return "#8b5cf6"; | |
| }) | |
| .attr("stroke-width", d => Math.max(0.5, d.strength * 3)) | |
| .attr("stroke-opacity", d => d.strength * 0.6) | |
| .attr("fill", "none"); | |
| // Add animation to connections if playing | |
| if (isPlaying) { | |
| connectionPaths | |
| .attr("stroke-dasharray", "5,5") | |
| .append("animate") | |
| .attr("attributeName", "stroke-dashoffset") | |
| .attr("from", "10") | |
| .attr("to", "0") | |
| .attr("dur", "1s") | |
| .attr("repeatCount", "indefinite"); | |
| } | |
| // Draw tokens | |
| const tokenGroups = layerGroups.selectAll(".token") | |
| .data(d => d.tokens) | |
| .enter() | |
| .append("g") | |
| .attr("class", "token-group") | |
| .attr("transform", (d, i) => { | |
| const pos = getTokenPosition(0, i); | |
| return `translate(${layerWidth / 2 - tokenWidth / 2}, ${i * (tokenHeight + 10)})`; | |
| }); | |
| // Token rectangles | |
| tokenGroups.append("rect") | |
| .attr("width", tokenWidth) | |
| .attr("height", tokenHeight) | |
| .attr("rx", 6) | |
| .attr("fill", d => { | |
| const importance = d.importance || 0; | |
| return d3.interpolateYlOrRd(importance); | |
| }) | |
| .attr("stroke", d => selectedToken === d.position ? "#3b82f6" : "#4b5563") | |
| .attr("stroke-width", d => selectedToken === d.position ? 2 : 1) | |
| .style("cursor", "pointer") | |
| .on("click", (event, d) => { | |
| setSelectedToken(d.position === selectedToken ? null : d.position); | |
| }); | |
| // Token text | |
| tokenGroups.append("text") | |
| .attr("x", tokenWidth / 2) | |
| .attr("y", tokenHeight / 2) | |
| .attr("text-anchor", "middle") | |
| .attr("dominant-baseline", "middle") | |
| .attr("fill", d => d.importance > 0.5 ? "#fff" : "#1f2937") | |
| .attr("font-size", "11px") | |
| .attr("font-family", "monospace") | |
| .attr("pointer-events", "none") | |
| .text(d => d.text.substring(0, 8)); | |
| // Add importance indicator | |
| tokenGroups.append("circle") | |
| .attr("cx", tokenWidth - 10) | |
| .attr("cy", 10) | |
| .attr("r", d => Math.max(2, d.importance * 6)) | |
| .attr("fill", "#fbbf24") | |
| .attr("opacity", 0.8); | |
| // Add title | |
| svg.append("text") | |
| .attr("x", width / 2) | |
| .attr("y", 30) | |
| .attr("text-anchor", "middle") | |
| .attr("font-size", "16px") | |
| .attr("font-weight", "bold") | |
| .attr("fill", "#fff") | |
| .text("Token Flow Through Transformer Layers"); | |
| // Add legend - positioned in the right margin area, clear of the visualization | |
| // The visualization ends at width - margin.right (1100), legend goes in the margin | |
| const legend = svg.append("g") | |
| .attr("transform", `translate(${width - 180}, 100)`); | |
| const legendItems = [ | |
| { color: "#3b82f6", label: "Attention Flow", type: "attention" }, | |
| { color: "#10b981", label: "Residual Connection", type: "residual" }, | |
| { color: "#fbbf24", label: "Token Importance", type: "importance" } | |
| ]; | |
| legendItems.forEach((item, i) => { | |
| const legendItem = legend.append("g") | |
| .attr("transform", `translate(0, ${i * 25})`); | |
| if (item.type === "importance") { | |
| legendItem.append("circle") | |
| .attr("cx", 10) | |
| .attr("cy", 10) | |
| .attr("r", 6) | |
| .attr("fill", item.color); | |
| } else { | |
| legendItem.append("line") | |
| .attr("x1", 0) | |
| .attr("y1", 10) | |
| .attr("x2", 20) | |
| .attr("y2", 10) | |
| .attr("stroke", item.color) | |
| .attr("stroke-width", 2); | |
| } | |
| legendItem.append("text") | |
| .attr("x", 30) | |
| .attr("y", 10) | |
| .attr("dominant-baseline", "middle") | |
| .attr("fill", "#9ca3af") | |
| .attr("font-size", "11px") | |
| .text(item.label); | |
| }); | |
| }, [layers, flowConnections, selectedToken, zoom, isPlaying, showAttention, showResidual]); | |
| // Animation control | |
| const toggleAnimation = () => { | |
| setIsPlaying(!isPlaying); | |
| if (!isPlaying) { | |
| animateFlow(); | |
| } else { | |
| if (animationRef.current) { | |
| cancelAnimationFrame(animationRef.current); | |
| } | |
| } | |
| }; | |
| const animateFlow = () => { | |
| setCurrentStep(prev => (prev + 1) % (layers.length || 1)); | |
| animationRef.current = requestAnimationFrame(() => { | |
| if (isPlaying) { | |
| setTimeout(animateFlow, 1000); | |
| } | |
| }); | |
| }; | |
| const reset = () => { | |
| setCurrentStep(0); | |
| setSelectedToken(null); | |
| setSelectedLayer(null); | |
| setIsPlaying(false); | |
| }; | |
| // Generate contextual explanation for current visualization | |
| const generateExplanation = () => { | |
| if (layers.length === 0) { | |
| return { | |
| title: "No Token Flow Data", | |
| description: "Run a model to see how tokens flow through transformer layers.", | |
| details: [] | |
| }; | |
| } | |
| const numLayers = layers.length; | |
| const numTokens = tokens.length; | |
| const activeConnections = flowConnections.filter(c => c.strength > 0.1).length; | |
| const totalConnections = flowConnections.length; | |
| const connectionDensity = totalConnections > 0 ? ((activeConnections / totalConnections) * 100).toFixed(1) : "0"; | |
| return { | |
| title: `Token Flow Analysis: ${numTokens} tokens, ${numLayers} layers`, | |
| description: `Visualizing information flow through the transformer's attention mechanism.`, | |
| details: [ | |
| { | |
| heading: "What is Token Flow?", | |
| content: `This visualization shows how tokens are processed through transformer layers in real-time. Each column represents a layer, each box is a token at that layer. The visualization builds progressively as tokens are generated.` | |
| }, | |
| { | |
| heading: "Reading the Flow", | |
| content: `Tokens flow from left (layer 0, input) to right (final layer, output). Each column is a transformer layer processing all tokens. Color intensity (yellow→orange→red) shows token importance/activation strength.` | |
| }, | |
| { | |
| heading: "Real-time Generation", | |
| content: `The visualization starts with one column and expands horizontally as new tokens are generated. You're watching the model build its understanding token by token, layer by layer.` | |
| }, | |
| { | |
| heading: "Current Network Stats", | |
| content: `${numTokens} tokens × ${numLayers} layers = ${numTokens * numLayers} nodes. ${connectionDensity}% of possible connections are active (strength > 0.1). Blue lines show attention flow between tokens.` | |
| }, | |
| { | |
| heading: "Connection Types", | |
| content: `Blue lines: Attention connections showing which tokens attend to which. Green lines: Residual connections (when enabled). Line thickness indicates connection strength.` | |
| }, | |
| { | |
| heading: "Color Meaning", | |
| content: `Token color represents activation level: Light yellow (low activation), Orange (medium), Red (high activation). This shows which tokens are most important at each processing stage.` | |
| } | |
| ] | |
| }; | |
| }; | |
| const explanation = generateExplanation(); | |
| // Export functionality | |
| const exportVisualization = () => { | |
| if (!svgRef.current) return; | |
| const svgData = new XMLSerializer().serializeToString(svgRef.current); | |
| const svgBlob = new Blob([svgData], { type: "image/svg+xml;charset=utf-8" }); | |
| const svgUrl = URL.createObjectURL(svgBlob); | |
| const link = document.createElement("a"); | |
| link.href = svgUrl; | |
| link.download = `token_flow_${Date.now()}.svg`; | |
| link.click(); | |
| }; | |
| return ( | |
| <div className="bg-gray-900 rounded-xl p-6"> | |
| <div className="flex items-center justify-between mb-6"> | |
| <div> | |
| <h2 className="text-2xl font-bold flex items-center gap-2"> | |
| <GitBranch className="w-6 h-6 text-purple-400" /> | |
| Token Flow Visualizer | |
| </h2> | |
| <p className="text-gray-400 mt-1"> | |
| Track how information flows through transformer layers | |
| </p> | |
| </div> | |
| <div className="flex items-center gap-4"> | |
| <div className={`flex items-center gap-2 px-3 py-1 rounded-full ${ | |
| isConnected ? 'bg-green-900/30 text-green-400' : 'bg-red-900/30 text-red-400' | |
| }`}> | |
| <Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} /> | |
| {isConnected ? 'Connected' : 'Disconnected'} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Generation Controls */} | |
| <div className="mb-6"> | |
| <div className="flex gap-4"> | |
| <input | |
| type="text" | |
| value={prompt} | |
| onChange={(e) => setPrompt(e.target.value)} | |
| className="flex-1 px-4 py-2 bg-gray-800 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none font-mono text-sm" | |
| placeholder="Enter prompt to analyze token flow..." | |
| /> | |
| <button | |
| onClick={async () => { | |
| setIsGenerating(true); | |
| setTokens([]); | |
| setLayers([]); | |
| setFlowConnections([]); | |
| setTraces([]); | |
| try { | |
| const response = await fetch(`${getApiUrl()}/generate`, { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ | |
| prompt, | |
| max_tokens: 50, | |
| temperature: 0.7, | |
| extract_traces: true, | |
| sampling_rate: 0.3 | |
| }) | |
| }); | |
| if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`); | |
| const data = await response.json(); | |
| // Process the response | |
| if (data.traces) { | |
| setTraces(data.traces); | |
| } | |
| } catch (error) { | |
| console.error('Generation error:', error); | |
| alert(`Failed to generate: ${error}`); | |
| } finally { | |
| setIsGenerating(false); | |
| } | |
| }} | |
| disabled={isGenerating || !isConnected} | |
| className="px-6 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors disabled:opacity-50 flex items-center gap-2" | |
| > | |
| {isGenerating ? ( | |
| <> | |
| <RefreshCw className="w-4 h-4 animate-spin" /> | |
| Analyzing... | |
| </> | |
| ) : ( | |
| <> | |
| Generate & Visualize | |
| <Zap className="w-4 h-4" /> | |
| </> | |
| )} | |
| </button> | |
| </div> | |
| </div> | |
| {/* Controls */} | |
| <div className="flex flex-wrap items-center gap-4 mb-4"> | |
| {/* Playback Controls */} | |
| <div className="flex items-center gap-2"> | |
| <button | |
| onClick={toggleAnimation} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title={isPlaying ? "Pause" : "Play"} | |
| > | |
| {isPlaying ? <Pause className="w-4 h-4" /> : <Play className="w-4 h-4" />} | |
| </button> | |
| <button | |
| onClick={reset} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title="Reset" | |
| > | |
| <RotateCcw className="w-4 h-4" /> | |
| </button> | |
| <span className="text-sm text-gray-400 px-2"> | |
| Layer {currentStep + 1} / {layers.length || 1} | |
| </span> | |
| </div> | |
| {/* View Options */} | |
| <div className="flex items-center gap-2"> | |
| <label className="flex items-center gap-2 text-sm text-gray-400"> | |
| <input | |
| type="checkbox" | |
| checked={showAttention} | |
| onChange={(e) => setShowAttention(e.target.checked)} | |
| className="rounded" | |
| /> | |
| Attention | |
| </label> | |
| <label className="flex items-center gap-2 text-sm text-gray-400"> | |
| <input | |
| type="checkbox" | |
| checked={showResidual} | |
| onChange={(e) => setShowResidual(e.target.checked)} | |
| className="rounded" | |
| /> | |
| Residual | |
| </label> | |
| </div> | |
| {/* Zoom Controls */} | |
| <div className="flex items-center gap-2"> | |
| <button | |
| onClick={() => setZoom(Math.max(0.5, zoom - 0.1))} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title="Zoom Out" | |
| > | |
| <ZoomOut className="w-4 h-4" /> | |
| </button> | |
| <span className="text-sm text-gray-400 min-w-[50px] text-center"> | |
| {(zoom * 100).toFixed(0)}% | |
| </span> | |
| <button | |
| onClick={() => setZoom(Math.min(2, zoom + 0.1))} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title="Zoom In" | |
| > | |
| <ZoomIn className="w-4 h-4" /> | |
| </button> | |
| </div> | |
| {/* Export */} | |
| <button | |
| onClick={exportVisualization} | |
| className="ml-auto px-3 py-1.5 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors flex items-center gap-2" | |
| > | |
| <Download className="w-4 h-4" /> | |
| Export | |
| </button> | |
| </div> | |
| {/* Main Content Area with Side Panel */} | |
| <div className="flex gap-4"> | |
| {/* Visualization Container */} | |
| <div className="flex-1 min-w-0 transition-all duration-500 ease-in-out"> | |
| <div className="bg-gray-800 rounded-lg p-4 overflow-auto relative"> | |
| {/* Help Toggle Button */} | |
| <button | |
| onClick={() => setShowExplanation(!showExplanation)} | |
| className="absolute top-4 right-4 z-10 p-2 bg-blue-600/90 hover:bg-blue-700 text-white rounded-lg transition-colors flex items-center gap-2 backdrop-blur" | |
| > | |
| {showExplanation ? <X className="w-5 h-5" /> : <HelpCircle className="w-5 h-5" />} | |
| <span className="text-sm font-medium"> | |
| {showExplanation ? 'Hide Info' : 'What am I seeing?'} | |
| </span> | |
| </button> | |
| {layers.length > 0 ? ( | |
| <svg ref={svgRef}></svg> | |
| ) : ( | |
| <div className="flex items-center justify-center h-96 text-gray-500"> | |
| <div className="text-center"> | |
| <GitBranch className="w-12 h-12 mx-auto mb-4 opacity-50" /> | |
| <p className="text-lg mb-2">No Token Flow Data</p> | |
| <p className="text-sm">Run a model to visualize token flow through layers</p> | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| </div> | |
| {/* Explanation Side Panel */} | |
| <div className={`${showExplanation ? 'w-96' : 'w-0'} transition-all duration-500 ease-in-out overflow-hidden`}> | |
| <div className="w-96 h-[600px] bg-gray-900 rounded-lg border border-gray-700"> | |
| {/* Panel Header */} | |
| <div className="bg-gray-800 px-4 py-3 border-b border-gray-700"> | |
| <div className="flex items-center gap-2"> | |
| <Info className="w-5 h-5 text-blue-400" /> | |
| <h3 className="text-lg font-semibold text-white">Understanding Token Flow</h3> | |
| </div> | |
| </div> | |
| {/* Panel Content */} | |
| <div className="px-4 py-4 overflow-y-auto h-[calc(600px-60px)]"> | |
| {/* Main Description */} | |
| <div className="mb-4 p-3 bg-purple-900/20 border border-purple-800 rounded-lg"> | |
| <h4 className="text-sm font-semibold text-purple-400 mb-1">{explanation.title}</h4> | |
| <p className="text-xs text-gray-300">{explanation.description}</p> | |
| </div> | |
| {/* Explanation Sections */} | |
| <div className="space-y-3"> | |
| {explanation.details.map((section, idx) => ( | |
| <div key={idx} className="bg-gray-800 rounded-lg p-3"> | |
| <h5 className="font-medium text-sm text-white mb-1 flex items-center gap-1"> | |
| <Zap className="w-3 h-3 text-yellow-400" /> | |
| {section.heading} | |
| </h5> | |
| <p className="text-xs text-gray-300 leading-relaxed">{section.content}</p> | |
| </div> | |
| ))} | |
| </div> | |
| {/* Visual Guide */} | |
| <div className="mt-4 p-3 bg-blue-900/20 border border-blue-800 rounded-lg"> | |
| <h4 className="font-medium text-sm text-blue-400 mb-2">Visual Elements</h4> | |
| <div className="space-y-2 text-xs"> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-blue-300">•</span> | |
| <span className="text-gray-300">Nodes = Tokens at each layer</span> | |
| </div> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-blue-300">•</span> | |
| <span className="text-gray-300">Lines = Attention connections</span> | |
| </div> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-blue-300">•</span> | |
| <span className="text-gray-300">Thickness = Connection strength</span> | |
| </div> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-blue-300">•</span> | |
| <span className="text-gray-300">Color intensity = Token importance</span> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Current Metrics */} | |
| {layers.length > 0 && ( | |
| <div className="mt-4 p-3 bg-gray-800 rounded-lg"> | |
| <h4 className="font-medium text-sm text-gray-300 mb-2">Current Metrics</h4> | |
| <div className="space-y-1 text-xs"> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Tokens:</span> | |
| <span className="text-white">{tokens.length}</span> | |
| </div> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Layers:</span> | |
| <span className="text-white">{layers.length}</span> | |
| </div> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Total Nodes:</span> | |
| <span className="text-white">{tokens.length * layers.length}</span> | |
| </div> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Active Connections:</span> | |
| <span className="text-blue-400">{flowConnections.filter(c => c.strength > 0.1).length}</span> | |
| </div> | |
| </div> | |
| </div> | |
| )} | |
| {/* Tips */} | |
| <div className="mt-4 p-3 bg-gray-800 rounded-lg"> | |
| <h4 className="font-medium text-sm text-gray-300 mb-2">💡 Tips</h4> | |
| <ul className="text-xs text-gray-400 space-y-1"> | |
| <li>• Click tokens to trace their path</li> | |
| <li>• Use animation to see flow evolution</li> | |
| <li>• Zoom for different perspectives</li> | |
| <li>• Toggle connection types with controls</li> | |
| </ul> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Info Panel */} | |
| {selectedToken !== null && tokens[selectedToken] && ( | |
| <div className="mt-4 p-4 bg-gray-800 rounded-lg"> | |
| <h3 className="text-lg font-semibold mb-3 flex items-center gap-2"> | |
| <Info className="w-5 h-5 text-blue-400" /> | |
| Selected Token: "{tokens[selectedToken].text}" | |
| </h3> | |
| <div className="grid grid-cols-2 md:grid-cols-4 gap-4 text-sm"> | |
| <div> | |
| <span className="text-gray-400">Position:</span> | |
| <div className="font-mono text-white mt-1">{selectedToken}</div> | |
| </div> | |
| <div> | |
| <span className="text-gray-400">Layers Processed:</span> | |
| <div className="font-mono text-white mt-1">{layers.length}</div> | |
| </div> | |
| <div> | |
| <span className="text-gray-400">Connections:</span> | |
| <div className="font-mono text-white mt-1"> | |
| {flowConnections.filter(c => | |
| c.source.token === selectedToken || c.target.token === selectedToken | |
| ).length} | |
| </div> | |
| </div> | |
| <div> | |
| <span className="text-gray-400">Max Importance:</span> | |
| <div className="font-mono text-blue-400 mt-1"> | |
| {Math.max(...layers.map(l => | |
| l.tokens[selectedToken]?.importance || 0 | |
| )).toFixed(3)} | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| )} | |
| {/* Instructions */} | |
| {layers.length === 0 && ( | |
| <div className="mt-4 p-4 bg-yellow-900/20 border border-yellow-700 rounded-lg"> | |
| <h4 className="text-yellow-400 font-semibold mb-2">How to Use</h4> | |
| <ol className="text-sm text-gray-300 space-y-1 list-decimal list-inside"> | |
| <li>Run a model to generate attention traces</li> | |
| <li>Token flow will automatically visualize</li> | |
| <li>Click tokens to see their flow details</li> | |
| <li>Use controls to animate the flow</li> | |
| </ol> | |
| </div> | |
| )} | |
| </div> | |
| ); | |
| } |