AlexNijjar's picture
Implement Leaderboard Graph (#2)
94052c1 verified
raw
history blame
5.84 kB
import os
import time
from dataclasses import dataclass
from datetime import datetime
from zoneinfo import ZoneInfo
import gradio as gr
import plotly.graph_objects as go
import schedule
import wandb
from substrateinterface import Keypair
from wandb.apis.public import Run, Runs
wandb_api = wandb.Api()
demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
BASELINE = 0.0
GRAPH_HISTORY_DAYS = 30
MAX_GRAPH_ENTRIES = 5
@dataclass
class LeaderboardEntry:
uid: int
model: str
score: float
hotkey: str
previous_day_winner: bool
rank: int
@dataclass
class GraphEntry:
dates: list[datetime]
scores: list[float]
models: list[str]
max_score: float
def is_valid_run(run: Run):
required_config_keys = ["hotkey", "uid", "contest", "signature"]
for key in required_config_keys:
if key not in run.config:
return False
uid = run.config["uid"]
validator_hotkey = run.config["hotkey"]
contest_name = run.config["contest"]
signing_message = f"{uid}:{validator_hotkey}:{contest_name}"
try:
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"])
except Exception:
return False
def get_graph_entries(runs: Runs) -> dict[int, GraphEntry]:
graph_entries: dict[int, GraphEntry] = {}
for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
if not is_valid_run(run):
continue
date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
for key, value in run.summary.items():
if key.startswith("_"):
continue
uid = int(key)
score = value["score"]
model = value["model"]
if uid not in graph_entries:
graph_entries[uid] = GraphEntry([date], [score], [model], score)
else:
if score > graph_entries[uid].max_score:
graph_entries[uid].max_score = score
data = graph_entries[uid]
data.dates.append(date)
data.scores.append(data.max_score)
data.models.append(model)
return dict(sorted(graph_entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES])
def create_graph(graph_entries: dict[int, GraphEntry]):
fig = go.Figure()
for uid, data in graph_entries.items():
fig.add_trace(go.Scatter(
x=data.dates,
y=data.scores,
customdata=data.models,
mode="lines+markers",
name=uid,
hovertemplate=(
"<b>Date:</b> %{x|%Y-%m-%d}<br>" +
"<b>Score:</b> %{y}<br>" +
"<b>Model:</b> %{customdata}<br>"
),
))
date_range = max(graph_entries.values(), key=lambda entry: len(entry.dates)).dates
fig.add_trace(go.Scatter(
x=date_range,
y=[BASELINE] * len(date_range),
line=dict(color="#ff0000", width=3),
mode="lines",
name="Baseline",
))
background_color = gr.themes.default.colors.slate.c800
fig.update_layout(
title="Score Improvements",
yaxis_title="Score",
plot_bgcolor=background_color,
paper_bgcolor=background_color,
template="plotly_dark"
)
gr.Plot(fig)
def refresh_leaderboard():
now = datetime.now(tz=ZoneInfo("America/New_York"))
print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
demo.clear()
with demo:
with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"):
runs: Runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "config.uid": SOURCE_VALIDATOR_UID},
order="-created_at",
)
entries: dict[int, LeaderboardEntry] = {}
for run in runs:
if not is_valid_run(run):
continue
has_data = False
for key, value in run.summary.items():
if key.startswith("_"):
continue
has_data = True
try:
uid = int(key)
entries[uid] = LeaderboardEntry(
uid=uid,
rank=value["rank"],
model=value["model"],
score=value["score"],
hotkey=value["hotkey"],
previous_day_winner=value["multiday_winner"],
)
except Exception:
continue
if has_data:
break
leaderboard: list[tuple] = [
(entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner)
for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
]
create_graph(get_graph_entries(runs))
gr.components.Dataframe(
value=leaderboard,
headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
datatype=["number", "number", "markdown", "number", "markdown", "bool"],
elem_id="leaderboard-table",
interactive=False,
visible=True,
)
def main():
refresh_leaderboard()
schedule.every(30).minutes.do(refresh_leaderboard)
demo.launch(prevent_thread_lock=True)
while True:
schedule.run_pending()
time.sleep(1)
main()