ror's picture
ror HF Staff
error bars
d73296f
raw
history blame
6.57 kB
from math import e
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
import io
import base64
from data import ModelBenchmarkData
# Configure matplotlib for better performance
matplotlib.use('Agg')
plt.ioff()
DATA = ModelBenchmarkData("data.json")
def refresh_plot_data():
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
print(data)
return pd.DataFrame(data)
def load_css():
"""Load CSS styling."""
try:
with open("styles.css", "r") as f:
return f.read()
except FileNotFoundError:
return "body { background: #000; color: #fff; }"
def create_matplotlib_bar_charts():
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
# Create figure with dark theme - larger for more screen space
plt.style.use('dark_background')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 16))
fig.patch.set_facecolor('#000000')
# Prepare data
labels = data['label']
ttft_values = data['ttft']
tpot_values = data['tpot']
# Calculate error bars (standard deviation) for each configuration
import numpy as np
ttft_errors = []
tpot_errors = []
# Get raw data to calculate standard deviations
raw_data = DATA.data
for label in labels:
if label in raw_data:
# Calculate std dev for TTFT
ttft_raw = [d['wall_time'] for d in raw_data[label]['ttft']]
ttft_errors.append(float(np.std(ttft_raw)))
# Calculate std dev for TPOT
tpot_raw = [d['wall_time'] for d in raw_data[label]['tpot']]
tpot_errors.append(float(np.std(tpot_raw)))
else:
ttft_errors.append(0)
tpot_errors.append(0)
# Define color mapping based on configuration keywords
def get_color_for_config(label):
is_eager = 'eager' in label.lower()
is_sdpa = 'sdpa' in label.lower()
is_compiled = '_compiled' in label.lower()
if is_eager:
if is_compiled:
return '#FF4444' # Red for eager compiled
else:
return '#FF6B6B' # Light red for eager uncompiled
elif is_sdpa:
if is_compiled:
return '#4A90E2' # Blue for SDPA compiled
else:
return '#7BB3F0' # Light blue for SDPA uncompiled
else:
return '#FFD700' # Yellow for others
# Get colors for each bar
colors = [get_color_for_config(label) for label in labels]
# TTFT Plot (left)
ax1.set_facecolor('#000000')
_ = ax1.bar(range(len(labels)), ttft_values,
color=colors, width=1.0, edgecolor='white', linewidth=1,
label=[label[:15] + '...' if len(label) > 15 else label for label in labels])
# Add error bars for TTFT
ax1.errorbar(
range(len(labels)), ttft_values, yerr=ttft_errors,
fmt='none', ecolor='white', alpha=0.8,
elinewidth=1.5, capthick=1.5, capsize=4,
)
ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
ax1.set_title('Time to first token (lower is better)', color='white', fontsize=16, pad=20)
ax1.set_xticks([])
ax1.tick_params(colors='white')
ax1.grid(True, alpha=0.3, color='white')
# TPOT Plot (right)
ax2.set_facecolor('#000000')
_ = ax2.bar(range(len(labels)), tpot_values,
color=colors, width=1.0, edgecolor='white', linewidth=1)
# Add error bars for TPOT
ax2.errorbar(
range(len(labels)), tpot_values, yerr=tpot_errors,
fmt='none', ecolor='white', alpha=0.8,
elinewidth=1.5, capthick=1.5, capsize=4,
)
ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
ax2.set_title('Time per output token (lower is better)', color='white', fontsize=16, pad=20)
ax2.set_xticks([])
ax2.tick_params(colors='white')
ax2.grid(True, alpha=0.3, color='white')
# Add common legend with full text
legend_labels = labels # Use full labels without truncation
legend_handles = [plt.Rectangle((0,0),1,1, color=color, edgecolor='white') for color in colors]
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=1,
bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
labelcolor='white', fontsize=12)
# Tight layout with spacing between subplots and extra bottom space for legend
plt.tight_layout()
plt.subplots_adjust(wspace=0.3, bottom=0.075)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format='png', facecolor='#000000',
bbox_inches='tight', dpi=130)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full height
html = f"""
<div style="width: 100%; height: 100vh; background: #000; display: flex; justify-content: center; align-items: center;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain;" />
</div>
"""
return html
def refresh_plot():
"""Generate new matplotlib charts and update description."""
return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
# Create Gradio interface
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
with gr.Row():
# Sidebar
with gr.Column(scale=1, elem_classes=["sidebar"]):
gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])
# Main plot area
with gr.Column(elem_classes=["main-content"]):
plot = gr.HTML(
create_matplotlib_bar_charts(),
elem_classes=["plot-container"],
)
# Button click handler
summary_btn.click(fn=refresh_plot, outputs=[plot, description])
if __name__ == "__main__":
demo.launch()