natolambert
commited on
Commit
·
8e499f4
1
Parent(s):
e4cd4cd
add dataset viewer
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
from huggingface_hub import HfApi, snapshot_download
|
|
|
4 |
from src.utils import load_all_data
|
5 |
from src.md import ABOUT_TEXT
|
6 |
import numpy as np
|
@@ -10,6 +11,7 @@ api = HfApi()
|
|
10 |
COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN")
|
11 |
evals_repo = "ai2-rlhf-collab/rm-benchmark-results"
|
12 |
prefs_repo = "ai2-rlhf-collab/rm-testset-results"
|
|
|
13 |
repo_dir_herm = "./evals/herm/"
|
14 |
repo_dir_prefs = "./evals/prefs/"
|
15 |
|
@@ -27,7 +29,6 @@ repo = snapshot_download(
|
|
27 |
etag_timeout=30,
|
28 |
repo_type="dataset",
|
29 |
)
|
30 |
-
# repo.git_pull()
|
31 |
|
32 |
repo_pref_sets = snapshot_download(
|
33 |
local_dir=repo_dir_prefs,
|
@@ -37,7 +38,6 @@ repo_pref_sets = snapshot_download(
|
|
37 |
etag_timeout=30,
|
38 |
repo_type="dataset",
|
39 |
)
|
40 |
-
# repo_pref_sets.git_pull()
|
41 |
|
42 |
def avg_over_herm(dataframe):
|
43 |
"""
|
@@ -69,6 +69,14 @@ col_types_herm_avg = ["markdown"] + ["number"] * (len(herm_data_avg.columns) - 1
|
|
69 |
col_types_prefs = ["markdown"] + ["number"] * (len(prefs_data.columns) - 1)
|
70 |
# col_types_prefs_sub = ["markdown"] + ["number"] * (len(prefs_data_sub.columns) - 1)
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
with gr.Blocks() as app:
|
73 |
# create tabs for the app, moving the current table to one titled "HERM" and the benchmark_text to a tab called "About"
|
74 |
with gr.Row():
|
@@ -101,8 +109,20 @@ with gr.Blocks() as app:
|
|
101 |
with gr.TabItem("About"):
|
102 |
with gr.Row():
|
103 |
gr.Markdown(ABOUT_TEXT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
# Load data when app starts
|
106 |
def load_data_on_start():
|
107 |
data_herm = load_all_data(repo_dir_herm)
|
108 |
herm_table.update(data_herm)
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
from huggingface_hub import HfApi, snapshot_download
|
4 |
+
from datasets import load_dataset
|
5 |
from src.utils import load_all_data
|
6 |
from src.md import ABOUT_TEXT
|
7 |
import numpy as np
|
|
|
11 |
COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN")
|
12 |
evals_repo = "ai2-rlhf-collab/rm-benchmark-results"
|
13 |
prefs_repo = "ai2-rlhf-collab/rm-testset-results"
|
14 |
+
eval_set_repo = "ai2-rlhf-collab/rm-benchmark-dev"
|
15 |
repo_dir_herm = "./evals/herm/"
|
16 |
repo_dir_prefs = "./evals/prefs/"
|
17 |
|
|
|
29 |
etag_timeout=30,
|
30 |
repo_type="dataset",
|
31 |
)
|
|
|
32 |
|
33 |
repo_pref_sets = snapshot_download(
|
34 |
local_dir=repo_dir_prefs,
|
|
|
38 |
etag_timeout=30,
|
39 |
repo_type="dataset",
|
40 |
)
|
|
|
41 |
|
42 |
def avg_over_herm(dataframe):
|
43 |
"""
|
|
|
69 |
col_types_prefs = ["markdown"] + ["number"] * (len(prefs_data.columns) - 1)
|
70 |
# col_types_prefs_sub = ["markdown"] + ["number"] * (len(prefs_data_sub.columns) - 1)
|
71 |
|
72 |
+
# for showing random samples
|
73 |
+
eval_set = load_dataset(eval_set_repo, use_auth_token=COLLAB_TOKEN, split="filtered")
|
74 |
+
def random_sample(r: gr.Request):
|
75 |
+
sample_index = np.random.randint(0, len(eval_set) - 1)
|
76 |
+
sample = eval_set[sample_index]
|
77 |
+
markdown_text = '\n\n'.join([f"**{key}**: {value}" for key, value in sample.items()])
|
78 |
+
return markdown_text
|
79 |
+
|
80 |
with gr.Blocks() as app:
|
81 |
# create tabs for the app, moving the current table to one titled "HERM" and the benchmark_text to a tab called "About"
|
82 |
with gr.Row():
|
|
|
109 |
with gr.TabItem("About"):
|
110 |
with gr.Row():
|
111 |
gr.Markdown(ABOUT_TEXT)
|
112 |
+
|
113 |
+
with gr.TabItem("Dataset Viewer"):
|
114 |
+
with gr.Row():
|
115 |
+
# loads one sample
|
116 |
+
gr.Markdown("## Random Dataset Sample Viewer")
|
117 |
+
button = gr.Button("Show Random Sample")
|
118 |
+
|
119 |
+
with gr.Row():
|
120 |
+
sample_display = gr.Markdown("{sampled data loads here}")
|
121 |
+
|
122 |
+
button.click(fn=random_sample, outputs=sample_display)
|
123 |
+
|
124 |
|
125 |
+
# Load data when app starts, TODO make this used somewhere...
|
126 |
def load_data_on_start():
|
127 |
data_herm = load_all_data(repo_dir_herm)
|
128 |
herm_table.update(data_herm)
|