davanstrien HF staff commited on
Commit
35be45b
1 Parent(s): 60bd318

app example

Browse files
Files changed (3) hide show
  1. app.py +204 -0
  2. requirements.in +2 -0
  3. requirements.txt +223 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import markdown
4
+
5
+
6
+ def create_chat_html(messages, dataset_id, offset, compare_mode=False, column=""):
7
+ chat_html = ""
8
+ turn_number = 1
9
+ for i in range(0, len(messages), 2):
10
+ user_message = messages[i]
11
+ system_message = messages[i + 1] if i + 1 < len(messages) else None
12
+ user_role = user_message["role"]
13
+ user_content = user_message["content"]
14
+ user_content_html = markdown.markdown(user_content)
15
+ user_content_length = len(user_content)
16
+ user_html = f'<div class="user-message" style="justify-content: right;">'
17
+ user_html += f'<div class="message-content">'
18
+ user_html += (
19
+ f"<strong>Turn {turn_number} - {user_role.capitalize()}:</strong><br>"
20
+ )
21
+ user_html += f"<em>Length: {user_content_length} characters</em><br><br>"
22
+ user_html += f"{user_content_html}"
23
+ user_html += "</div></div>"
24
+ chat_html += user_html
25
+ if system_message:
26
+ system_role = system_message["role"]
27
+ system_content = system_message["content"]
28
+ system_content_html = markdown.markdown(system_content)
29
+ system_content_length = len(system_content)
30
+ system_html = f'<div class="system-message" style="justify-content: left;">'
31
+ system_html += f'<div class="message-content">'
32
+ system_html += f"<strong>{system_role.capitalize()}:</strong><br>"
33
+ system_html += (
34
+ f"<em>Length: {system_content_length} characters</em><br><br>"
35
+ )
36
+ system_html += f"{system_content_html}"
37
+ system_html += "</div></div>"
38
+ chat_html += system_html
39
+ turn_number += 1
40
+
41
+ if compare_mode:
42
+ chat_html = f'<div class="column {column}">{chat_html}</div>'
43
+
44
+ style = """
45
+ <style>
46
+ .user-message, .system-message {
47
+ display: flex;
48
+ margin: 10px;
49
+ }
50
+ .user-message .message-content {
51
+ background-color: #c2e3f7;
52
+ color: #000000;
53
+ }
54
+ .system-message .message-content {
55
+ background-color: #f5f5f5;
56
+ color: #000000;
57
+ }
58
+ .message-content {
59
+ padding: 10px;
60
+ border-radius: 10px;
61
+ max-width: 70%;
62
+ word-wrap: break-word;
63
+ }
64
+ .container {
65
+ display: flex;
66
+ justify-content: space-between;
67
+ }
68
+ .column {
69
+ width: 48%;
70
+ }
71
+ </style>
72
+ """
73
+
74
+ dataset_url = f"https://huggingface.co/datasets/{dataset_id}/viewer/default/train?row={offset}"
75
+ dataset_link = f"[View dataset row]({dataset_url})"
76
+
77
+ return dataset_link, style + chat_html
78
+
79
+
80
+ def fetch_data(
81
+ dataset_id, chosen_column, rejected_column, current_offset, direction, compare_mode
82
+ ):
83
+ change = 1 if direction == "Next" else -1
84
+ new_offset = max(0, current_offset + change)
85
+
86
+ base_url = f"https://datasets-server.huggingface.co/rows?dataset={dataset_id}&config=default&split=train&offset={new_offset}&length=1"
87
+ response = requests.get(base_url)
88
+ if response.status_code != 200:
89
+ return "", "Failed to fetch data", new_offset
90
+ data = response.json()
91
+
92
+ if compare_mode:
93
+ if chosen_column and rejected_column:
94
+ chosen_messages = data["rows"][0]["row"].get(chosen_column, [])
95
+ rejected_messages = data["rows"][0]["row"].get(rejected_column, [])
96
+ chosen_link, chosen_html = create_chat_html(
97
+ chosen_messages,
98
+ dataset_id,
99
+ new_offset,
100
+ compare_mode=True,
101
+ column="chosen",
102
+ )
103
+ rejected_link, rejected_html = create_chat_html(
104
+ rejected_messages,
105
+ dataset_id,
106
+ new_offset,
107
+ compare_mode=True,
108
+ column="rejected",
109
+ )
110
+ chat_html = f'<div class="container">{chosen_html}{rejected_html}</div>'
111
+ else:
112
+ return (
113
+ "",
114
+ "Please provide both chosen and rejected columns for comparison",
115
+ new_offset,
116
+ )
117
+ else:
118
+ if chosen_column:
119
+ messages = data["rows"][0]["row"].get(chosen_column, [])
120
+ else:
121
+ for key, value in data["rows"][0]["row"].items():
122
+ if (
123
+ isinstance(value, list)
124
+ and len(value) > 0
125
+ and isinstance(value[0], dict)
126
+ and "role" in value[0]
127
+ ):
128
+ messages = value
129
+ break
130
+ else:
131
+ return "", "No suitable chat column found", new_offset
132
+ _, chat_html = create_chat_html(messages, dataset_id, new_offset)
133
+
134
+ dataset_url = f"https://huggingface.co/datasets/{dataset_id}/viewer/default/train?row={new_offset}"
135
+ dataset_link = f"[View dataset row]({dataset_url})"
136
+
137
+ return dataset_link, chat_html, new_offset
138
+
139
+
140
+ def update_column_names(compare_mode):
141
+ if compare_mode:
142
+ return "chosen", "rejected"
143
+ else:
144
+ return "", ""
145
+
146
+
147
+ with gr.Blocks() as demo:
148
+ with gr.Row():
149
+ dataset_id = gr.Textbox(
150
+ label="Dataset ID", placeholder="e.g., davanstrien/cosmochat"
151
+ )
152
+ chosen_column = gr.Textbox(
153
+ label="Chosen Column",
154
+ placeholder="Column containing chosen chat data",
155
+ )
156
+ rejected_column = gr.Textbox(
157
+ label="Rejected Column",
158
+ placeholder="Column containing rejected chat data",
159
+ )
160
+ compare_mode = gr.Checkbox(label="Compare chosen and rejected chats")
161
+ current_offset = gr.State(value=0)
162
+
163
+ with gr.Row():
164
+ back_button = gr.Button("Back")
165
+ next_button = gr.Button("Next")
166
+
167
+ dataset_link = gr.Markdown()
168
+ output_html = gr.HTML()
169
+
170
+ compare_mode.change(
171
+ fn=update_column_names,
172
+ inputs=compare_mode,
173
+ outputs=[chosen_column, rejected_column],
174
+ )
175
+
176
+ back_button.click(
177
+ lambda data, chosen, rejected, offset, compare: fetch_data(
178
+ data, chosen, rejected, offset, "Back", compare
179
+ ),
180
+ inputs=[
181
+ dataset_id,
182
+ chosen_column,
183
+ rejected_column,
184
+ current_offset,
185
+ compare_mode,
186
+ ],
187
+ outputs=[dataset_link, output_html, current_offset],
188
+ )
189
+
190
+ next_button.click(
191
+ lambda data, chosen, rejected, offset, compare: fetch_data(
192
+ data, chosen, rejected, offset, "Next", compare
193
+ ),
194
+ inputs=[
195
+ dataset_id,
196
+ chosen_column,
197
+ rejected_column,
198
+ current_offset,
199
+ compare_mode,
200
+ ],
201
+ outputs=[dataset_link, output_html, current_offset],
202
+ )
203
+
204
+ demo.launch(debug=True, share=True)
requirements.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ markdown
requirements.txt ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile requirements.in -o requirements.txt
3
+ aiofiles==23.2.1
4
+ # via gradio
5
+ altair==5.3.0
6
+ # via gradio
7
+ annotated-types==0.6.0
8
+ # via pydantic
9
+ anyio==4.3.0
10
+ # via
11
+ # httpx
12
+ # starlette
13
+ # watchfiles
14
+ attrs==23.2.0
15
+ # via
16
+ # jsonschema
17
+ # referencing
18
+ certifi==2024.2.2
19
+ # via
20
+ # httpcore
21
+ # httpx
22
+ # requests
23
+ charset-normalizer==3.3.2
24
+ # via requests
25
+ click==8.1.7
26
+ # via
27
+ # typer
28
+ # uvicorn
29
+ contourpy==1.2.1
30
+ # via matplotlib
31
+ cycler==0.12.1
32
+ # via matplotlib
33
+ dnspython==2.6.1
34
+ # via email-validator
35
+ email-validator==2.1.1
36
+ # via fastapi
37
+ fastapi==0.111.0
38
+ # via
39
+ # fastapi-cli
40
+ # gradio
41
+ fastapi-cli==0.0.3
42
+ # via fastapi
43
+ ffmpy==0.3.2
44
+ # via gradio
45
+ filelock==3.14.0
46
+ # via huggingface-hub
47
+ fonttools==4.51.0
48
+ # via matplotlib
49
+ fsspec==2024.3.1
50
+ # via
51
+ # gradio-client
52
+ # huggingface-hub
53
+ gradio==4.29.0
54
+ gradio-client==0.16.1
55
+ # via gradio
56
+ h11==0.14.0
57
+ # via
58
+ # httpcore
59
+ # uvicorn
60
+ httpcore==1.0.5
61
+ # via httpx
62
+ httptools==0.6.1
63
+ # via uvicorn
64
+ httpx==0.27.0
65
+ # via
66
+ # fastapi
67
+ # gradio
68
+ # gradio-client
69
+ huggingface-hub==0.23.0
70
+ # via
71
+ # gradio
72
+ # gradio-client
73
+ idna==3.7
74
+ # via
75
+ # anyio
76
+ # email-validator
77
+ # httpx
78
+ # requests
79
+ importlib-resources==6.4.0
80
+ # via gradio
81
+ jinja2==3.1.4
82
+ # via
83
+ # altair
84
+ # fastapi
85
+ # gradio
86
+ jsonschema==4.22.0
87
+ # via altair
88
+ jsonschema-specifications==2023.12.1
89
+ # via jsonschema
90
+ kiwisolver==1.4.5
91
+ # via matplotlib
92
+ markdown==3.6
93
+ markdown-it-py==3.0.0
94
+ # via rich
95
+ markupsafe==2.1.5
96
+ # via
97
+ # gradio
98
+ # jinja2
99
+ matplotlib==3.8.4
100
+ # via gradio
101
+ mdurl==0.1.2
102
+ # via markdown-it-py
103
+ numpy==1.26.4
104
+ # via
105
+ # altair
106
+ # contourpy
107
+ # gradio
108
+ # matplotlib
109
+ # pandas
110
+ orjson==3.10.3
111
+ # via
112
+ # fastapi
113
+ # gradio
114
+ packaging==24.0
115
+ # via
116
+ # altair
117
+ # gradio
118
+ # gradio-client
119
+ # huggingface-hub
120
+ # matplotlib
121
+ pandas==2.2.2
122
+ # via
123
+ # altair
124
+ # gradio
125
+ pillow==10.3.0
126
+ # via
127
+ # gradio
128
+ # matplotlib
129
+ pydantic==2.7.1
130
+ # via
131
+ # fastapi
132
+ # gradio
133
+ pydantic-core==2.18.2
134
+ # via pydantic
135
+ pydub==0.25.1
136
+ # via gradio
137
+ pygments==2.18.0
138
+ # via rich
139
+ pyparsing==3.1.2
140
+ # via matplotlib
141
+ python-dateutil==2.9.0.post0
142
+ # via
143
+ # matplotlib
144
+ # pandas
145
+ python-dotenv==1.0.1
146
+ # via uvicorn
147
+ python-multipart==0.0.9
148
+ # via
149
+ # fastapi
150
+ # gradio
151
+ pytz==2024.1
152
+ # via pandas
153
+ pyyaml==6.0.1
154
+ # via
155
+ # gradio
156
+ # huggingface-hub
157
+ # uvicorn
158
+ referencing==0.35.1
159
+ # via
160
+ # jsonschema
161
+ # jsonschema-specifications
162
+ requests==2.31.0
163
+ # via huggingface-hub
164
+ rich==13.7.1
165
+ # via typer
166
+ rpds-py==0.18.1
167
+ # via
168
+ # jsonschema
169
+ # referencing
170
+ ruff==0.4.3
171
+ # via gradio
172
+ semantic-version==2.10.0
173
+ # via gradio
174
+ shellingham==1.5.4
175
+ # via typer
176
+ six==1.16.0
177
+ # via python-dateutil
178
+ sniffio==1.3.1
179
+ # via
180
+ # anyio
181
+ # httpx
182
+ starlette==0.37.2
183
+ # via fastapi
184
+ tomlkit==0.12.0
185
+ # via gradio
186
+ toolz==0.12.1
187
+ # via altair
188
+ tqdm==4.66.4
189
+ # via huggingface-hub
190
+ typer==0.12.3
191
+ # via
192
+ # fastapi-cli
193
+ # gradio
194
+ typing-extensions==4.11.0
195
+ # via
196
+ # fastapi
197
+ # gradio
198
+ # gradio-client
199
+ # huggingface-hub
200
+ # pydantic
201
+ # pydantic-core
202
+ # typer
203
+ tzdata==2024.1
204
+ # via pandas
205
+ ujson==5.9.0
206
+ # via fastapi
207
+ urllib3==2.2.1
208
+ # via
209
+ # gradio
210
+ # requests
211
+ uvicorn==0.29.0
212
+ # via
213
+ # fastapi
214
+ # fastapi-cli
215
+ # gradio
216
+ uvloop==0.19.0
217
+ # via uvicorn
218
+ watchfiles==0.21.0
219
+ # via uvicorn
220
+ websockets==11.0.3
221
+ # via
222
+ # gradio-client
223
+ # uvicorn