attilabalint commited on
Commit
f1e08ee
1 Parent(s): 53a0c92

initial commit

Browse files
Files changed (5) hide show
  1. app.py +52 -0
  2. components.py +10 -0
  3. images/energyville_logo.png +0 -0
  4. images/ku_leuven_logo.png +0 -0
  5. utils.py +29 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from components import summary_view
4
+ import utils
5
+
6
+ wandb_api_key = st.secrets["wandb"]
7
+
8
+ st.set_page_config(page_title='Electricity Demand Dashboard', layout='wide')
9
+
10
+
11
+ @st.cache_data(ttl=86400)
12
+ def fetch_data():
13
+ return utils.get_wandb_data(
14
+ st.secrets['wandb']['entity'],
15
+ "enfobench-electricity-demand",
16
+ st.secrets["wandb"]['api_key'],
17
+ job_type="metrics",
18
+ )
19
+
20
+
21
+ data = fetch_data()
22
+ models = sorted(data['model'].unique().tolist())
23
+ models_to_plot = set()
24
+ model_groups: dict[str, list[str]] = {}
25
+
26
+
27
+ for model in models:
28
+ group, model_name = model.split(".", maxsplit=1)
29
+ if group not in model_groups:
30
+ model_groups[group] = []
31
+ model_groups[group].append(model_name)
32
+
33
+
34
+ with st.sidebar:
35
+ left, right = st.columns(2) # Create two columns within the right column for side-by-side images
36
+ with left:
37
+ st.image("./images/ku_leuven_logo.png") # Adjust the path and width as needed
38
+ with right:
39
+ st.image("./images/energyville_logo.png")
40
+
41
+ view = st.selectbox("View", ["Summary", "Raw Data"], index=0)
42
+
43
+ st.header("Models to include")
44
+ for model_group, models in model_groups.items():
45
+ st.text(model_group)
46
+ for model_name in models:
47
+ to_plot = st.checkbox(model_name, value=True)
48
+ if to_plot:
49
+ models_to_plot.add(f"{model_group}.{model_name}")
50
+
51
+ if view == "Summary":
52
+ summary_view(data, models_to_plot)
components.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def summary_view(data, models_to_plot: set[str]):
5
+ st.title("Summary View")
6
+ st.write(data)
7
+ st.write(models_to_plot)
8
+
9
+ if st.button('Say Hello'):
10
+ st.balloons()
images/energyville_logo.png ADDED
images/ku_leuven_logo.png ADDED
utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import wandb
3
+
4
+
5
+ def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd.DataFrame:
6
+ api = wandb.Api(api_key=api_key)
7
+
8
+ # Project is specified by <entity/project-name>
9
+ filter_dict = {"jobType": job_type}
10
+ runs = api.runs(f"{entity}/{project}", filters=filter_dict)
11
+
12
+ summary_list, config_list, name_list = [], [], []
13
+ for run in runs:
14
+ # .summary contains the output keys/values for metrics like accuracy.
15
+ # We call ._json_dict to omit large files
16
+ summary_list.append(run.summary._json_dict)
17
+
18
+ # .config contains the hyperparameters.
19
+ # We remove special values that start with _.
20
+ config_list.append({k: v for k, v in run.config.items()})
21
+
22
+ # .name is the human-readable name of the run.
23
+ name_list.append(run.name)
24
+
25
+ summary_df = pd.json_normalize(summary_list, max_level=1)
26
+ config_df = pd.json_normalize(config_list, max_level=2)
27
+ runs_df = pd.concat([summary_df, config_df], axis=1)
28
+ runs_df.index = name_list
29
+ return runs_df