rusticluftig's picture
Update subnet ID and other links
d60014f
raw
history blame
No virus
19 kB
import argparse
import datetime
import functools
import json
import math
import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import bittensor as bt
import numpy as np
import pandas as pd
import wandb
from bittensor.extrinsics.serving import get_metadata
from dotenv import load_dotenv
from wandb.apis.public.history import HistoryScan
import competitions
NETUID = 37
DELAY_SECS = 3
RETRIES = 3
load_dotenv()
WANDB_TOKEN = os.environ.get("WANDB_API_KEY", None)
SUBTENSOR_ENDPOINT = os.environ.get("SUBTENSOR_ENDPOINT", None)
VALIDATOR_WANDB_PROJECT = "rusticluftig/finetuning"
BENCHMARK_WANDB_PROJECT = ""
BENCHMARK_FLAG = os.environ.get("BENCHMARK_FLAG", None)
@dataclass(frozen=True)
class ModelData:
uid: int
hotkey: str
competition_id: int
namespace: str
name: str
commit: str
# Hash of (hash(model) + hotkey)
secure_hash: str
block: int
incentive: float
emission: float
@classmethod
def from_compressed_str(
cls,
uid: int,
hotkey: str,
cs: str,
block: int,
incentive: float,
emission: float,
):
"""Returns an instance of this class from a compressed string representation"""
tokens = cs.split(":")
return ModelData(
uid=uid,
hotkey=hotkey,
namespace=tokens[0],
name=tokens[1],
commit=tokens[2],
secure_hash=tokens[3],
competition_id=int(tokens[4]),
block=block,
incentive=incentive,
emission=emission,
)
def run_with_retries(func, *args, **kwargs):
"""Runs a provided function with retries in the event of a failure."""
for i in range(0, RETRIES):
try:
return func(*args, **kwargs)
except (Exception, RuntimeError):
bt.logging.error(f"Failed to run function: {traceback.format_exc()}")
if i == RETRIES - 1:
raise
time.sleep(DELAY_SECS)
raise RuntimeError("Should never happen")
def get_subtensor_and_metagraph() -> Tuple[bt.subtensor, bt.metagraph]:
"""Returns a subtensor and metagraph for the finetuning subnet."""
def _internal() -> Tuple[bt.subtensor, bt.metagraph]:
if SUBTENSOR_ENDPOINT:
parser = argparse.ArgumentParser()
bt.subtensor.add_args(parser)
subtensor = bt.subtensor(
config=bt.config(
parser=parser,
args=["--subtensor.chain_endpoint", SUBTENSOR_ENDPOINT],
)
)
else:
subtensor = bt.subtensor("finney")
metagraph = subtensor.metagraph(NETUID, lite=False)
return subtensor, metagraph
return run_with_retries(_internal)
def get_subnet_data(
subtensor: bt.subtensor, metagraph: bt.metagraph
) -> List[ModelData]:
result = []
for uid in metagraph.uids.tolist():
hotkey = metagraph.hotkeys[uid]
metadata = None
try:
metadata = run_with_retries(
functools.partial(get_metadata, subtensor, metagraph.netuid, hotkey)
)
except:
print(f"Failed to get metadata for UID {uid}: {traceback.format_exc()}")
if not metadata:
continue
commitment = metadata["info"]["fields"][0]
hex_data = commitment[list(commitment.keys())[0]][2:]
chain_str = bytes.fromhex(hex_data).decode()
block = metadata["block"]
incentive = np.nan_to_num(metagraph.incentive[uid]).item()
emission = (
np.nan_to_num(metagraph.emission[uid]).item() * 20
) # convert to daily TAO
model_data = None
try:
model_data = ModelData.from_compressed_str(
uid, hotkey, chain_str, block, incentive, emission
)
except:
continue
result.append(model_data)
return result
def get_wandb_runs(project: str, filters: Dict[str, Any]) -> List:
"""Get the latest runs from Wandb, retrying infinitely until we get them.
Returns:
List: List of runs matching the provided filters, newest run (by creation time) first.
"""
while True:
api = wandb.Api(api_key=WANDB_TOKEN)
runs = list(
api.runs(
project,
filters=filters,
order="-created_at",
)
)
if len(runs) > 0:
return runs
# WandDB API is quite unreliable. Wait another minute and try again.
bt.logging.error("Failed to get runs from Wandb. Trying again in 60 seconds.")
time.sleep(60)
def get_scores(
uids: List[int],
wandb_runs: List,
) -> Dict[int, Dict[str, Optional[float]]]:
"""Returns the most recent scores for the provided UIDs.
Args:
uids (List[int]): List of UIDs to get scores for.
wandb_runs (List): List of validator runs from Wandb. Requires the runs are provided in descending order.
"""
result = {}
previous_timestamp = None
# Iterate through the runs until we've processed all the uids.
for i, run in enumerate(wandb_runs):
if not "original_format_json" in run.summary:
continue
data = json.loads(run.summary["original_format_json"])
all_uid_data = data["uid_data"]
timestamp = data["timestamp"]
# Make sure runs are indeed in descending time order.
assert (
previous_timestamp is None or timestamp < previous_timestamp
), f"Timestamps are not in descending order: {timestamp} >= {previous_timestamp}"
previous_timestamp = timestamp
for uid in uids:
if uid in result:
continue
if str(uid) in all_uid_data:
uid_data = all_uid_data[str(uid)]
# Only the most recent run is fresh.
is_fresh = i == 0
result[uid] = {
"avg_loss": uid_data.get("average_loss", None),
"win_rate": uid_data.get("win_rate", None),
"win_total": uid_data.get("win_total", None),
"weight": uid_data.get("weight", None),
"competition_id": uid_data.get("competition_id", None),
"fresh": is_fresh,
}
if len(result) == len(uids):
break
return result
def get_validator_weights(
metagraph: bt.metagraph,
) -> Dict[int, Tuple[float, int, Dict[int, float]]]:
"""Returns a dictionary of validator UIDs to (vtrust, stake, {uid: weight})."""
ret = {}
for uid in metagraph.uids.tolist():
vtrust = metagraph.validator_trust[uid].item()
stake = metagraph.stake[uid].item()
if vtrust > 0 and stake > 10_000:
ret[uid] = (vtrust, stake, {})
for ouid in metagraph.uids.tolist():
if ouid == uid:
continue
weight = round(metagraph.weights[uid][ouid].item(), 4)
if weight > 0:
ret[uid][-1][ouid] = weight
return ret
def get_losses_over_time(wandb_runs: List) -> pd.DataFrame:
"""Returns a dataframe of the best average model loss over time."""
timestamps = []
datapoints_per_comp_id = {id: [] for id in competitions.COMPETITION_DETAILS}
for run in wandb_runs:
# For each run, check the 10 most recent steps.
best_loss_per_competition_id = defaultdict(lambda: math.inf)
should_add_datapoint = False
min_step = max(0, run.lastHistoryStep - 10)
history_scan = HistoryScan(
run.client, run, min_step, run.lastHistoryStep, page_size=10
)
max_timestamp = None
for step in history_scan:
if "original_format_json" not in step:
continue
data = json.loads(step["original_format_json"])
all_uid_data = data["uid_data"]
timestamp = datetime.datetime.fromtimestamp(data["timestamp"])
if max_timestamp is None:
max_timestamp = timestamp
max_timestamp = max(max_timestamp, timestamp)
for _, uid_data in all_uid_data.items():
loss = uid_data.get("average_loss", math.inf)
competition_id = uid_data.get("competition_id", None)
if not competition_id:
continue
if loss < best_loss_per_competition_id[competition_id]:
best_loss_per_competition_id[competition_id] = uid_data["average_loss"]
should_add_datapoint = True
# Now that we've processed the run's most recent steps, check if we should add a datapoint.
if should_add_datapoint:
timestamps.append(max_timestamp)
# Iterate through all possible competitions and add the best loss for each.
# Set None for any that aren't active during this run.
for id, losses in datapoints_per_comp_id.items():
losses.append(best_loss_per_competition_id.get(id, None))
# Create a dictionary of competitions to lists of losses.
output_columns = {competitions.COMPETITION_DETAILS[id].name: losses for id, losses in datapoints_per_comp_id.items()}
return pd.DataFrame({"timestamp": timestamps, **output_columns})
def next_epoch(subtensor: bt.subtensor, block: int) -> int:
return (
block
+ subtensor.get_subnet_hyperparameters(NETUID).tempo
- subtensor.blocks_since_epoch(NETUID, block)
)
def is_floatable(x) -> bool:
return (
isinstance(x, float) and not math.isnan(x) and not math.isinf(x)
) or isinstance(x, int)
def format_score(uid: int, scores, key) -> Optional[float]:
if uid in scores:
if key in scores[uid]:
point = scores[uid][key]
if is_floatable(point):
return round(scores[uid][key], 4)
return None
def leaderboard_data(
leaderboard: List[ModelData],
scores: Dict[int, Dict[str, Optional[float]]],
show_stale: bool,
) -> List[List[Any]]:
"""Returns the leaderboard data, based on models data and UID scores."""
return [
[
f"[{c.namespace}/{c.name} ({c.commit[0:8]})](https://huggingface.co/{c.namespace}/{c.name}/commit/{c.commit})",
format_score(c.uid, scores, "win_rate"),
format_score(c.uid, scores, "avg_loss"),
format_score(c.uid, scores, "weight"),
c.uid,
c.block,
]
for c in leaderboard
if (c.uid in scores and scores[c.uid]["fresh"]) or show_stale
]
def get_benchmarks() -> Tuple[pd.DataFrame, datetime.datetime]:
"""Returns the latest benchmarks and the time they were run."""
if not BENCHMARK_WANDB_PROJECT:
bt.logging.error("No benchmark project set.")
return None, None
runs = get_wandb_runs(project=BENCHMARK_WANDB_PROJECT, filters=None)
for run in runs:
artifacts = list(run.logged_artifacts())
if artifacts:
table = artifacts[-1].get("benchmarks")
if table:
return table.get_dataframe(), datetime.datetime.strptime(
run.metadata["startedAt"], "%Y-%m-%dT%H:%M:%S.%f"
)
bt.logging.error("Failed to get benchmarks from Wandb.")
return None, None
def make_validator_dataframe(
validator_df: pd.DataFrame, model_data: ModelData
) -> pd.DataFrame:
values = [
[uid, int(validator_df[uid][1]), round(validator_df[uid][0], 4)]
+ [validator_df[uid][-1].get(c.uid) for c in model_data if c.incentive]
for uid, _ in sorted(
zip(
validator_df.keys(),
[validator_df[x][1] for x in validator_df.keys()],
),
key=lambda x: x[1],
reverse=True,
)
]
dtypes = {"UID": int, "Stake (τ)": float, "V-Trust": float}
dtypes.update(
{
f"{c.namespace}/{c.name} ({c.commit[0:8]})": float
for c in model_data
if c.incentive
}
)
return pd.DataFrame(values, columns=dtypes.keys()).astype(dtypes)
def make_metagraph_dataframe(metagraph: bt.metagraph, weights=False) -> pd.DataFrame:
cols = [
"stake",
"emission",
"trust",
"validator_trust",
"dividends",
"incentive",
"R",
"consensus",
"validator_permit",
]
frame = pd.DataFrame({k: getattr(metagraph, k) for k in cols})
frame["block"] = metagraph.block.item()
frame["netuid"] = NETUID
frame["uid"] = range(len(frame))
frame["hotkey"] = [axon.hotkey for axon in metagraph.axons]
frame["coldkey"] = [axon.coldkey for axon in metagraph.axons]
if weights and metagraph.W is not None:
# convert NxN tensor to a list of lists so it fits into the dataframe
frame["weights"] = [w.tolist() for w in metagraph.W]
return frame
def load_state_vars() -> dict[Any]:
while True:
try:
subtensor, metagraph = get_subtensor_and_metagraph()
bt.logging.success("Loaded subtensor and metagraph")
model_data: List[ModelData] = get_subnet_data(subtensor, metagraph)
model_data.sort(key=lambda x: x.incentive, reverse=True)
bt.logging.success(f"Loaded {len(model_data)} models")
vali_runs = get_wandb_runs(
project=VALIDATOR_WANDB_PROJECT,
# TODO: Update to point to the OTF vali on finetuning
filters={"config.type": "validator", "config.uid": 0},
)
scores = get_scores([x.uid for x in model_data], vali_runs)
# TODO: Re-enable once ""SubtensorModule.BlocksSinceEpoch" not found" issue is resolved.
# current_block = metagraph.block.item()
# next_epoch_block = next_epoch(subtensor, current_block)
validator_df = get_validator_weights(metagraph)
weight_keys = set()
for uid, stats in validator_df.items():
weight_keys.update(stats[-1].keys())
# Enable benchmark if the flag is set
if BENCHMARK_FLAG:
benchmarks, benchmark_timestamp = get_benchmarks()
else:
benchmarks, benchmark_timestamp = None, None
break
except KeyboardInterrupt:
bt.logging.error("Exiting...")
break
except Exception as e:
print(f"Failed to get data: {traceback.format_exc()}")
time.sleep(30)
return {
"metagraph": metagraph,
"model_data": model_data,
"vali_runs": vali_runs,
"scores": scores,
"validator_df": validator_df,
"benchmarks": benchmarks,
"benchmark_timestamp": benchmark_timestamp,
}
def test_load_state_vars():
# TODO: Change to finetuning data.
subtensor = bt.subtensor("finney")
metagraph = subtensor.metagraph(NETUID, lite=True)
model_data = [
ModelData(
uid=253,
hotkey="5DjoPAgZ54Zf6NsuiVYh8RjonnWWWREE2iXBNzM2VDBMQDPm",
namespace="jw-hf-test",
name="jw2",
commit="aad131f6b02219964e6dcf749c2a23e75a7ceca8",
secure_hash="L1ImYzWJwV+9KSnZ2TYW0Iy2KMcVjJVTd30YJoRkpbw=",
block=3131103,
incentive=1.0,
emission=209.06051635742188,
),
ModelData(
uid=1,
hotkey="5CccVtjk4yamCao6QYgEg7jc8vktdj16RbLKNUftHfEsjuJS",
namespace="borggAI",
name="bittensor-subnet9-models",
commit="d373864bc6c972872edb8db95eed570958054bac",
secure_hash="+drdTIKYEGYClW2FFVVID6A2Dh//4rLmExRFCJsH6Y4=",
block=2081837,
incentive=0.0,
emission=0.0,
),
ModelData(
uid=2,
hotkey="5HYwoXaczs3jAptbb5mk4aUCkgZqeNcNzJKxSec97GwasfLy",
namespace="jungiebeen",
name="pretrain1",
commit="4c0c6bfd0f92e243d6c8a82209142e7204c852c3",
secure_hash="ld/agc0XIWICom/Cpj0fkQLcMogMNj/F65MJogK5RLY=",
block=2467482,
incentive=0.0,
emission=0.0,
),
ModelData(
uid=3,
hotkey="5Dnb6edh9yTeEp5aasRPZVPRAkxvQ6qnERVcXw22awMZ5rxm",
namespace="jungiebeen",
name="pretrain2",
commit="e827b7281c92224adb11124489cc45356553a87a",
secure_hash="ld/agc0XIWICom/Cpj0fkQLcMogMNj/F65MJogK5RLY=",
block=2467497,
incentive=0.0,
emission=0.0,
),
ModelData(
uid=4,
hotkey="5FRfca8NbnH424WaX43PMhKBnbLA1bZpRRoXXiVs6HgsxN4K",
namespace="ZainAli60",
name="mine_modeles",
commit="8a4ed4ad1f1fb58d424fd22e8e9874b87d32917c",
secure_hash="tVcbZAFoNIOF+Ntxq31OQ2NrLXf5iFCmmPUJlpkMYYo=",
block=2508509,
incentive=0.0,
emission=0.0,
),
]
vali_runs = get_wandb_runs(
project=VALIDATOR_WANDB_PROJECT,
filters={"config.type": "validator", "config.uid": 238},
)
scores = get_scores([x.uid for x in model_data], vali_runs)
validator_df = {
28: (1.0, 33273.4453125, {253: 1.0}),
49: (
0.9127794504165649,
10401.677734375,
{
7: 0.0867,
217: 0.0001,
219: 0.0001,
241: 0.0001,
248: 0.0001,
253: 0.9128,
},
),
78: (1.0, 26730.37109375, {253: 1.0}),
116: (1.0, 629248.4375, {253: 1.0}),
150: (1.0, 272634.53125, {253: 1.0}),
161: (1.0, 280212.53125, {253: 1.0}),
180: (1.0, 16838.0, {253: 1.0}),
184: (1.0, 47969.3984375, {253: 1.0}),
210: (1.0, 262846.28125, {253: 1.0}),
213: (1.0, 119462.734375, {253: 1.0}),
215: (1.0, 274747.46875, {253: 1.0}),
234: (1.0, 38831.6953125, {253: 1.0}),
236: (1.0, 183966.9375, {253: 1.0}),
238: (1.0, 1293707.25, {253: 1.0}),
240: (1.0, 106461.6015625, {253: 1.0}),
243: (1.0, 320271.5, {253: 1.0}),
244: (1.0, 116138.9609375, {253: 1.0}),
247: (0.9527428150177002, 119812.390625, {7: 0.0472, 253: 0.9528}),
249: (1.0, 478127.3125, {253: 1.0}),
252: (1.0, 442395.03125, {253: 1.0}),
254: (1.0, 46845.2109375, {253: 1.0}),
255: (1.0, 28977.56640625, {253: 1.0}),
}
return {
"metagraph": metagraph,
"model_data": model_data,
"vali_runs": vali_runs,
"scores": scores,
"validator_df": validator_df,
}