cassova commited on
Commit
80820db
·
1 Parent(s): 338a112

black + fixed timezone on dates

Browse files
Files changed (1) hide show
  1. utils.py +436 -192
utils.py CHANGED
@@ -14,67 +14,72 @@ import plotly.express as px
14
  # TODO: Store relevant wandb data in a database for faster access
15
 
16
 
17
- MIN_STEPS = 10 # minimum number of steps in wandb run in order to be worth analyzing
18
  NETUID = 1
19
- BASE_PATH = 'macrocosmos/prompting-validators'
20
- NETWORK = 'finney'
21
- KEYS = ['_step','_timestamp','task','query','reference','challenge','topic','subtopic']
 
 
 
 
 
 
 
 
 
22
  ABBREV_CHARS = 8
23
- ENTITY_CHOICES = ('identity', 'hotkey', 'coldkey')
24
- LOCAL_WANDB_PATH = './data/wandb'
25
- USERNAME = 'taostats'
26
 
27
  # Initialize wandb with anonymous login
28
- wandb.login(anonymous='must')
29
  api = wandb.Api(timeout=600)
30
 
31
  IDENTITIES = {
32
- '5F4tQyWrhfGVcNhoqeiNsR6KjD4wMZ2kfhLj4oHYuyHbZAc3': 'opentensor',
33
- '5Hddm3iBFD2GLT5ik7LZnT3XJUnRnN8PoeCFgGQgawUVKNm8': 'taostats',
34
- '5HEo565WAy4Dbq3Sv271SAi7syBSofyfhhwRNjFNSM2gP9M2': 'foundry',
35
- '5HK5tp6t2S59DywmHRWPBVJeJ86T61KjurYqeooqj8sREpeN': 'bittensor-guru',
36
- '5FFApaS75bv5pJHfAp2FVLBj9ZaXuFDjEypsaBNc1wCfe52v': 'roundtable-21',
37
- '5EhvL1FVkQPpMjZX4MAADcW42i3xPSF1KiCpuaxTYVr28sux': 'tao-validator',
38
- '5FKstHjZkh4v3qAMSBa1oJcHCLjxYZ8SNTSz1opTv4hR7gVB': 'datura',
39
- '5DvTpiniW9s3APmHRYn8FroUWyfnLtrsid5Mtn5EwMXHN2ed': 'first-tensor',
40
- '5HbLYXUBy1snPR8nfioQ7GoA9x76EELzEq9j7F32vWUQHm1x': 'tensorplex',
41
- '5CsvRJXuR955WojnGMdok1hbhffZyB4N5ocrv82f3p5A2zVp': 'owl-ventures',
42
- '5CXRfP2ekFhe62r7q3vppRajJmGhTi7vwvb2yr79jveZ282w': 'rizzo',
43
- '5HNQURvmjjYhTSksi8Wfsw676b4owGwfLR2BFAQzG7H3HhYf': 'neural-internet'
44
  }
45
 
46
  EXTRACTORS = {
47
- 'state': lambda x: x.state,
48
- 'run_id': lambda x: x.id,
49
- 'run_path': lambda x: os.path.join(BASE_PATH, x.id),
50
- 'user': lambda x: x.user.name[:16],
51
- 'username': lambda x: x.user.username[:16],
52
- 'created_at': lambda x: pd.Timestamp(x.created_at),
53
- 'last_event_at': lambda x: pd.Timestamp(x.summary.get('_timestamp'), unit='s'),
54
-
55
  # 'netuid': lambda x: x.config.get('netuid'),
56
  # 'mock': lambda x: x.config.get('neuron').get('mock'),
57
  # 'sample_size': lambda x: x.config.get('neuron').get('sample_size'),
58
  # 'timeout': lambda x: x.config.get('neuron').get('timeout'),
59
  # 'epoch_length': lambda x: x.config.get('neuron').get('epoch_length'),
60
  # 'disable_set_weights': lambda x: x.config.get('neuron').get('disable_set_weights'),
61
-
62
  # This stuff is from the last logged event
63
- 'num_steps': lambda x: x.summary.get('_step'),
64
- 'runtime': lambda x: x.summary.get('_runtime'),
65
- 'query': lambda x: x.summary.get('query'),
66
- 'challenge': lambda x: x.summary.get('challenge'),
67
- 'reference': lambda x: x.summary.get('reference'),
68
- 'completions': lambda x: x.summary.get('completions'),
69
-
70
- 'version': lambda x: x.tags[0],
71
- 'spec_version': lambda x: x.tags[1],
72
- 'vali_hotkey': lambda x: x.tags[2],
73
  # 'tasks_selected': lambda x: x.tags[3:],
74
-
75
  # System metrics
76
- 'disk_read': lambda x: x.system_metrics.get('system.disk.in'),
77
- 'disk_write': lambda x: x.system_metrics.get('system.disk.out'),
78
  # Really slow stuff below
79
  # 'started_at': lambda x: x.metadata.get('startedAt'),
80
  # 'disk_used': lambda x: x.metadata.get('disk').get('/').get('used'),
@@ -82,50 +87,64 @@ EXTRACTORS = {
82
  }
83
 
84
 
85
- def get_leaderboard(df, ntop=10, entity_choice='identity'):
86
 
87
- df = df.loc[df.validator_permit==False]
88
  df.index = range(df.shape[0])
89
  return df.groupby(entity_choice).I.sum().sort_values().reset_index().tail(ntop)
90
 
 
91
  @st.cache_data()
92
  def get_metagraph(time):
93
- print(f'Loading metagraph with time {time}')
94
  subtensor = bt.subtensor(network=NETWORK)
95
  m = subtensor.metagraph(netuid=NETUID)
96
- meta_cols = ['I','stake','trust','validator_trust','validator_permit','C','R','E','dividends','last_update']
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  df_m = pd.DataFrame({k: getattr(m, k) for k in meta_cols})
99
- df_m['uid'] = range(m.n.item())
100
- df_m['hotkey'] = list(map(lambda a: a.hotkey, m.axons))
101
- df_m['coldkey'] = list(map(lambda a: a.coldkey, m.axons))
102
- df_m['ip'] = list(map(lambda a: a.ip, m.axons))
103
- df_m['port'] = list(map(lambda a: a.port, m.axons))
104
- df_m['coldkey'] = df_m.coldkey.str[:ABBREV_CHARS]
105
- df_m['hotkey'] = df_m.hotkey.str[:ABBREV_CHARS]
106
- df_m['identity'] = df_m.apply(lambda x: f'{x.hotkey} @ uid {x.uid}', axis=1)
107
  return df_m
108
 
109
 
110
  @st.cache_data(show_spinner=False)
111
  def load_downloaded_runs(time, cols=KEYS):
112
 
113
- list_cols = ['rewards','uids']
114
- extra_cols = ['turn']
115
  df_all = pd.DataFrame()
116
 
117
- progress = st.progress(0, text='Loading downloaded data')
118
- paths = glob.glob(os.path.join(LOCAL_WANDB_PATH,'*.parquet'))
119
  for i, path in enumerate(paths):
120
- run_id = path.split('/')[-1].split('.')[0]
121
  frame = pd.read_parquet(path).dropna(subset=cols)
122
- frame._timestamp = frame._timestamp.apply(pd.to_datetime, unit='s')
123
  # handle missing extra cols such as turn which depend on the version of the codebase
124
  found_extra_cols = [c for c in frame.columns if c in extra_cols]
125
- df_long = frame[cols+list_cols+found_extra_cols].explode(list_cols)
126
 
127
- prog_msg = f'Downloading data {i/len(paths)*100:.0f}%'
128
- progress.progress(i/len(paths), text=f'{prog_msg}... **downloading** `{run_id}`')
 
 
129
 
130
  df_all = pd.concat([df_all, df_long.assign(run_id=run_id)], ignore_index=True)
131
 
@@ -133,17 +152,16 @@ def load_downloaded_runs(time, cols=KEYS):
133
 
134
  # Ensure we have consistent naming schema for tasks
135
  task_mapping = {
136
- 'date-based question answering': 'date_qa',
137
- 'question-answering': 'qa',
138
  }
139
-
140
- df_all['task'] = df_all.task.apply(lambda x: task_mapping.get(x, x))
141
 
 
142
 
143
  # Runs which do not have a turn field are imputed to be turn zero (single turn)
144
- df_all['turn'] = df_all.turn.fillna(0)
145
 
146
- df_all.sort_values(by=['_timestamp'], inplace=True)
147
 
148
  return df_all
149
 
@@ -151,88 +169,129 @@ def load_downloaded_runs(time, cols=KEYS):
151
  @st.cache_data(show_spinner=False)
152
  def build_data(timestamp=None, path=BASE_PATH, min_steps=MIN_STEPS, use_cache=True):
153
 
154
- save_path = '_saved_runs.csv'
155
  filters = {}
156
  df = pd.DataFrame()
157
  # Load the last saved runs so that we only need to update the new ones
158
  if use_cache and os.path.exists(save_path):
159
  df = pd.read_csv(save_path)
160
- df['created_at'] = pd.to_datetime(df['created_at'])
161
- df['last_event_at'] = pd.to_datetime(df['last_event_at'])
162
 
163
- timestamp_str = df['last_event_at'].max().isoformat()
164
- filters.update({'updated_at': {'$gte': timestamp_str}})
165
 
166
- progress = st.progress(0, text='Loading data')
167
 
168
  runs = api.runs(path, filters=filters)
169
 
170
  run_data = []
171
  n_events = 0
172
  for i, run in enumerate(tqdm.tqdm(runs, total=len(runs))):
173
- num_steps = run.summary.get('_step',0)
174
- if num_steps<min_steps:
175
  continue
176
  n_events += num_steps
177
- prog_msg = f'Loading data {i/len(runs)*100:.0f}%, (total {n_events:,.0f} events)'
178
- progress.progress(i/len(runs),text=f'{prog_msg}... **downloading** `{os.path.join(*run.path)}`')
179
- if 'netuid_1' in run.tags or 'netuid_61' in run.tags or 'netuid_102' in run.tags:
 
 
 
 
 
 
 
 
 
180
  run_data.append(run)
181
 
182
-
183
  progress.empty()
184
 
185
- df_new = pd.DataFrame([{k: func(run) for k, func in EXTRACTORS.items()} for run in tqdm.tqdm(run_data, total=len(run_data))])
 
 
 
 
 
186
  df = pd.concat([df, df_new], ignore_index=True)
187
- df['duration'] = (df.last_event_at - df.created_at).round('s')
188
- df['identity'] = df['vali_hotkey'].map(IDENTITIES).fillna('unknown')
189
- df['vali_hotkey'] = df['vali_hotkey'].str[:ABBREV_CHARS]
 
 
 
 
 
 
 
190
 
191
  # Drop events that are not related to validator queries
192
- df.dropna(subset='query', inplace=True)
193
 
194
  print(df.completions.apply(type).value_counts())
195
  # Assumes completions is in the frame
196
- df['completions'] = df['completions'].apply(lambda x: x if isinstance(x, list) else eval(x))
197
-
198
- df['completion_words'] = df.completions.apply(lambda x: sum([len(xx.split()) for xx in x]) if isinstance(x, list) else 0)
199
- df['validator_words'] = df.apply(lambda x: len(str(x.query).split()) + len(str(x.challenge).split()) + len(str(x.reference).split()), axis=1 )
 
 
 
 
 
 
 
 
 
200
 
201
  df.to_csv(save_path, index=False)
202
 
203
  return df
204
 
 
205
  @st.cache_data()
206
  def normalize_rewards(df, turn=0, percentile=0.98):
207
- top_reward_stats = df.loc[df.turn==turn].astype({'rewards':float}).groupby('task').rewards.quantile(percentile)
208
-
209
- df['best_reward'] = df.task.map(top_reward_stats)
210
- df['normalized_rewards'] = df['rewards'].astype(float) / df['best_reward']
 
 
 
 
 
211
  return df
212
 
 
213
  @st.cache_data(show_spinner=False)
214
  def download_runs(time, df_vali):
215
 
216
  pbar = tqdm.tqdm(df_vali.index, total=len(df_vali))
217
 
218
- progress = st.progress(0, text='Loading data')
219
 
220
  for i, idx in enumerate(pbar):
221
  row = df_vali.loc[idx]
222
 
223
- prog_msg = f'Downloading data {i/len(df_vali)*100:.0f}%'
224
- progress.progress(i/len(df_vali), text=f'{prog_msg}... **downloading** `{os.path.join(*row.run_id)}`')
 
 
 
225
 
226
- save_path = f'data/wandb/{row.run_id}.parquet'
227
  # Create the directory if it does not exist
228
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
229
 
230
  if os.path.exists(save_path):
231
- pbar.set_description(f'>> Skipping {row.run_id!r} because file {save_path!r} already exists')
 
 
232
  continue
233
 
234
  try:
235
- pbar.set_description(f'* Downloading run {row.run_id!r}')
236
  run = api.run(row.run_path)
237
 
238
  # By default we just download a subset of events (500 most recent)
@@ -241,7 +300,9 @@ def download_runs(time, df_vali):
241
  except KeyboardInterrupt:
242
  break
243
  except Exception as e:
244
- pbar.set_description(f'- Something went wrong with {row.run_id!r}: {print_exc()}\n')
 
 
245
 
246
  progress.empty()
247
 
@@ -250,24 +311,41 @@ def get_productivity(df_runs):
250
 
251
  total_duration = df_runs.last_event_at.max() - df_runs.created_at.min()
252
  total_steps = df_runs.num_steps.sum()
253
- total_completions = (df_runs.num_steps*100).sum() #TODO: Parse from df
254
- total_completion_words = (df_runs.num_steps*df_runs.completion_words).sum()
255
- total_completion_tokens = round(total_completion_words/0.75)
256
- total_validator_words = (df_runs.num_steps*df_runs.apply(lambda x: len(str(x.query).split()) + len(str(x.challenge).split()) + len(str(x.reference).split()), axis=1 )).sum()
257
- total_validator_tokens = round(total_validator_words/0.75)
 
 
 
 
 
 
 
 
258
  total_dataset_tokens = total_completion_tokens + total_validator_tokens
259
 
260
  return {
261
- 'duration':total_duration,
262
- 'total_events':total_steps,
263
- 'total_completions':total_completions,
264
- 'total_completion_tokens':total_completion_tokens,
265
- 'total_validator_tokens':total_validator_tokens,
266
- 'total_tokens':total_dataset_tokens,
267
  }
268
 
 
269
  @st.cache_data(show_spinner=False)
270
- def get_reward_stats(df, exclude_multiturn=True, freq='D', remove_zero_rewards=True, agg='mean', date_min='2024-01-22', date_max='2024-08-12'): #TODO: Set the date_max to the current date
 
 
 
 
 
 
 
 
271
 
272
  df = df.loc[df._timestamp.between(pd.Timestamp(date_min), pd.Timestamp(date_max))]
273
  if exclude_multiturn:
@@ -275,145 +353,311 @@ def get_reward_stats(df, exclude_multiturn=True, freq='D', remove_zero_rewards=T
275
  if remove_zero_rewards:
276
  df = df.loc[df.rewards > 0]
277
 
278
- groups = ['run_id',pd.Grouper(key='_timestamp',freq=freq),'task']
279
- return df.groupby(groups).agg({'rewards':agg, 'normalized_rewards':agg})
 
280
 
281
  def get_release_dates():
282
- release_dates = pd.DataFrame([
283
- {'version': '1.0.0', 'release_date': pd.Timestamp(month=1, day=22, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':['qa','summarization']},
284
- {'version': '1.0.1', 'release_date': pd.Timestamp(month=1, day=22, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':[]},
285
- {'version': '1.0.2', 'release_date': pd.Timestamp(month=1, day=24, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':['qa','summarization']},
286
- {'version': '1.0.3', 'release_date': pd.Timestamp(month=2, day=14, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':[]},
287
- {'version': '1.0.4', 'release_date': pd.Timestamp(month=2, day=15, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':[]},
288
- {'version': '1.1.0', 'release_date': pd.Timestamp(month=2, day=21, year=2024), 'note': 'decay scores', 'model': 'zephyr', 'tasks_affected':['date_qa','math']},
289
- {'version': '1.1.1', 'release_date': pd.Timestamp(month=2, day=28, year=2024), 'note': 'reduce penalty weight', 'model': 'zephyr', 'tasks_affected':['date_qa','qa','summarization']},
290
- {'version': '1.1.2', 'release_date': pd.Timestamp(month=2, day=29, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':[]},
291
- {'version': '1.1.3', 'release_date': pd.Timestamp(month=3, day=11, year=2024), 'note': '', 'model': 'zephyr', 'tasks_affected':[]},
292
- {'version': '1.2.0', 'release_date': pd.Timestamp(month=3, day=19, year=2024), 'note': 'vllm', 'model': 'zephyr', 'tasks_affected':[]},
293
- {'version': '1.3.0', 'release_date': pd.Timestamp(month=3, day=27, year=2024), 'note': '', 'model': 'solar', 'tasks_affected':['all','math']},
294
- {'version': '2.0.0', 'release_date': pd.Timestamp(month=4, day=4, year=2024), 'note': 'streaming', 'model': 'solar', 'tasks_affected':['math','qa','summarization']},
295
- {'version': '2.1.0', 'release_date': pd.Timestamp(month=4, day=18, year=2024), 'note': 'chattensor prompt', 'model': 'solar', 'tasks_affected':['generic']},
296
- {'version': '2.2.0', 'release_date': pd.Timestamp(month=5, day=1, year=2024), 'note': 'multiturn + paraphrase', 'model': 'solar', 'tasks_affected':['sentiment','translation','math']},
297
- {'version': '2.3.0', 'release_date': pd.Timestamp(month=5, day=20, year=2024), 'note': 'llama + freeform date', 'model': 'llama', 'tasks_affected':['all','date_qa']},
298
- {'version': '2.3.1', 'release_date': pd.Timestamp(month=5, day=21, year=2024), 'note': '', 'model': 'llama', 'tasks_affected':['date_qa']},
299
- {'version': '2.4.0', 'release_date': pd.Timestamp(month=6, day=5, year=2024), 'note': 'streaming penalty', 'model': 'llama', 'tasks_affected':[]},
300
- {'version': '2.4.1', 'release_date': pd.Timestamp(month=6, day=6, year=2024), 'note': '', 'model': 'llama', 'tasks_affected':[]},
301
- {'version': '2.4.2', 'release_date': pd.Timestamp(month=6, day=7, year=2024), 'note': '', 'model': 'llama', 'tasks_affected':[]},
302
- {'version': '2.4.2', 'release_date': pd.Timestamp(month=6, day=7, year=2024), 'note': '', 'model': 'llama', 'tasks_affected':[]},
303
- {'version': '2.5.0', 'release_date': pd.Timestamp(month=6, day=18, year=2024), 'note': 'reduce multiturn', 'model': 'llama', 'tasks_affected':['translation','sentiment']},
304
- {'version': '2.5.1', 'release_date': pd.Timestamp(month=6, day=25, year=2024), 'note': 'reduce timeout', 'model': 'llama', 'tasks_affected':[]},
305
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  return release_dates
307
 
308
 
309
- def plot_reward_trends(df_stats, task='qa', window=14, col='normalized_reward', annotate=False, task_label='Question answering'):
 
 
 
 
 
 
 
310
 
311
  stats = df_stats.reset_index()
312
  release_dates = get_release_dates()
313
- stats_task = stats.loc[(stats.task == task)].sort_values(by='_timestamp')
314
- stats_task['rewards_ma'] = stats_task[col].rolling(window, min_periods=0).mean()
315
- fig = px.area(stats_task,
316
- x='_timestamp', y='rewards_ma',
317
- title=f'Reward Trend for {task_label} Task',
318
- labels={'rewards_ma': f'Rewards [{window} day avg.]','_timestamp':''},
319
- width=800,height=600,
320
- )
 
 
 
321
 
322
  if not annotate:
323
  return fig
324
 
325
  # Add annotations based on relevant releases
326
  for idx, row in release_dates.iterrows():
327
- line_color = 'grey'
328
- if task in row['tasks_affected']:
329
- line_color='red'
330
- elif 'all' not in row['tasks_affected']:
331
- line_color='blue'
332
  # TODO add annotation or something
333
- fig.add_vline(row['release_date'], line_color=line_color, opacity=0.6, line_dash='dot', line_width=1)#, annotation_text=str(v))
 
 
 
 
 
 
334
 
335
  return fig
336
 
 
337
  @st.cache_data()
338
  def get_task_counts(df_runs, df_events):
339
- # Get mapping from run id to prompting repo version
340
- run_to_version = df_runs.set_index('run_id').version.to_dict()
 
 
341
 
342
- df_events['version'] = df_events.run_id.map(run_to_version)
343
-
344
  def version_to_spec(version):
345
- major, minor, patch = version.split('.')
346
  return 10_000 * major + 100 * minor + patch
347
-
348
  def get_closest_prev_version(version, my_versions):
349
-
350
  ref_spec = version_to_spec(version)
351
  my_specs = list(map(version_to_spec, my_versions))
352
-
353
  match = my_specs[0]
354
  for spec in my_specs[1:]:
355
- if spec>ref_spec:
356
  break
357
-
358
  match = spec
359
-
360
  return my_versions[my_specs.index(match)]
361
-
362
  # Now estimate the distribution of tasks for each version using the event data
363
- task_rate = df_events.groupby('version').task.value_counts(normalize=True).unstack().fillna(0)
 
 
 
 
 
364
  # Impute missing versions
365
  for v in sorted(df_runs.version.unique()):
366
  if v not in task_rate.index:
367
  prev_version = get_closest_prev_version(v, list(task_rate.index))
368
- print(f'Imputing version {v} with task rate from closes previous version {prev_version!r}')
 
 
369
  task_rate.loc[v] = task_rate.loc[prev_version]
370
-
371
  # get esimated number of each task generated in every run using summary dataframe
372
- task_counts = df_runs.set_index('created_at').sort_index().apply(lambda x: round(task_rate.loc[x.version]*x.num_steps), axis=1).cumsum()
373
- return task_counts
 
 
 
 
 
374
 
375
 
376
  def load_state_vars(username=USERNAME, percentile=0.95):
377
 
378
  UPDATE_INTERVAL = 600
379
 
380
- df_runs = build_data(time.time()//UPDATE_INTERVAL, use_cache=False)
381
-
382
  # df_runs = df_runs.loc[df_runs.netuid.isin([1,61,102])] # Now we filter for the netuid tag in build_data
383
- st.toast(f'Loaded {len(df_runs)} runs')
384
 
385
  df_vali = df_runs.loc[df_runs.username == username]
386
  # df_vali = df_runs
387
 
388
- download_runs(time.time()//UPDATE_INTERVAL, df_vali)
389
 
390
- df_events = load_downloaded_runs(time.time()//UPDATE_INTERVAL)
391
  df_events = normalize_rewards(df_events, percentile=percentile)
392
 
393
- yesterday = pd.Timestamp.now() - pd.Timedelta('1d')
394
- runs_alive_24h_ago = (df_runs.last_event_at > yesterday)
395
 
396
  df_runs_24h = df_runs.loc[runs_alive_24h_ago]
397
-
398
  # weight factor indicates the fraction of events that happened within the last 24 hour.
399
- fraction = 1 - (yesterday - df_runs_24h.created_at) / (pd.Timestamp.now()- df_runs_24h.created_at)
400
- df_runs_24h['fraction'] = fraction.clip(0,1)
401
- df_runs_24h['num_steps'] *= fraction.clip(0,1)
402
-
 
 
403
  df_task_counts = get_task_counts(df_runs, df_events)
404
 
405
- df_m = get_metagraph(time.time()//UPDATE_INTERVAL)
406
 
407
  return {
408
- 'df_runs': df_runs,
409
- 'df_runs_24h': df_runs_24h,
410
- 'df_vali': df_vali,
411
- 'df_events': df_events,
412
- 'metagraph': df_m,
413
- 'df_task_counts': df_task_counts
414
  }
415
 
416
 
417
- if __name__ == '__main__':
418
 
419
- pass
 
14
  # TODO: Store relevant wandb data in a database for faster access
15
 
16
 
17
+ MIN_STEPS = 10 # minimum number of steps in wandb run in order to be worth analyzing
18
  NETUID = 1
19
+ BASE_PATH = "macrocosmos/prompting-validators"
20
+ NETWORK = "finney"
21
+ KEYS = [
22
+ "_step",
23
+ "_timestamp",
24
+ "task",
25
+ "query",
26
+ "reference",
27
+ "challenge",
28
+ "topic",
29
+ "subtopic",
30
+ ]
31
  ABBREV_CHARS = 8
32
+ ENTITY_CHOICES = ("identity", "hotkey", "coldkey")
33
+ LOCAL_WANDB_PATH = "./data/wandb"
34
+ USERNAME = "taostats"
35
 
36
  # Initialize wandb with anonymous login
37
+ wandb.login(anonymous="must")
38
  api = wandb.Api(timeout=600)
39
 
40
  IDENTITIES = {
41
+ "5F4tQyWrhfGVcNhoqeiNsR6KjD4wMZ2kfhLj4oHYuyHbZAc3": "opentensor",
42
+ "5Hddm3iBFD2GLT5ik7LZnT3XJUnRnN8PoeCFgGQgawUVKNm8": "taostats",
43
+ "5HEo565WAy4Dbq3Sv271SAi7syBSofyfhhwRNjFNSM2gP9M2": "foundry",
44
+ "5HK5tp6t2S59DywmHRWPBVJeJ86T61KjurYqeooqj8sREpeN": "bittensor-guru",
45
+ "5FFApaS75bv5pJHfAp2FVLBj9ZaXuFDjEypsaBNc1wCfe52v": "roundtable-21",
46
+ "5EhvL1FVkQPpMjZX4MAADcW42i3xPSF1KiCpuaxTYVr28sux": "tao-validator",
47
+ "5FKstHjZkh4v3qAMSBa1oJcHCLjxYZ8SNTSz1opTv4hR7gVB": "datura",
48
+ "5DvTpiniW9s3APmHRYn8FroUWyfnLtrsid5Mtn5EwMXHN2ed": "first-tensor",
49
+ "5HbLYXUBy1snPR8nfioQ7GoA9x76EELzEq9j7F32vWUQHm1x": "tensorplex",
50
+ "5CsvRJXuR955WojnGMdok1hbhffZyB4N5ocrv82f3p5A2zVp": "owl-ventures",
51
+ "5CXRfP2ekFhe62r7q3vppRajJmGhTi7vwvb2yr79jveZ282w": "rizzo",
52
+ "5HNQURvmjjYhTSksi8Wfsw676b4owGwfLR2BFAQzG7H3HhYf": "neural-internet",
53
  }
54
 
55
  EXTRACTORS = {
56
+ "state": lambda x: x.state,
57
+ "run_id": lambda x: x.id,
58
+ "run_path": lambda x: os.path.join(BASE_PATH, x.id),
59
+ "user": lambda x: x.user.name[:16],
60
+ "username": lambda x: x.user.username[:16],
61
+ "created_at": lambda x: pd.Timestamp(x.created_at),
62
+ "last_event_at": lambda x: pd.Timestamp(x.summary.get("_timestamp"), unit="s"),
 
63
  # 'netuid': lambda x: x.config.get('netuid'),
64
  # 'mock': lambda x: x.config.get('neuron').get('mock'),
65
  # 'sample_size': lambda x: x.config.get('neuron').get('sample_size'),
66
  # 'timeout': lambda x: x.config.get('neuron').get('timeout'),
67
  # 'epoch_length': lambda x: x.config.get('neuron').get('epoch_length'),
68
  # 'disable_set_weights': lambda x: x.config.get('neuron').get('disable_set_weights'),
 
69
  # This stuff is from the last logged event
70
+ "num_steps": lambda x: x.summary.get("_step"),
71
+ "runtime": lambda x: x.summary.get("_runtime"),
72
+ "query": lambda x: x.summary.get("query"),
73
+ "challenge": lambda x: x.summary.get("challenge"),
74
+ "reference": lambda x: x.summary.get("reference"),
75
+ "completions": lambda x: x.summary.get("completions"),
76
+ "version": lambda x: x.tags[0],
77
+ "spec_version": lambda x: x.tags[1],
78
+ "vali_hotkey": lambda x: x.tags[2],
 
79
  # 'tasks_selected': lambda x: x.tags[3:],
 
80
  # System metrics
81
+ "disk_read": lambda x: x.system_metrics.get("system.disk.in"),
82
+ "disk_write": lambda x: x.system_metrics.get("system.disk.out"),
83
  # Really slow stuff below
84
  # 'started_at': lambda x: x.metadata.get('startedAt'),
85
  # 'disk_used': lambda x: x.metadata.get('disk').get('/').get('used'),
 
87
  }
88
 
89
 
90
+ def get_leaderboard(df, ntop=10, entity_choice="identity"):
91
 
92
+ df = df.loc[df.validator_permit == False]
93
  df.index = range(df.shape[0])
94
  return df.groupby(entity_choice).I.sum().sort_values().reset_index().tail(ntop)
95
 
96
+
97
  @st.cache_data()
98
  def get_metagraph(time):
99
+ print(f"Loading metagraph with time {time}")
100
  subtensor = bt.subtensor(network=NETWORK)
101
  m = subtensor.metagraph(netuid=NETUID)
102
+ meta_cols = [
103
+ "I",
104
+ "stake",
105
+ "trust",
106
+ "validator_trust",
107
+ "validator_permit",
108
+ "C",
109
+ "R",
110
+ "E",
111
+ "dividends",
112
+ "last_update",
113
+ ]
114
 
115
  df_m = pd.DataFrame({k: getattr(m, k) for k in meta_cols})
116
+ df_m["uid"] = range(m.n.item())
117
+ df_m["hotkey"] = list(map(lambda a: a.hotkey, m.axons))
118
+ df_m["coldkey"] = list(map(lambda a: a.coldkey, m.axons))
119
+ df_m["ip"] = list(map(lambda a: a.ip, m.axons))
120
+ df_m["port"] = list(map(lambda a: a.port, m.axons))
121
+ df_m["coldkey"] = df_m.coldkey.str[:ABBREV_CHARS]
122
+ df_m["hotkey"] = df_m.hotkey.str[:ABBREV_CHARS]
123
+ df_m["identity"] = df_m.apply(lambda x: f"{x.hotkey} @ uid {x.uid}", axis=1)
124
  return df_m
125
 
126
 
127
  @st.cache_data(show_spinner=False)
128
  def load_downloaded_runs(time, cols=KEYS):
129
 
130
+ list_cols = ["rewards", "uids"]
131
+ extra_cols = ["turn"]
132
  df_all = pd.DataFrame()
133
 
134
+ progress = st.progress(0, text="Loading downloaded data")
135
+ paths = glob.glob(os.path.join(LOCAL_WANDB_PATH, "*.parquet"))
136
  for i, path in enumerate(paths):
137
+ run_id = path.split("/")[-1].split(".")[0]
138
  frame = pd.read_parquet(path).dropna(subset=cols)
139
+ frame._timestamp = frame._timestamp.apply(pd.to_datetime, unit="s")
140
  # handle missing extra cols such as turn which depend on the version of the codebase
141
  found_extra_cols = [c for c in frame.columns if c in extra_cols]
142
+ df_long = frame[cols + list_cols + found_extra_cols].explode(list_cols)
143
 
144
+ prog_msg = f"Downloading data {i/len(paths)*100:.0f}%"
145
+ progress.progress(
146
+ i / len(paths), text=f"{prog_msg}... **downloading** `{run_id}`"
147
+ )
148
 
149
  df_all = pd.concat([df_all, df_long.assign(run_id=run_id)], ignore_index=True)
150
 
 
152
 
153
  # Ensure we have consistent naming schema for tasks
154
  task_mapping = {
155
+ "date-based question answering": "date_qa",
156
+ "question-answering": "qa",
157
  }
 
 
158
 
159
+ df_all["task"] = df_all.task.apply(lambda x: task_mapping.get(x, x))
160
 
161
  # Runs which do not have a turn field are imputed to be turn zero (single turn)
162
+ df_all["turn"] = df_all.turn.fillna(0)
163
 
164
+ df_all.sort_values(by=["_timestamp"], inplace=True)
165
 
166
  return df_all
167
 
 
169
  @st.cache_data(show_spinner=False)
170
  def build_data(timestamp=None, path=BASE_PATH, min_steps=MIN_STEPS, use_cache=True):
171
 
172
+ save_path = "_saved_runs.csv"
173
  filters = {}
174
  df = pd.DataFrame()
175
  # Load the last saved runs so that we only need to update the new ones
176
  if use_cache and os.path.exists(save_path):
177
  df = pd.read_csv(save_path)
178
+ df["created_at"] = pd.to_datetime(df["created_at"])
179
+ df["last_event_at"] = pd.to_datetime(df["last_event_at"])
180
 
181
+ timestamp_str = df["last_event_at"].max().isoformat()
182
+ filters.update({"updated_at": {"$gte": timestamp_str}})
183
 
184
+ progress = st.progress(0, text="Loading data")
185
 
186
  runs = api.runs(path, filters=filters)
187
 
188
  run_data = []
189
  n_events = 0
190
  for i, run in enumerate(tqdm.tqdm(runs, total=len(runs))):
191
+ num_steps = run.summary.get("_step", 0)
192
+ if num_steps < min_steps:
193
  continue
194
  n_events += num_steps
195
+ prog_msg = (
196
+ f"Loading data {i/len(runs)*100:.0f}%, (total {n_events:,.0f} events)"
197
+ )
198
+ progress.progress(
199
+ i / len(runs),
200
+ text=f"{prog_msg}... **downloading** `{os.path.join(*run.path)}`",
201
+ )
202
+ if (
203
+ "netuid_1" in run.tags
204
+ or "netuid_61" in run.tags
205
+ or "netuid_102" in run.tags
206
+ ):
207
  run_data.append(run)
208
 
 
209
  progress.empty()
210
 
211
+ df_new = pd.DataFrame(
212
+ [
213
+ {k: func(run) for k, func in EXTRACTORS.items()}
214
+ for run in tqdm.tqdm(run_data, total=len(run_data))
215
+ ]
216
+ )
217
  df = pd.concat([df, df_new], ignore_index=True)
218
+
219
+ # Ensure that the timestamps are timezone aware
220
+ if df.last_event_at.dt.tz is None:
221
+ df.last_event_at = df.last_event_at.dt.tz_localize("UTC")
222
+ if df.created_at.dt.tz is None:
223
+ df.created_at = df.created_at.dt.tz_localize("UTC")
224
+
225
+ df["duration"] = (df.last_event_at - df.created_at).round("s")
226
+ df["identity"] = df["vali_hotkey"].map(IDENTITIES).fillna("unknown")
227
+ df["vali_hotkey"] = df["vali_hotkey"].str[:ABBREV_CHARS]
228
 
229
  # Drop events that are not related to validator queries
230
+ df.dropna(subset="query", inplace=True)
231
 
232
  print(df.completions.apply(type).value_counts())
233
  # Assumes completions is in the frame
234
+ df["completions"] = df["completions"].apply(
235
+ lambda x: x if isinstance(x, list) else eval(x)
236
+ )
237
+
238
+ df["completion_words"] = df.completions.apply(
239
+ lambda x: sum([len(xx.split()) for xx in x]) if isinstance(x, list) else 0
240
+ )
241
+ df["validator_words"] = df.apply(
242
+ lambda x: len(str(x.query).split())
243
+ + len(str(x.challenge).split())
244
+ + len(str(x.reference).split()),
245
+ axis=1,
246
+ )
247
 
248
  df.to_csv(save_path, index=False)
249
 
250
  return df
251
 
252
+
253
  @st.cache_data()
254
  def normalize_rewards(df, turn=0, percentile=0.98):
255
+ top_reward_stats = (
256
+ df.loc[df.turn == turn]
257
+ .astype({"rewards": float})
258
+ .groupby("task")
259
+ .rewards.quantile(percentile)
260
+ )
261
+
262
+ df["best_reward"] = df.task.map(top_reward_stats)
263
+ df["normalized_rewards"] = df["rewards"].astype(float) / df["best_reward"]
264
  return df
265
 
266
+
267
  @st.cache_data(show_spinner=False)
268
  def download_runs(time, df_vali):
269
 
270
  pbar = tqdm.tqdm(df_vali.index, total=len(df_vali))
271
 
272
+ progress = st.progress(0, text="Loading data")
273
 
274
  for i, idx in enumerate(pbar):
275
  row = df_vali.loc[idx]
276
 
277
+ prog_msg = f"Downloading data {i/len(df_vali)*100:.0f}%"
278
+ progress.progress(
279
+ i / len(df_vali),
280
+ text=f"{prog_msg}... **downloading** `{os.path.join(*row.run_id)}`",
281
+ )
282
 
283
+ save_path = f"data/wandb/{row.run_id}.parquet"
284
  # Create the directory if it does not exist
285
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
286
 
287
  if os.path.exists(save_path):
288
+ pbar.set_description(
289
+ f">> Skipping {row.run_id!r} because file {save_path!r} already exists"
290
+ )
291
  continue
292
 
293
  try:
294
+ pbar.set_description(f"* Downloading run {row.run_id!r}")
295
  run = api.run(row.run_path)
296
 
297
  # By default we just download a subset of events (500 most recent)
 
300
  except KeyboardInterrupt:
301
  break
302
  except Exception as e:
303
+ pbar.set_description(
304
+ f"- Something went wrong with {row.run_id!r}: {print_exc()}\n"
305
+ )
306
 
307
  progress.empty()
308
 
 
311
 
312
  total_duration = df_runs.last_event_at.max() - df_runs.created_at.min()
313
  total_steps = df_runs.num_steps.sum()
314
+ total_completions = (df_runs.num_steps * 100).sum() # TODO: Parse from df
315
+ total_completion_words = (df_runs.num_steps * df_runs.completion_words).sum()
316
+ total_completion_tokens = round(total_completion_words / 0.75)
317
+ total_validator_words = (
318
+ df_runs.num_steps
319
+ * df_runs.apply(
320
+ lambda x: len(str(x.query).split())
321
+ + len(str(x.challenge).split())
322
+ + len(str(x.reference).split()),
323
+ axis=1,
324
+ )
325
+ ).sum()
326
+ total_validator_tokens = round(total_validator_words / 0.75)
327
  total_dataset_tokens = total_completion_tokens + total_validator_tokens
328
 
329
  return {
330
+ "duration": total_duration,
331
+ "total_events": total_steps,
332
+ "total_completions": total_completions,
333
+ "total_completion_tokens": total_completion_tokens,
334
+ "total_validator_tokens": total_validator_tokens,
335
+ "total_tokens": total_dataset_tokens,
336
  }
337
 
338
+
339
  @st.cache_data(show_spinner=False)
340
+ def get_reward_stats(
341
+ df,
342
+ exclude_multiturn=True,
343
+ freq="D",
344
+ remove_zero_rewards=True,
345
+ agg="mean",
346
+ date_min="2024-01-22",
347
+ date_max="2024-08-12",
348
+ ): # TODO: Set the date_max to the current date
349
 
350
  df = df.loc[df._timestamp.between(pd.Timestamp(date_min), pd.Timestamp(date_max))]
351
  if exclude_multiturn:
 
353
  if remove_zero_rewards:
354
  df = df.loc[df.rewards > 0]
355
 
356
+ groups = ["run_id", pd.Grouper(key="_timestamp", freq=freq), "task"]
357
+ return df.groupby(groups).agg({"rewards": agg, "normalized_rewards": agg})
358
+
359
 
360
  def get_release_dates():
361
+ release_dates = pd.DataFrame(
362
+ [
363
+ {
364
+ "version": "1.0.0",
365
+ "release_date": pd.Timestamp(month=1, day=22, year=2024),
366
+ "note": "",
367
+ "model": "zephyr",
368
+ "tasks_affected": ["qa", "summarization"],
369
+ },
370
+ {
371
+ "version": "1.0.1",
372
+ "release_date": pd.Timestamp(month=1, day=22, year=2024),
373
+ "note": "",
374
+ "model": "zephyr",
375
+ "tasks_affected": [],
376
+ },
377
+ {
378
+ "version": "1.0.2",
379
+ "release_date": pd.Timestamp(month=1, day=24, year=2024),
380
+ "note": "",
381
+ "model": "zephyr",
382
+ "tasks_affected": ["qa", "summarization"],
383
+ },
384
+ {
385
+ "version": "1.0.3",
386
+ "release_date": pd.Timestamp(month=2, day=14, year=2024),
387
+ "note": "",
388
+ "model": "zephyr",
389
+ "tasks_affected": [],
390
+ },
391
+ {
392
+ "version": "1.0.4",
393
+ "release_date": pd.Timestamp(month=2, day=15, year=2024),
394
+ "note": "",
395
+ "model": "zephyr",
396
+ "tasks_affected": [],
397
+ },
398
+ {
399
+ "version": "1.1.0",
400
+ "release_date": pd.Timestamp(month=2, day=21, year=2024),
401
+ "note": "decay scores",
402
+ "model": "zephyr",
403
+ "tasks_affected": ["date_qa", "math"],
404
+ },
405
+ {
406
+ "version": "1.1.1",
407
+ "release_date": pd.Timestamp(month=2, day=28, year=2024),
408
+ "note": "reduce penalty weight",
409
+ "model": "zephyr",
410
+ "tasks_affected": ["date_qa", "qa", "summarization"],
411
+ },
412
+ {
413
+ "version": "1.1.2",
414
+ "release_date": pd.Timestamp(month=2, day=29, year=2024),
415
+ "note": "",
416
+ "model": "zephyr",
417
+ "tasks_affected": [],
418
+ },
419
+ {
420
+ "version": "1.1.3",
421
+ "release_date": pd.Timestamp(month=3, day=11, year=2024),
422
+ "note": "",
423
+ "model": "zephyr",
424
+ "tasks_affected": [],
425
+ },
426
+ {
427
+ "version": "1.2.0",
428
+ "release_date": pd.Timestamp(month=3, day=19, year=2024),
429
+ "note": "vllm",
430
+ "model": "zephyr",
431
+ "tasks_affected": [],
432
+ },
433
+ {
434
+ "version": "1.3.0",
435
+ "release_date": pd.Timestamp(month=3, day=27, year=2024),
436
+ "note": "",
437
+ "model": "solar",
438
+ "tasks_affected": ["all", "math"],
439
+ },
440
+ {
441
+ "version": "2.0.0",
442
+ "release_date": pd.Timestamp(month=4, day=4, year=2024),
443
+ "note": "streaming",
444
+ "model": "solar",
445
+ "tasks_affected": ["math", "qa", "summarization"],
446
+ },
447
+ {
448
+ "version": "2.1.0",
449
+ "release_date": pd.Timestamp(month=4, day=18, year=2024),
450
+ "note": "chattensor prompt",
451
+ "model": "solar",
452
+ "tasks_affected": ["generic"],
453
+ },
454
+ {
455
+ "version": "2.2.0",
456
+ "release_date": pd.Timestamp(month=5, day=1, year=2024),
457
+ "note": "multiturn + paraphrase",
458
+ "model": "solar",
459
+ "tasks_affected": ["sentiment", "translation", "math"],
460
+ },
461
+ {
462
+ "version": "2.3.0",
463
+ "release_date": pd.Timestamp(month=5, day=20, year=2024),
464
+ "note": "llama + freeform date",
465
+ "model": "llama",
466
+ "tasks_affected": ["all", "date_qa"],
467
+ },
468
+ {
469
+ "version": "2.3.1",
470
+ "release_date": pd.Timestamp(month=5, day=21, year=2024),
471
+ "note": "",
472
+ "model": "llama",
473
+ "tasks_affected": ["date_qa"],
474
+ },
475
+ {
476
+ "version": "2.4.0",
477
+ "release_date": pd.Timestamp(month=6, day=5, year=2024),
478
+ "note": "streaming penalty",
479
+ "model": "llama",
480
+ "tasks_affected": [],
481
+ },
482
+ {
483
+ "version": "2.4.1",
484
+ "release_date": pd.Timestamp(month=6, day=6, year=2024),
485
+ "note": "",
486
+ "model": "llama",
487
+ "tasks_affected": [],
488
+ },
489
+ {
490
+ "version": "2.4.2",
491
+ "release_date": pd.Timestamp(month=6, day=7, year=2024),
492
+ "note": "",
493
+ "model": "llama",
494
+ "tasks_affected": [],
495
+ },
496
+ {
497
+ "version": "2.4.2",
498
+ "release_date": pd.Timestamp(month=6, day=7, year=2024),
499
+ "note": "",
500
+ "model": "llama",
501
+ "tasks_affected": [],
502
+ },
503
+ {
504
+ "version": "2.5.0",
505
+ "release_date": pd.Timestamp(month=6, day=18, year=2024),
506
+ "note": "reduce multiturn",
507
+ "model": "llama",
508
+ "tasks_affected": ["translation", "sentiment"],
509
+ },
510
+ {
511
+ "version": "2.5.1",
512
+ "release_date": pd.Timestamp(month=6, day=25, year=2024),
513
+ "note": "reduce timeout",
514
+ "model": "llama",
515
+ "tasks_affected": [],
516
+ },
517
+ ]
518
+ )
519
  return release_dates
520
 
521
 
522
+ def plot_reward_trends(
523
+ df_stats,
524
+ task="qa",
525
+ window=14,
526
+ col="normalized_reward",
527
+ annotate=False,
528
+ task_label="Question answering",
529
+ ):
530
 
531
  stats = df_stats.reset_index()
532
  release_dates = get_release_dates()
533
+ stats_task = stats.loc[(stats.task == task)].sort_values(by="_timestamp")
534
+ stats_task["rewards_ma"] = stats_task[col].rolling(window, min_periods=0).mean()
535
+ fig = px.area(
536
+ stats_task,
537
+ x="_timestamp",
538
+ y="rewards_ma",
539
+ title=f"Reward Trend for {task_label} Task",
540
+ labels={"rewards_ma": f"Rewards [{window} day avg.]", "_timestamp": ""},
541
+ width=800,
542
+ height=600,
543
+ )
544
 
545
  if not annotate:
546
  return fig
547
 
548
  # Add annotations based on relevant releases
549
  for idx, row in release_dates.iterrows():
550
+ line_color = "grey"
551
+ if task in row["tasks_affected"]:
552
+ line_color = "red"
553
+ elif "all" not in row["tasks_affected"]:
554
+ line_color = "blue"
555
  # TODO add annotation or something
556
+ fig.add_vline(
557
+ row["release_date"],
558
+ line_color=line_color,
559
+ opacity=0.6,
560
+ line_dash="dot",
561
+ line_width=1,
562
+ ) # , annotation_text=str(v))
563
 
564
  return fig
565
 
566
+
567
  @st.cache_data()
568
  def get_task_counts(df_runs, df_events):
569
+ # Get mapping from run id to prompting repo version
570
+ run_to_version = df_runs.set_index("run_id").version.to_dict()
571
+
572
+ df_events["version"] = df_events.run_id.map(run_to_version)
573
 
 
 
574
  def version_to_spec(version):
575
+ major, minor, patch = version.split(".")
576
  return 10_000 * major + 100 * minor + patch
577
+
578
  def get_closest_prev_version(version, my_versions):
579
+
580
  ref_spec = version_to_spec(version)
581
  my_specs = list(map(version_to_spec, my_versions))
582
+
583
  match = my_specs[0]
584
  for spec in my_specs[1:]:
585
+ if spec > ref_spec:
586
  break
587
+
588
  match = spec
589
+
590
  return my_versions[my_specs.index(match)]
591
+
592
  # Now estimate the distribution of tasks for each version using the event data
593
+ task_rate = (
594
+ df_events.groupby("version")
595
+ .task.value_counts(normalize=True)
596
+ .unstack()
597
+ .fillna(0)
598
+ )
599
  # Impute missing versions
600
  for v in sorted(df_runs.version.unique()):
601
  if v not in task_rate.index:
602
  prev_version = get_closest_prev_version(v, list(task_rate.index))
603
+ print(
604
+ f"Imputing version {v} with task rate from closes previous version {prev_version!r}"
605
+ )
606
  task_rate.loc[v] = task_rate.loc[prev_version]
607
+
608
  # get esimated number of each task generated in every run using summary dataframe
609
+ task_counts = (
610
+ df_runs.set_index("created_at")
611
+ .sort_index()
612
+ .apply(lambda x: round(task_rate.loc[x.version] * x.num_steps), axis=1)
613
+ .cumsum()
614
+ )
615
+ return task_counts
616
 
617
 
618
  def load_state_vars(username=USERNAME, percentile=0.95):
619
 
620
  UPDATE_INTERVAL = 600
621
 
622
+ df_runs = build_data(time.time() // UPDATE_INTERVAL, use_cache=False)
623
+
624
  # df_runs = df_runs.loc[df_runs.netuid.isin([1,61,102])] # Now we filter for the netuid tag in build_data
625
+ st.toast(f"Loaded {len(df_runs)} runs")
626
 
627
  df_vali = df_runs.loc[df_runs.username == username]
628
  # df_vali = df_runs
629
 
630
+ download_runs(time.time() // UPDATE_INTERVAL, df_vali)
631
 
632
+ df_events = load_downloaded_runs(time.time() // UPDATE_INTERVAL)
633
  df_events = normalize_rewards(df_events, percentile=percentile)
634
 
635
+ yesterday = pd.Timestamp.now(tz="UTC") - pd.Timedelta("1d")
636
+ runs_alive_24h_ago = df_runs.last_event_at > yesterday
637
 
638
  df_runs_24h = df_runs.loc[runs_alive_24h_ago]
639
+
640
  # weight factor indicates the fraction of events that happened within the last 24 hour.
641
+ fraction = 1 - (yesterday - df_runs_24h.created_at) / (
642
+ pd.Timestamp.now(tz="UTC") - df_runs_24h.created_at
643
+ )
644
+ df_runs_24h["fraction"] = fraction.clip(0, 1)
645
+ df_runs_24h["num_steps"] *= fraction.clip(0, 1)
646
+
647
  df_task_counts = get_task_counts(df_runs, df_events)
648
 
649
+ df_m = get_metagraph(time.time() // UPDATE_INTERVAL)
650
 
651
  return {
652
+ "df_runs": df_runs,
653
+ "df_runs_24h": df_runs_24h,
654
+ "df_vali": df_vali,
655
+ "df_events": df_events,
656
+ "metagraph": df_m,
657
+ "df_task_counts": df_task_counts,
658
  }
659
 
660
 
661
+ if __name__ == "__main__":
662
 
663
+ pass