vis data
Update app.py
744a850 verified
import gradio as gr
import datetime
import plotly.graph_objects as go
import threading
import time
from utils import * # Ensure get_wandb_runs and get_scores are defined here.
from chain import get_model_info
# Global history list to record the lowest avg_loss over time.
loss_history = []
# Set your project name and filter.
project_name = 'ai-factory-validators'
filters = {"State": {"$eq": "running"}}
window_size = 32
# Create a global lock so that update_results runs in mutual exclusion.
update_lock = threading.Lock()
def moving_average(data, window=10):
"""Compute the moving average of data using a sliding window."""
if not data:
return []
ma = []
for i in range(len(data)):
start = max(0, i - window + 1)
window_vals = data[start:i+1]
ma.append(sum(window_vals) / len(window_vals))
return ma
def update_results():
"""Fetch runs and scores, update the leaderboard and plot, ensuring that only one call runs at a time."""
with update_lock:
# Load new results using provided snippets.
runs = get_wandb_runs(project_name, filters)
scores = get_scores(list(range(256)), runs)
# Group scores by competition_id with required fields.
tables = {}
for uid, data in scores.items():
comp_id = data.get("competition_id", "unknown")
if comp_id not in tables:
tables[comp_id] = []
tables[comp_id].append({
"uid": uid,
"avg_loss": data.get("avg_loss"),
"win_rate": data.get("win_rate"),
"model": get_model_info(uid)
})
# Sort each table by UID.
for comp_id in tables:
tables[comp_id] = sorted(tables[comp_id], key=lambda x: x["uid"])
# Determine the current lowest avg_loss (for plotting).
try:
min_loss = min(data.get("avg_loss", float("inf")) for data in scores.values())
except ValueError:
min_loss = None
# Record the current time and update loss_history.
now = datetime.datetime.now()
if not loss_history or loss_history[-1][1] != min_loss:
loss_history.append((now, min_loss))
if len(loss_history) > 10000:
loss_history[:] = loss_history[-10000:]
# Create time series and compute moving average.
times = [t[0] for t in loss_history]
losses = [t[1] for t in loss_history]
ma_losses = moving_average(losses, window=window_size)
# Build the Plotly graph.
fig = go.Figure()
fig.add_trace(go.Scatter(x=times, y=losses, mode='lines+markers', name='Lowest avg_loss'))
fig.add_trace(go.Scatter(x=times, y=ma_losses, mode='lines', name=f'Moving Average (window={window_size})'))
fig.update_layout(
title="Lowest Avg Loss Over Time",
xaxis_title="Time",
yaxis_title="Lowest Avg Loss",
template="plotly_white",
height=400
)
# Build the HTML content for the leaderboard.
html_content = "<h1>AI Factory Leaderboard</h1>"
for comp_id, rows in tables.items():
# Identify the row with the highest win_rate.
best_win_rate = max(row["win_rate"] for row in rows)
comp_title = f"Competition ID: {comp_id}"
if comp_id == 0:
comp_title += " (Research Track)"
html_content += f"<h3>{comp_title}</h3>"
html_content += """
<table border='1' style='border-collapse: collapse; width: 100%;'>
<tr>
<th>UID</th>
<th>Avg Loss</th>
<th>Win Rate</th>
<th>Model Name</th>
</tr>
"""
for row in rows:
if row["win_rate"] == best_win_rate:
style = "background-color: #ffeb99;" # Light yellow background.
crown = " πŸ‘‘"
else:
style = ""
crown = ""
html_content += f"<tr style='{style}'><td>{row['uid']}</td><td>{row['avg_loss']:.4f}</td><td>{row['win_rate']:.2f}</td><td>{row['model']}{crown}</td></tr>"
html_content += "</table><br>"
return html_content, fig
# Global variables to store the latest outputs.
latest_html = ""
latest_fig = None
def background_update():
"""Background thread that runs update_results every 10 seconds and stores its outputs."""
global latest_html, latest_fig
while True:
try:
html_content, fig = update_results()
latest_html, latest_fig = html_content, fig
except Exception as e:
print("Error during background update:", e)
time.sleep(10)
# Start the background update thread.
threading.Thread(target=background_update, daemon=True).start()
def get_latest_results():
"""Return the latest HTML and Plotly graph."""
return latest_html, latest_fig
with gr.Blocks() as demo:
# Hide any unwanted refresh button in the DOM.
gr.HTML("<style>#refresh_button {display: none;}</style>")
# Display the title.
# gr.HTML("<h1 style='text-align:center;'>AI Factory Leaderboard</h1>")
# Define the outputs.
tables_output = gr.HTML()
graph_output = gr.Plot()
# A hidden textbox triggers periodic updates every 10 seconds.
trigger = gr.Textbox(visible=False, every=10)
trigger.change(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
# Manual refresh button that also calls update_results.
manual_refresh = gr.Button("Refresh Now")
manual_refresh.click(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
# Load results once on startup.
demo.load(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
demo.launch(share=True)