rusticluftig commited on
Commit
8c72a30
1 Parent(s): b33f70b

Add a line plot of loss over time

Browse files
Files changed (1) hide show
  1. app.py +75 -19
app.py CHANGED
@@ -5,7 +5,6 @@ import bittensor as bt
5
  from typing import Dict, List, Any, Optional, Tuple
6
  from bittensor.extrinsics.serving import get_metadata
7
  from dataclasses import dataclass
8
- import requests
9
  import wandb
10
  import math
11
  import os
@@ -16,6 +15,7 @@ import pandas as pd
16
  from dotenv import load_dotenv
17
  from huggingface_hub import HfApi
18
  from apscheduler.schedulers.background import BackgroundScheduler
 
19
 
20
  load_dotenv()
21
 
@@ -121,10 +121,12 @@ def get_subnet_data(
121
  hotkey = metagraph.hotkeys[uid]
122
  metadata = None
123
  try:
124
- metadata = run_with_retries(functools.partial(get_metadata, subtensor, metagraph.netuid, hotkey))
 
 
125
  except:
126
  print(f"Failed to get metadata for UID {uid}: {traceback.format_exc()}")
127
-
128
  if not metadata:
129
  continue
130
 
@@ -155,10 +157,8 @@ def is_floatable(x) -> bool:
155
  ) or isinstance(x, int)
156
 
157
 
158
- def get_scores(
159
- uids: List[int],
160
- ) -> Dict[int, Dict[str, Optional[float]]]:
161
- runs = []
162
  while True:
163
  api = wandb.Api(api_key=WANDB_TOKEN)
164
  runs = list(
@@ -168,15 +168,20 @@ def get_scores(
168
  )
169
  )
170
  if len(runs) > 0:
171
- break
172
  # WandDB API is quite unreliable. Wait another minute and try again.
173
  print("Failed to get runs from Wandb. Trying again in 60 seconds.")
174
  time.sleep(60)
175
-
 
 
 
 
 
176
  result = {}
177
  previous_timestamp = None
178
  # Iterate through the runs until we've processed all the uids.
179
- for i, run in enumerate(runs):
180
  if not "original_format_json" in run.summary:
181
  continue
182
  data = json.loads(run.summary["original_format_json"])
@@ -208,6 +213,30 @@ def get_scores(
208
  return result
209
 
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def format_score(uid: int, scores, key) -> Optional[float]:
212
  if uid in scores:
213
  if key in scores[uid]:
@@ -218,9 +247,11 @@ def format_score(uid: int, scores, key) -> Optional[float]:
218
 
219
 
220
  def next_epoch(subtensor: bt.subtensor, block: int) -> int:
221
- return block + subtensor.get_subnet_hyperparameters(
222
- NETUID
223
- ).tempo - subtensor.blocks_since_epoch(NETUID, block)
 
 
224
 
225
 
226
  def get_next_update_div(current_block: int, next_update_block: int) -> str:
@@ -232,9 +263,11 @@ def get_next_update_div(current_block: int, next_update_block: int) -> str:
232
  delta = next_update_time - now
233
  return f"""<div align="center" style="font-size: larger;">Next reward update: <b>{blocks_to_go}</b> blocks (~{int(delta.total_seconds() // 60)} minutes)</div>"""
234
 
 
235
  def get_last_updated_div() -> str:
236
  return f"""<div>Last Updated: {datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")} (UTC)</div>"""
237
 
 
238
  def leaderboard_data(
239
  leaderboard: List[ModelData],
240
  scores: Dict[int, Dict[str, Optional[float]]],
@@ -254,6 +287,7 @@ def leaderboard_data(
254
  if (c.uid in scores and scores[c.uid]["fresh"]) or show_stale
255
  ]
256
 
 
257
  def restart_space():
258
  API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)
259
 
@@ -264,7 +298,9 @@ def main():
264
  model_data: List[ModelData] = get_subnet_data(subtensor, metagraph)
265
  model_data.sort(key=lambda x: x.incentive, reverse=True)
266
 
267
- scores = get_scores([x.uid for x in model_data])
 
 
268
 
269
  current_block = metagraph.block.item()
270
  next_epoch_block = next_epoch(subtensor, current_block)
@@ -303,13 +339,34 @@ def main():
303
  visible=True,
304
  )
305
  gr.HTML(EVALUATION_DETAILS)
306
- show_stale.change(lambda stale: leaderboard_data(model_data, scores, stale), inputs=[show_stale], outputs=leaderboard_table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  with gr.Accordion("Validator Stats"):
309
  gr.components.Dataframe(
310
  value=[
311
  [uid, int(validator_df[uid][1]), round(validator_df[uid][0], 4)]
312
- + [validator_df[uid][-1].get(c.uid) for c in model_data if c.incentive]
 
 
 
 
313
  for uid, _ in sorted(
314
  zip(
315
  validator_df.keys(),
@@ -332,8 +389,6 @@ def main():
332
  )
333
  gr.HTML(value=get_last_updated_div())
334
 
335
-
336
-
337
  scheduler = BackgroundScheduler()
338
  scheduler.add_job(
339
  restart_space, "interval", seconds=60 * 30
@@ -341,5 +396,6 @@ def main():
341
  scheduler.start()
342
 
343
  demo.launch()
344
-
 
345
  main()
 
5
  from typing import Dict, List, Any, Optional, Tuple
6
  from bittensor.extrinsics.serving import get_metadata
7
  from dataclasses import dataclass
 
8
  import wandb
9
  import math
10
  import os
 
15
  from dotenv import load_dotenv
16
  from huggingface_hub import HfApi
17
  from apscheduler.schedulers.background import BackgroundScheduler
18
+ import pandas as pd
19
 
20
  load_dotenv()
21
 
 
121
  hotkey = metagraph.hotkeys[uid]
122
  metadata = None
123
  try:
124
+ metadata = run_with_retries(
125
+ functools.partial(get_metadata, subtensor, metagraph.netuid, hotkey)
126
+ )
127
  except:
128
  print(f"Failed to get metadata for UID {uid}: {traceback.format_exc()}")
129
+
130
  if not metadata:
131
  continue
132
 
 
157
  ) or isinstance(x, int)
158
 
159
 
160
+ def get_wandb_runs() -> List:
161
+ """Get the latest runs from Wandb, retrying infinitely until we get them."""
 
 
162
  while True:
163
  api = wandb.Api(api_key=WANDB_TOKEN)
164
  runs = list(
 
168
  )
169
  )
170
  if len(runs) > 0:
171
+ return runs
172
  # WandDB API is quite unreliable. Wait another minute and try again.
173
  print("Failed to get runs from Wandb. Trying again in 60 seconds.")
174
  time.sleep(60)
175
+
176
+
177
+ def get_scores(
178
+ uids: List[int],
179
+ wandb_runs: List,
180
+ ) -> Dict[int, Dict[str, Optional[float]]]:
181
  result = {}
182
  previous_timestamp = None
183
  # Iterate through the runs until we've processed all the uids.
184
+ for i, run in enumerate(wandb_runs):
185
  if not "original_format_json" in run.summary:
186
  continue
187
  data = json.loads(run.summary["original_format_json"])
 
213
  return result
214
 
215
 
216
+ def get_losses_over_time(wandb_runs: List) -> pd.DataFrame:
217
+ """Returns a dataframe of the best average model loss over time."""
218
+ timestamps = []
219
+ best_losses = []
220
+
221
+ for run in wandb_runs:
222
+ if "original_format_json" not in run.summary:
223
+ continue
224
+ data = json.loads(run.summary["original_format_json"])
225
+ all_uid_data = data["uid_data"]
226
+ timestamp = datetime.datetime.fromtimestamp(data["timestamp"])
227
+ best_loss = math.inf
228
+ for _, uid_data in all_uid_data.items():
229
+ loss = uid_data.get("average_loss", math.inf)
230
+ # Filter out the numbers from the exploit.
231
+ if loss < best_loss and (loss > 2.5 or timestamp > datetime.datetime(2024,2,8)):
232
+ best_loss = uid_data["average_loss"]
233
+ if best_loss != math.inf:
234
+ timestamps.append(timestamp)
235
+ best_losses.append(best_loss)
236
+
237
+ return pd.DataFrame({"timestamp": timestamps, "best_loss": best_losses})
238
+
239
+
240
  def format_score(uid: int, scores, key) -> Optional[float]:
241
  if uid in scores:
242
  if key in scores[uid]:
 
247
 
248
 
249
  def next_epoch(subtensor: bt.subtensor, block: int) -> int:
250
+ return (
251
+ block
252
+ + subtensor.get_subnet_hyperparameters(NETUID).tempo
253
+ - subtensor.blocks_since_epoch(NETUID, block)
254
+ )
255
 
256
 
257
  def get_next_update_div(current_block: int, next_update_block: int) -> str:
 
263
  delta = next_update_time - now
264
  return f"""<div align="center" style="font-size: larger;">Next reward update: <b>{blocks_to_go}</b> blocks (~{int(delta.total_seconds() // 60)} minutes)</div>"""
265
 
266
+
267
  def get_last_updated_div() -> str:
268
  return f"""<div>Last Updated: {datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")} (UTC)</div>"""
269
 
270
+
271
  def leaderboard_data(
272
  leaderboard: List[ModelData],
273
  scores: Dict[int, Dict[str, Optional[float]]],
 
287
  if (c.uid in scores and scores[c.uid]["fresh"]) or show_stale
288
  ]
289
 
290
+
291
  def restart_space():
292
  API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)
293
 
 
298
  model_data: List[ModelData] = get_subnet_data(subtensor, metagraph)
299
  model_data.sort(key=lambda x: x.incentive, reverse=True)
300
 
301
+ wandb_runs = get_wandb_runs()
302
+
303
+ scores = get_scores([x.uid for x in model_data], wandb_runs)
304
 
305
  current_block = metagraph.block.item()
306
  next_epoch_block = next_epoch(subtensor, current_block)
 
339
  visible=True,
340
  )
341
  gr.HTML(EVALUATION_DETAILS)
342
+ show_stale.change(
343
+ lambda stale: leaderboard_data(model_data, scores, stale),
344
+ inputs=[show_stale],
345
+ outputs=leaderboard_table,
346
+ )
347
+
348
+ gr.LinePlot(
349
+ get_losses_over_time(wandb_runs),
350
+ x="timestamp",
351
+ x_title="Date",
352
+ y="best_loss",
353
+ y_title="Average Loss",
354
+ tooltip="best_loss",
355
+ interactive=True,
356
+ visible=True,
357
+ width=1024,
358
+ title="Best Average Loss Over Time",
359
+ )
360
 
361
  with gr.Accordion("Validator Stats"):
362
  gr.components.Dataframe(
363
  value=[
364
  [uid, int(validator_df[uid][1]), round(validator_df[uid][0], 4)]
365
+ + [
366
+ validator_df[uid][-1].get(c.uid)
367
+ for c in model_data
368
+ if c.incentive
369
+ ]
370
  for uid, _ in sorted(
371
  zip(
372
  validator_df.keys(),
 
389
  )
390
  gr.HTML(value=get_last_updated_div())
391
 
 
 
392
  scheduler = BackgroundScheduler()
393
  scheduler.add_job(
394
  restart_space, "interval", seconds=60 * 30
 
396
  scheduler.start()
397
 
398
  demo.launch()
399
+
400
+
401
  main()