SaulLu commited on
Commit
4adb1cb
2 Parent(s): 57845b8 e638825

Merge pull request #1 from SaulLu/add-cache

Browse files
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import altair as alt
2
  import pandas as pd
3
  import streamlit as st
4
  import wandb
@@ -7,9 +6,12 @@ from dashboard_utils.bubbles import get_new_bubble_data
7
  from dashboard_utils.main_metrics import get_main_metrics
8
  from streamlit_observable import observable
9
 
 
 
 
10
  wandb.login(anonymous="must")
11
 
12
- st.title("Training transformers together dashboard")
13
  st.caption("Training Loss")
14
 
15
  steps, dates, losses, alive_peers = get_main_metrics()
@@ -58,10 +60,9 @@ st.vega_lite_chart(
58
 
59
  st.header("Collaborative training participants")
60
  serialized_data, profiles = get_new_bubble_data()
61
- with st.spinner("Wait for it..."):
62
- observers = observable(
63
- "Participants",
64
- notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
65
- targets=["c_noaws"],
66
- redefine={"serializedData": serialized_data, "profileSimple": profiles},
67
- )
 
 
1
  import pandas as pd
2
  import streamlit as st
3
  import wandb
 
6
  from dashboard_utils.main_metrics import get_main_metrics
7
  from streamlit_observable import observable
8
 
9
+ # Only need to set these here as we are add controls outside of Hydralit, to customise a run Hydralit!
10
+ st.set_page_config(page_title="Dashboard", layout="centered")
11
+
12
  wandb.login(anonymous="must")
13
 
14
+ st.markdown("<h1 style='text-align: center;'>Dashboard</h1>", unsafe_allow_html=True)
15
  st.caption("Training Loss")
16
 
17
  steps, dates, losses, alive_peers = get_main_metrics()
 
60
 
61
  st.header("Collaborative training participants")
62
  serialized_data, profiles = get_new_bubble_data()
63
+ observable(
64
+ "Participants",
65
+ notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
66
+ targets=["c_noaws"],
67
+ redefine={"serializedData": serialized_data, "profileSimple": profiles},
68
+ )
 
dashboard_utils/bubbles.py CHANGED
@@ -2,6 +2,7 @@ import datetime
2
  from concurrent.futures import as_completed
3
  from urllib import parse
4
 
 
5
  import wandb
6
  from requests_futures.sessions import FuturesSession
7
 
@@ -9,23 +10,31 @@ from dashboard_utils.time_tracker import _log, simple_time_tracker
9
 
10
  URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
11
  WANDB_REPO = "learning-at-home/Worker_logs"
 
12
 
13
 
 
14
  @simple_time_tracker(_log)
15
  def get_new_bubble_data():
16
  serialized_data_points, latest_timestamp = get_serialized_data_points()
17
  serialized_data = get_serialized_data(serialized_data_points, latest_timestamp)
18
- profiles = get_profiles(serialized_data_points)
 
 
 
 
 
19
 
20
  return serialized_data, profiles
21
 
22
 
 
23
  @simple_time_tracker(_log)
24
- def get_profiles(serialized_data_points):
25
  profiles = []
26
  with FuturesSession() as session:
27
  futures = []
28
- for username in serialized_data_points.keys():
29
  future = session.get(URL_QUICKSEARCH + parse.urlencode({"type": "user", "q": username}))
30
  future.username = username
31
  futures.append(future)
@@ -51,6 +60,7 @@ def get_profiles(serialized_data_points):
51
  return profiles
52
 
53
 
 
54
  @simple_time_tracker(_log)
55
  def get_serialized_data_points():
56
 
@@ -98,6 +108,7 @@ def get_serialized_data_points():
98
  return serialized_data_points, latest_timestamp
99
 
100
 
 
101
  @simple_time_tracker(_log)
102
  def get_serialized_data(serialized_data_points, latest_timestamp):
103
  serialized_data_points_v2 = []
 
2
  from concurrent.futures import as_completed
3
  from urllib import parse
4
 
5
+ import streamlit as st
6
  import wandb
7
  from requests_futures.sessions import FuturesSession
8
 
 
10
 
11
  URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
12
  WANDB_REPO = "learning-at-home/Worker_logs"
13
+ CACHE_TTL = 100
14
 
15
 
16
+ @st.cache(ttl=CACHE_TTL)
17
  @simple_time_tracker(_log)
18
  def get_new_bubble_data():
19
  serialized_data_points, latest_timestamp = get_serialized_data_points()
20
  serialized_data = get_serialized_data(serialized_data_points, latest_timestamp)
21
+
22
+ usernames = []
23
+ for item in serialized_data["points"][0]:
24
+ usernames.append(item["profileId"])
25
+
26
+ profiles = get_profiles(usernames)
27
 
28
  return serialized_data, profiles
29
 
30
 
31
+ @st.cache(ttl=CACHE_TTL)
32
  @simple_time_tracker(_log)
33
+ def get_profiles(usernames):
34
  profiles = []
35
  with FuturesSession() as session:
36
  futures = []
37
+ for username in usernames:
38
  future = session.get(URL_QUICKSEARCH + parse.urlencode({"type": "user", "q": username}))
39
  future.username = username
40
  futures.append(future)
 
60
  return profiles
61
 
62
 
63
+ @st.cache(ttl=CACHE_TTL)
64
  @simple_time_tracker(_log)
65
  def get_serialized_data_points():
66
 
 
108
  return serialized_data_points, latest_timestamp
109
 
110
 
111
+ @st.cache(ttl=CACHE_TTL)
112
  @simple_time_tracker(_log)
113
  def get_serialized_data(serialized_data_points, latest_timestamp):
114
  serialized_data_points_v2 = []
dashboard_utils/main_metrics.py CHANGED
@@ -1,12 +1,15 @@
1
  import datetime
2
 
 
3
  import wandb
4
 
5
  from dashboard_utils.time_tracker import _log, simple_time_tracker
6
 
7
  WANDB_REPO = "learning-at-home/Main_metrics"
 
8
 
9
 
 
10
  @simple_time_tracker(_log)
11
  def get_main_metrics():
12
  api = wandb.Api()
 
1
  import datetime
2
 
3
+ import streamlit as st
4
  import wandb
5
 
6
  from dashboard_utils.time_tracker import _log, simple_time_tracker
7
 
8
  WANDB_REPO = "learning-at-home/Main_metrics"
9
+ CACHE_TTL = 100
10
 
11
 
12
+ @st.cache(ttl=CACHE_TTL)
13
  @simple_time_tracker(_log)
14
  def get_main_metrics():
15
  api = wandb.Api()
data/serializaledata.json ADDED
The diff for this file is too large to render. See raw diff
 
perso/change_data.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ with open(
5
+ "/mnt/storage/Documents/hugging_face/colaborative_hub_training/demo_neurips/training-transformers-together-dashboard/data/"
6
+ "serializaledata.json",
7
+ "r",
8
+ ) as f:
9
+ serialized_data = json.load(f)
10
+
11
+ serialized_data_v2 = serialized_data
12
+ serialized_data_v2["points"] = [[item for item in serialized_data["points"][-1] if random.random() > 0.8]]
13
+
14
+ with open(
15
+ "/mnt/storage/Documents/hugging_face/colaborative_hub_training/demo_neurips/training-transformers-together-dashboard/data/"
16
+ "serializaledata_V2.json",
17
+ "w",
18
+ ) as f:
19
+ f.write(json.dumps(serialized_data_v2))
perso/get_usernames.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ with open(
4
+ "/mnt/storage/Documents/hugging_face/colaborative_hub_training/demo_neurips/training-transformers-together-dashboard/data/"
5
+ "serializaledata_V2.json",
6
+ "r",
7
+ ) as f:
8
+ serialized_data = json.load(f)
9
+
10
+ usernames = []
11
+ for item in serialized_data["points"][0]:
12
+ usernames.append(item["profileId"])
13
+
14
+ print(usernames)