SaulLu commited on
Commit
e638825
·
1 Parent(s): 4615d65
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import altair as alt
2
  import pandas as pd
3
  import streamlit as st
4
  import wandb
@@ -7,9 +6,12 @@ from dashboard_utils.bubbles import get_new_bubble_data
7
  from dashboard_utils.main_metrics import get_main_metrics
8
  from streamlit_observable import observable
9
 
 
 
 
10
  wandb.login(anonymous="must")
11
 
12
- st.title("Training transformers together dashboard")
13
  st.caption("Training Loss")
14
 
15
  steps, dates, losses, alive_peers = get_main_metrics()
@@ -58,10 +60,9 @@ st.vega_lite_chart(
58
 
59
  st.header("Collaborative training participants")
60
  serialized_data, profiles = get_new_bubble_data()
61
- with st.spinner("Wait for it..."):
62
- observers = observable(
63
- "Participants",
64
- notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
65
- targets=["c_noaws"],
66
- redefine={"serializedData": serialized_data, "profileSimple": profiles},
67
- )
 
 
1
  import pandas as pd
2
  import streamlit as st
3
  import wandb
 
6
  from dashboard_utils.main_metrics import get_main_metrics
7
  from streamlit_observable import observable
8
 
9
+ # Only need to set these here as we are add controls outside of Hydralit, to customise a run Hydralit!
10
+ st.set_page_config(page_title="Dashboard", layout="centered")
11
+
12
  wandb.login(anonymous="must")
13
 
14
+ st.markdown("<h1 style='text-align: center;'>Dashboard</h1>", unsafe_allow_html=True)
15
  st.caption("Training Loss")
16
 
17
  steps, dates, losses, alive_peers = get_main_metrics()
 
60
 
61
  st.header("Collaborative training participants")
62
  serialized_data, profiles = get_new_bubble_data()
63
+ observable(
64
+ "Participants",
65
+ notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
66
+ targets=["c_noaws"],
67
+ redefine={"serializedData": serialized_data, "profileSimple": profiles},
68
+ )
 
dashboard_utils/bubbles.py CHANGED
@@ -1,8 +1,8 @@
1
  import datetime
2
- import json
3
  from concurrent.futures import as_completed
4
  from urllib import parse
5
 
 
6
  import wandb
7
  from requests_futures.sessions import FuturesSession
8
 
@@ -10,12 +10,13 @@ from dashboard_utils.time_tracker import _log, simple_time_tracker
10
 
11
  URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
12
  WANDB_REPO = "learning-at-home/Worker_logs"
 
13
 
14
 
 
15
  @simple_time_tracker(_log)
16
  def get_new_bubble_data():
17
- # serialized_data_points, latest_timestamp = get_serialized_data_points()
18
- serialized_data_points, latest_timestamp = None, None
19
  serialized_data = get_serialized_data(serialized_data_points, latest_timestamp)
20
 
21
  usernames = []
@@ -27,6 +28,7 @@ def get_new_bubble_data():
27
  return serialized_data, profiles
28
 
29
 
 
30
  @simple_time_tracker(_log)
31
  def get_profiles(usernames):
32
  profiles = []
@@ -58,6 +60,7 @@ def get_profiles(usernames):
58
  return profiles
59
 
60
 
 
61
  @simple_time_tracker(_log)
62
  def get_serialized_data_points():
63
 
@@ -105,38 +108,33 @@ def get_serialized_data_points():
105
  return serialized_data_points, latest_timestamp
106
 
107
 
 
108
  @simple_time_tracker(_log)
109
  def get_serialized_data(serialized_data_points, latest_timestamp):
110
- # serialized_data_points_v2 = []
111
- # max_velocity = 1
112
- # for run_name, serialized_data_point in serialized_data_points.items():
113
- # activeRuns = []
114
- # loss = 0
115
- # runtime = 0
116
- # batches = 0
117
- # velocity = 0
118
- # for run in serialized_data_point["Runs"]:
119
- # if run["date"] == latest_timestamp:
120
- # run["date"] = run["date"].isoformat()
121
- # activeRuns.append(run)
122
- # loss += run["loss"]
123
- # velocity += run["velocity"]
124
- # loss = loss / len(activeRuns) if activeRuns else 0
125
- # runtime += run["runtime"]
126
- # batches += run["batches"]
127
- # new_item = {
128
- # "date": latest_timestamp.isoformat(),
129
- # "profileId": run_name,
130
- # "batches": batches,
131
- # "runtime": runtime,
132
- # "activeRuns": activeRuns,
133
- # }
134
- # serialized_data_points_v2.append(new_item)
135
- # serialized_data = {"points": [serialized_data_points_v2], "maxVelocity": max_velocity}
136
- with open(
137
- "/mnt/storage/Documents/hugging_face/colaborative_hub_training/demo_neurips/training-transformers-together-dashboard/data/"
138
- "serializaledata_V2.json",
139
- "r",
140
- ) as f:
141
- serialized_data = json.load(f)
142
  return serialized_data
 
1
  import datetime
 
2
  from concurrent.futures import as_completed
3
  from urllib import parse
4
 
5
+ import streamlit as st
6
  import wandb
7
  from requests_futures.sessions import FuturesSession
8
 
 
10
 
11
  URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
12
  WANDB_REPO = "learning-at-home/Worker_logs"
13
+ CACHE_TTL = 100
14
 
15
 
16
+ @st.cache(ttl=CACHE_TTL)
17
  @simple_time_tracker(_log)
18
  def get_new_bubble_data():
19
+ serialized_data_points, latest_timestamp = get_serialized_data_points()
 
20
  serialized_data = get_serialized_data(serialized_data_points, latest_timestamp)
21
 
22
  usernames = []
 
28
  return serialized_data, profiles
29
 
30
 
31
+ @st.cache(ttl=CACHE_TTL)
32
  @simple_time_tracker(_log)
33
  def get_profiles(usernames):
34
  profiles = []
 
60
  return profiles
61
 
62
 
63
+ @st.cache(ttl=CACHE_TTL)
64
  @simple_time_tracker(_log)
65
  def get_serialized_data_points():
66
 
 
108
  return serialized_data_points, latest_timestamp
109
 
110
 
111
+ @st.cache(ttl=CACHE_TTL)
112
  @simple_time_tracker(_log)
113
  def get_serialized_data(serialized_data_points, latest_timestamp):
114
+ serialized_data_points_v2 = []
115
+ max_velocity = 1
116
+ for run_name, serialized_data_point in serialized_data_points.items():
117
+ activeRuns = []
118
+ loss = 0
119
+ runtime = 0
120
+ batches = 0
121
+ velocity = 0
122
+ for run in serialized_data_point["Runs"]:
123
+ if run["date"] == latest_timestamp:
124
+ run["date"] = run["date"].isoformat()
125
+ activeRuns.append(run)
126
+ loss += run["loss"]
127
+ velocity += run["velocity"]
128
+ loss = loss / len(activeRuns) if activeRuns else 0
129
+ runtime += run["runtime"]
130
+ batches += run["batches"]
131
+ new_item = {
132
+ "date": latest_timestamp.isoformat(),
133
+ "profileId": run_name,
134
+ "batches": batches,
135
+ "runtime": runtime,
136
+ "activeRuns": activeRuns,
137
+ }
138
+ serialized_data_points_v2.append(new_item)
139
+ serialized_data = {"points": [serialized_data_points_v2], "maxVelocity": max_velocity}
 
 
 
 
 
 
140
  return serialized_data
dashboard_utils/main_metrics.py CHANGED
@@ -1,12 +1,15 @@
1
  import datetime
2
 
 
3
  import wandb
4
 
5
  from dashboard_utils.time_tracker import _log, simple_time_tracker
6
 
7
  WANDB_REPO = "learning-at-home/Main_metrics"
 
8
 
9
 
 
10
  @simple_time_tracker(_log)
11
  def get_main_metrics():
12
  api = wandb.Api()
 
1
  import datetime
2
 
3
+ import streamlit as st
4
  import wandb
5
 
6
  from dashboard_utils.time_tracker import _log, simple_time_tracker
7
 
8
  WANDB_REPO = "learning-at-home/Main_metrics"
9
+ CACHE_TTL = 100
10
 
11
 
12
+ @st.cache(ttl=CACHE_TTL)
13
  @simple_time_tracker(_log)
14
  def get_main_metrics():
15
  api = wandb.Api()