|
import json |
|
import os |
|
import time |
|
from dataclasses import dataclass |
|
from datetime import datetime |
|
from zoneinfo import ZoneInfo |
|
|
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import wandb |
|
from substrateinterface import Keypair |
|
from wandb.apis.public import Run |
|
|
|
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"] |
|
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"]) |
|
REFRESH_RATE = 60 * 30 |
|
BASELINE = 0.0 |
|
GRAPH_HISTORY_DAYS = 30 |
|
MAX_GRAPH_ENTRIES = 10 |
|
|
|
wandb_api = wandb.Api() |
|
demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}") |
|
runs: dict[int, list[Run]] = {} |
|
|
|
|
|
@dataclass |
|
class LeaderboardEntry: |
|
uid: int |
|
model: str |
|
score: float |
|
hotkey: str |
|
previous_day_winner: bool |
|
rank: int |
|
|
|
|
|
@dataclass |
|
class GraphEntry: |
|
dates: list[datetime] |
|
scores: list[float] |
|
models: list[str] |
|
max_score: float |
|
|
|
|
|
def is_valid_run(run: Run): |
|
required_config_keys = ["hotkey", "uid", "contest", "signature"] |
|
|
|
for key in required_config_keys: |
|
if key not in run.config: |
|
return False |
|
|
|
uid = run.config["uid"] |
|
validator_hotkey = run.config["hotkey"] |
|
contest_name = run.config["contest"] |
|
|
|
signing_message = f"{uid}:{validator_hotkey}:{contest_name}" |
|
|
|
try: |
|
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"]) |
|
except Exception: |
|
return False |
|
|
|
|
|
def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]: |
|
entries: dict[int, GraphEntry] = {} |
|
|
|
for run in reversed(runs[:GRAPH_HISTORY_DAYS]): |
|
date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S") |
|
|
|
for key, value in run.summary.items(): |
|
if key.startswith("_"): |
|
continue |
|
|
|
uid = int(key) |
|
score = value["score"] |
|
model = value["model"] |
|
|
|
if uid not in entries: |
|
entries[uid] = GraphEntry([date], [score], [model], score) |
|
else: |
|
if score > entries[uid].max_score: |
|
entries[uid].max_score = score |
|
|
|
data = entries[uid] |
|
data.dates.append(date) |
|
data.scores.append(data.max_score) |
|
data.models.append(model) |
|
|
|
return dict(sorted(entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES]) |
|
|
|
|
|
def create_graph(runs: list[Run]) -> go.Figure: |
|
entries = get_graph_entries(runs) |
|
fig = go.Figure() |
|
|
|
for uid, data in entries.items(): |
|
fig.add_trace(go.Scatter( |
|
x=data.dates, |
|
y=data.scores, |
|
customdata=data.models, |
|
mode="lines+markers", |
|
name=uid, |
|
hovertemplate=( |
|
"<b>Date:</b> %{x|%Y-%m-%d}<br>" + |
|
"<b>Score:</b> %{y}<br>" + |
|
"<b>Model:</b> %{customdata}<br>" |
|
), |
|
)) |
|
|
|
date_range = max(entries.values(), key=lambda entry: len(entry.dates)).dates |
|
|
|
fig.add_trace(go.Scatter( |
|
x=date_range, |
|
y=[BASELINE] * len(date_range), |
|
line=dict(color="#ff0000", width=3), |
|
mode="lines", |
|
name="Baseline", |
|
)) |
|
|
|
background_color = gr.themes.default.colors.slate.c800 |
|
|
|
fig.update_layout( |
|
title="Score Improvements", |
|
yaxis_title="Score", |
|
plot_bgcolor=background_color, |
|
paper_bgcolor=background_color, |
|
template="plotly_dark" |
|
) |
|
|
|
return fig |
|
|
|
|
|
def create_leaderboard(runs: list[Run]) -> list[tuple]: |
|
entries: dict[int, LeaderboardEntry] = {} |
|
|
|
for run in runs: |
|
has_data = False |
|
for key, value in run.summary.items(): |
|
if key.startswith("_"): |
|
continue |
|
|
|
has_data = True |
|
|
|
try: |
|
uid = int(key) |
|
|
|
entries[uid] = LeaderboardEntry( |
|
uid=uid, |
|
rank=value["rank"], |
|
model=value["model"], |
|
score=value["score"], |
|
hotkey=value["hotkey"], |
|
previous_day_winner=value["multiday_winner"], |
|
) |
|
except Exception: |
|
continue |
|
|
|
if has_data: |
|
break |
|
|
|
leaderboard: list[tuple] = [ |
|
(entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner) |
|
for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True) |
|
] |
|
|
|
return leaderboard |
|
|
|
|
|
def get_run_validator_uid(run: Run) -> int: |
|
json_config = json.loads(run.json_config) |
|
uid = int(json_config["uid"]["value"]) |
|
return uid |
|
|
|
|
|
def fetch_wandb_data(): |
|
wandb_runs = wandb_api.runs( |
|
WANDB_RUN_PATH, |
|
filters={"config.type": "validator"}, |
|
order="-created_at", |
|
) |
|
|
|
global runs |
|
runs.clear() |
|
for run in wandb_runs: |
|
if not is_valid_run(run): |
|
continue |
|
|
|
uid = get_run_validator_uid(run) |
|
if uid not in runs: |
|
runs[uid] = [] |
|
runs[uid].append(run) |
|
|
|
runs = dict(sorted(runs.items(), key=lambda item: item[0])) |
|
|
|
|
|
def refresh(): |
|
fetch_wandb_data() |
|
demo.clear() |
|
with demo: |
|
gr.Image( |
|
"cover.png", |
|
show_label=False, |
|
show_download_button=False, |
|
interactive=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
|
|
gr.Label( |
|
"SN39 EdgeMaxxing Leaderboard", |
|
show_label=False, |
|
), |
|
|
|
gr.Text( |
|
"This leaderboard for SN39 tracks the results and top model submissions from current and previous contests.", |
|
show_label=False, |
|
text_align="center", |
|
) |
|
|
|
with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"): |
|
choices = list(runs.keys()) |
|
dropdown = gr.Dropdown( |
|
choices, |
|
value=SOURCE_VALIDATOR_UID, |
|
interactive=True, |
|
label="Source Validator" |
|
) |
|
|
|
graph = gr.Plot() |
|
|
|
leaderboard = gr.components.Dataframe( |
|
create_leaderboard(runs[dropdown.value]), |
|
headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"], |
|
datatype=["number", "number", "markdown", "number", "markdown", "bool"], |
|
elem_id="leaderboard-table", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
demo.load(lambda uid: create_graph(runs[uid]), [dropdown], [graph]) |
|
|
|
dropdown.change(lambda uid: create_graph(runs[uid]), [dropdown], [graph]) |
|
dropdown.change(lambda uid: create_leaderboard(runs[uid]), [dropdown], [leaderboard]) |
|
|
|
|
|
if __name__ == "__main__": |
|
refresh() |
|
demo.launch(prevent_thread_lock=True) |
|
|
|
while True: |
|
time.sleep(REFRESH_RATE) |
|
|
|
now = datetime.now(tz=ZoneInfo("America/New_York")) |
|
print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
refresh() |
|
|