|
import os |
|
import time |
|
from dataclasses import dataclass |
|
from datetime import datetime |
|
from zoneinfo import ZoneInfo |
|
|
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import schedule |
|
import wandb |
|
from substrateinterface import Keypair |
|
from wandb.apis.public import Run, Runs |
|
|
|
wandb_api = wandb.Api() |
|
demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}") |
|
|
|
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"]) |
|
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"] |
|
|
|
BASELINE = 0.0 |
|
GRAPH_HISTORY_DAYS = 30 |
|
MAX_GRAPH_ENTRIES = 5 |
|
|
|
|
|
@dataclass |
|
class LeaderboardEntry: |
|
uid: int |
|
model: str |
|
score: float |
|
hotkey: str |
|
previous_day_winner: bool |
|
rank: int |
|
|
|
|
|
@dataclass |
|
class GraphEntry: |
|
dates: list[datetime] |
|
scores: list[float] |
|
models: list[str] |
|
max_score: float |
|
|
|
|
|
def is_valid_run(run: Run): |
|
required_config_keys = ["hotkey", "uid", "contest", "signature"] |
|
|
|
for key in required_config_keys: |
|
if key not in run.config: |
|
return False |
|
|
|
uid = run.config["uid"] |
|
validator_hotkey = run.config["hotkey"] |
|
contest_name = run.config["contest"] |
|
|
|
signing_message = f"{uid}:{validator_hotkey}:{contest_name}" |
|
|
|
try: |
|
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"]) |
|
except Exception: |
|
return False |
|
|
|
|
|
def get_graph_entries(runs: Runs) -> dict[int, GraphEntry]: |
|
graph_entries: dict[int, GraphEntry] = {} |
|
|
|
for run in reversed(runs[:GRAPH_HISTORY_DAYS]): |
|
if not is_valid_run(run): |
|
continue |
|
|
|
date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S") |
|
|
|
for key, value in run.summary.items(): |
|
if key.startswith("_"): |
|
continue |
|
|
|
uid = int(key) |
|
score = value["score"] |
|
model = value["model"] |
|
|
|
if uid not in graph_entries: |
|
graph_entries[uid] = GraphEntry([date], [score], [model], score) |
|
else: |
|
if score > graph_entries[uid].max_score: |
|
graph_entries[uid].max_score = score |
|
|
|
data = graph_entries[uid] |
|
data.dates.append(date) |
|
data.scores.append(data.max_score) |
|
data.models.append(model) |
|
|
|
return dict(sorted(graph_entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES]) |
|
|
|
|
|
def create_graph(graph_entries: dict[int, GraphEntry]): |
|
fig = go.Figure() |
|
|
|
for uid, data in graph_entries.items(): |
|
fig.add_trace(go.Scatter( |
|
x=data.dates, |
|
y=data.scores, |
|
customdata=data.models, |
|
mode="lines+markers", |
|
name=uid, |
|
hovertemplate=( |
|
"<b>Date:</b> %{x|%Y-%m-%d}<br>" + |
|
"<b>Score:</b> %{y}<br>" + |
|
"<b>Model:</b> %{customdata}<br>" |
|
), |
|
)) |
|
|
|
date_range = max(graph_entries.values(), key=lambda entry: len(entry.dates)).dates |
|
|
|
fig.add_trace(go.Scatter( |
|
x=date_range, |
|
y=[BASELINE] * len(date_range), |
|
line=dict(color="#ff0000", width=3), |
|
mode="lines", |
|
name="Baseline", |
|
)) |
|
|
|
background_color = gr.themes.default.colors.slate.c800 |
|
|
|
fig.update_layout( |
|
title="Score Improvements", |
|
yaxis_title="Score", |
|
plot_bgcolor=background_color, |
|
paper_bgcolor=background_color, |
|
template="plotly_dark" |
|
) |
|
|
|
gr.Plot(fig) |
|
|
|
|
|
def refresh_leaderboard(): |
|
now = datetime.now(tz=ZoneInfo("America/New_York")) |
|
print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
demo.clear() |
|
|
|
with demo: |
|
with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"): |
|
runs: Runs = wandb_api.runs( |
|
WANDB_RUN_PATH, |
|
filters={"config.type": "validator", "config.uid": SOURCE_VALIDATOR_UID}, |
|
order="-created_at", |
|
) |
|
|
|
entries: dict[int, LeaderboardEntry] = {} |
|
|
|
for run in runs: |
|
if not is_valid_run(run): |
|
continue |
|
|
|
has_data = False |
|
for key, value in run.summary.items(): |
|
if key.startswith("_"): |
|
continue |
|
|
|
has_data = True |
|
|
|
try: |
|
uid = int(key) |
|
|
|
entries[uid] = LeaderboardEntry( |
|
uid=uid, |
|
rank=value["rank"], |
|
model=value["model"], |
|
score=value["score"], |
|
hotkey=value["hotkey"], |
|
previous_day_winner=value["multiday_winner"], |
|
) |
|
except Exception: |
|
continue |
|
|
|
if has_data: |
|
break |
|
|
|
leaderboard: list[tuple] = [ |
|
(entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner) |
|
for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True) |
|
] |
|
|
|
create_graph(get_graph_entries(runs)) |
|
|
|
gr.components.Dataframe( |
|
value=leaderboard, |
|
headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"], |
|
datatype=["number", "number", "markdown", "number", "markdown", "bool"], |
|
elem_id="leaderboard-table", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
|
|
def main(): |
|
refresh_leaderboard() |
|
schedule.every(30).minutes.do(refresh_leaderboard) |
|
|
|
demo.launch(prevent_thread_lock=True) |
|
|
|
while True: |
|
schedule.run_pending() |
|
time.sleep(1) |
|
|
|
|
|
main() |
|
|