Update app.py
Browse files
app.py
CHANGED
@@ -14,12 +14,17 @@ import random
|
|
14 |
from params import load_params, save_params
|
15 |
import pandas as pd
|
16 |
import csv
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
|
20 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
21 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
22 |
|
|
|
|
|
23 |
def load_llm_config():
|
24 |
params = load_params()
|
25 |
return (
|
@@ -34,6 +39,8 @@ def load_llm_config():
|
|
34 |
params.get('presence_penalty', 0.0)
|
35 |
)
|
36 |
|
|
|
|
|
37 |
def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
38 |
save_params({
|
39 |
'PROVIDER': provider,
|
@@ -49,6 +56,8 @@ def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperat
|
|
49 |
return "LLM configuration saved successfully"
|
50 |
|
51 |
|
|
|
|
|
52 |
def load_annotation_config():
|
53 |
try:
|
54 |
with open(ANNOTATION_CONFIG_FILE, 'r') as f:
|
@@ -92,6 +101,8 @@ def load_annotation_config():
|
|
92 |
}
|
93 |
|
94 |
|
|
|
|
|
95 |
def load_csv_dataset(file_path):
|
96 |
data = []
|
97 |
with open(file_path, 'r') as f:
|
@@ -100,20 +111,28 @@ def load_csv_dataset(file_path):
|
|
100 |
data.append(row)
|
101 |
return data
|
102 |
|
|
|
|
|
103 |
def load_txt_dataset(file_path):
|
104 |
with open(file_path, 'r') as f:
|
105 |
return [{"content": line.strip()} for line in f if line.strip()]
|
106 |
|
|
|
|
|
107 |
def save_annotation_config(config):
|
108 |
with open(ANNOTATION_CONFIG_FILE, 'w') as f:
|
109 |
json.dump(config, f, indent=2)
|
110 |
|
|
|
|
|
111 |
def load_jsonl_dataset(file_path):
|
112 |
if not os.path.exists(file_path):
|
113 |
return []
|
114 |
with open(file_path, 'r') as f:
|
115 |
return [json.loads(line.strip()) for line in f if line.strip()]
|
116 |
|
|
|
|
|
117 |
def load_dataset(file):
|
118 |
if file is None:
|
119 |
return "", 0, 0, "No file uploaded", "3", [], [], [], ""
|
@@ -136,6 +155,8 @@ def load_dataset(file):
|
|
136 |
first_row = json.dumps(data[0], indent=2)
|
137 |
return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
|
138 |
|
|
|
|
|
139 |
def save_row(file_path, index, row_data):
|
140 |
file_extension = file_path.split('.')[-1].lower()
|
141 |
|
@@ -150,6 +171,8 @@ def save_row(file_path, index, row_data):
|
|
150 |
|
151 |
return f"Row {index} saved successfully"
|
152 |
|
|
|
|
|
153 |
def save_jsonl_row(file_path, index, row_data):
|
154 |
with open(file_path, 'r') as f:
|
155 |
lines = f.readlines()
|
@@ -159,6 +182,8 @@ def save_jsonl_row(file_path, index, row_data):
|
|
159 |
with open(file_path, 'w') as f:
|
160 |
f.writelines(lines)
|
161 |
|
|
|
|
|
162 |
def save_csv_row(file_path, index, row_data):
|
163 |
df = pd.read_csv(file_path)
|
164 |
row_dict = json.loads(row_data)
|
@@ -166,6 +191,8 @@ def save_csv_row(file_path, index, row_data):
|
|
166 |
df.at[index, col] = value
|
167 |
df.to_csv(file_path, index=False)
|
168 |
|
|
|
|
|
169 |
def save_txt_row(file_path, index, row_data):
|
170 |
with open(file_path, 'r') as f:
|
171 |
lines = f.readlines()
|
@@ -176,6 +203,8 @@ def save_txt_row(file_path, index, row_data):
|
|
176 |
with open(file_path, 'w') as f:
|
177 |
f.writelines(lines)
|
178 |
|
|
|
|
|
179 |
def get_row(file_path, index):
|
180 |
data = load_jsonl_dataset(file_path)
|
181 |
if not data:
|
@@ -184,6 +213,8 @@ def get_row(file_path, index):
|
|
184 |
return json.dumps(data[index], indent=2), len(data)
|
185 |
return "", len(data)
|
186 |
|
|
|
|
|
187 |
def json_to_markdown(json_str):
|
188 |
try:
|
189 |
data = json.loads(json_str)
|
@@ -192,6 +223,8 @@ def json_to_markdown(json_str):
|
|
192 |
except json.JSONDecodeError:
|
193 |
return "Error: Invalid JSON format"
|
194 |
|
|
|
|
|
195 |
def markdown_to_json(markdown_str):
|
196 |
sections = re.split(r'#\s+(System|Instruction|Response)\s*\n', markdown_str)
|
197 |
if len(sections) != 7: # Should be: ['', 'System', content, 'Instruction', content, 'Response', content]
|
@@ -204,10 +237,14 @@ def markdown_to_json(markdown_str):
|
|
204 |
}
|
205 |
return json.dumps(json_data, indent=2)
|
206 |
|
|
|
|
|
207 |
def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
|
208 |
new_index = max(0, current_index + (-1 if direction == "prev" else 1))
|
209 |
return load_and_show_row(file_path, new_index, metadata_config)
|
210 |
|
|
|
|
|
211 |
def load_and_show_row(file_path, index, metadata_config):
|
212 |
row_data, total = get_row(file_path, index)
|
213 |
if not row_data:
|
@@ -229,6 +266,8 @@ def load_and_show_row(file_path, index, metadata_config):
|
|
229 |
return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
|
230 |
high_quality_tags, low_quality_tags, toxic_tags, other)
|
231 |
|
|
|
|
|
232 |
def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
233 |
data = json.loads(row_data)
|
234 |
metadata = {
|
@@ -248,6 +287,8 @@ def save_row_with_metadata(file_path, index, row_data, config, quality, high_qua
|
|
248 |
data["metadata"] = metadata
|
249 |
return save_row(file_path, index, json.dumps(data))
|
250 |
|
|
|
|
|
251 |
def update_annotation_ui(config):
|
252 |
quality_choices = [(item["value"], item["label"]) for item in config["quality_scale"]["scale"]]
|
253 |
quality_label = gr.Radio(
|
@@ -271,6 +312,8 @@ def update_annotation_ui(config):
|
|
271 |
|
272 |
return quality_label, *tag_components, other_description
|
273 |
|
|
|
|
|
274 |
def load_config_to_ui(config):
|
275 |
return (
|
276 |
config["quality_scale"]["name"],
|
@@ -280,6 +323,8 @@ def load_config_to_ui(config):
|
|
280 |
[[field["name"], field["description"]] for field in config["free_text_fields"]]
|
281 |
)
|
282 |
|
|
|
|
|
283 |
def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
|
284 |
if all_topics_text.visible:
|
285 |
topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
|
@@ -299,6 +344,8 @@ def save_config_from_ui(name, description, scale, categories, fields, topics, al
|
|
299 |
save_annotation_config(new_config)
|
300 |
return "Configuration saved successfully", new_config
|
301 |
|
|
|
|
|
302 |
# Add this new function to generate the preview
|
303 |
def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
304 |
try:
|
@@ -321,6 +368,8 @@ def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, tox
|
|
321 |
except json.JSONDecodeError:
|
322 |
return "Error: Invalid JSON in the current row data"
|
323 |
|
|
|
|
|
324 |
def load_dataset_config():
|
325 |
params = load_params()
|
326 |
with open("system_messages.py", "r") as f:
|
@@ -347,6 +396,8 @@ def load_dataset_config():
|
|
347 |
params.get('presence_penalty', 0.0)
|
348 |
)
|
349 |
|
|
|
|
|
350 |
def edit_all_topics_func(topics):
|
351 |
topics_list = [topic[0] for topic in topics]
|
352 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
@@ -356,6 +407,8 @@ def edit_all_topics_func(topics):
|
|
356 |
gr.update(visible=True)
|
357 |
)
|
358 |
|
|
|
|
|
359 |
def update_topics_from_text(text):
|
360 |
try:
|
361 |
# Try parsing as JSONL
|
@@ -366,6 +419,8 @@ def update_topics_from_text(text):
|
|
366 |
|
367 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
368 |
|
|
|
|
|
369 |
def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
370 |
# Save VODALUS_SYSTEM_MESSAGE to system_messages.py
|
371 |
with open("system_messages.py", "w") as f:
|
@@ -426,6 +481,7 @@ def chat_with_llm(message, history):
|
|
426 |
print(f"Error in chat_with_llm: {str(e)}")
|
427 |
return history + [[message, f"Error: {str(e)}"]]
|
428 |
|
|
|
429 |
def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
430 |
context = f"""Current app state:
|
431 |
Row: {index + 1}/{total}
|
@@ -440,12 +496,16 @@ def update_chat_context(row_data, index, total, quality, high_quality_tags, low_
|
|
440 |
return [[None, context]]
|
441 |
|
442 |
|
443 |
-
|
|
|
|
|
|
|
|
|
444 |
generated_data = []
|
445 |
for _ in range(num_generations):
|
446 |
topic_selected = random.choice(TOPICS)
|
447 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
448 |
-
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path)
|
449 |
if data:
|
450 |
generated_data.append(json.dumps(data))
|
451 |
|
@@ -456,15 +516,21 @@ async def run_generate_dataset(num_workers, num_generations, output_file_path):
|
|
456 |
|
457 |
return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
|
458 |
|
|
|
|
|
459 |
def add_topic_row(data):
|
460 |
if isinstance(data, pd.DataFrame):
|
461 |
return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
|
462 |
else:
|
463 |
return data + [["New Topic"]]
|
464 |
|
|
|
|
|
465 |
def remove_last_topic_row(data):
|
466 |
return data[:-1] if len(data) > 1 else data
|
467 |
|
|
|
|
|
468 |
def edit_all_topics_func(topics):
|
469 |
topics_list = [topic[0] for topic in topics]
|
470 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
@@ -474,6 +540,8 @@ def edit_all_topics_func(topics):
|
|
474 |
gr.update(visible=True)
|
475 |
)
|
476 |
|
|
|
|
|
477 |
def update_topics_from_text(text):
|
478 |
try:
|
479 |
# Try parsing as JSONL
|
@@ -484,6 +552,8 @@ def update_topics_from_text(text):
|
|
484 |
|
485 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
486 |
|
|
|
|
|
487 |
def update_topics_from_text(text):
|
488 |
try:
|
489 |
# Try parsing as JSONL
|
@@ -494,6 +564,82 @@ def update_topics_from_text(text):
|
|
494 |
|
495 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
css = """
|
498 |
body, #root {
|
499 |
margin: 0;
|
@@ -740,6 +886,20 @@ with demo:
|
|
740 |
with gr.Row():
|
741 |
save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
|
742 |
dataset_config_status = gr.Textbox(label="Status")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
|
744 |
|
745 |
with gr.Tab("Dataset Generation"):
|
@@ -889,7 +1049,7 @@ with demo:
|
|
889 |
|
890 |
start_generation_btn.click(
|
891 |
run_generate_dataset,
|
892 |
-
inputs=[num_workers, num_generations, output_file_path],
|
893 |
outputs=[generation_status, generation_output]
|
894 |
)
|
895 |
|
@@ -915,6 +1075,30 @@ with demo:
|
|
915 |
outputs=[chatbot]
|
916 |
)
|
917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
918 |
|
919 |
demo.load(
|
920 |
lambda: (
|
|
|
14 |
from params import load_params, save_params
|
15 |
import pandas as pd
|
16 |
import csv
|
17 |
+
from datasets import load_dataset
|
18 |
+
from huggingface_hub import list_datasets, HfApi, hf_hub_download
|
19 |
+
|
20 |
|
21 |
|
22 |
|
23 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
24 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
25 |
|
26 |
+
|
27 |
+
|
28 |
def load_llm_config():
|
29 |
params = load_params()
|
30 |
return (
|
|
|
39 |
params.get('presence_penalty', 0.0)
|
40 |
)
|
41 |
|
42 |
+
|
43 |
+
|
44 |
def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
45 |
save_params({
|
46 |
'PROVIDER': provider,
|
|
|
56 |
return "LLM configuration saved successfully"
|
57 |
|
58 |
|
59 |
+
|
60 |
+
|
61 |
def load_annotation_config():
|
62 |
try:
|
63 |
with open(ANNOTATION_CONFIG_FILE, 'r') as f:
|
|
|
101 |
}
|
102 |
|
103 |
|
104 |
+
|
105 |
+
|
106 |
def load_csv_dataset(file_path):
|
107 |
data = []
|
108 |
with open(file_path, 'r') as f:
|
|
|
111 |
data.append(row)
|
112 |
return data
|
113 |
|
114 |
+
|
115 |
+
|
116 |
def load_txt_dataset(file_path):
|
117 |
with open(file_path, 'r') as f:
|
118 |
return [{"content": line.strip()} for line in f if line.strip()]
|
119 |
|
120 |
+
|
121 |
+
|
122 |
def save_annotation_config(config):
|
123 |
with open(ANNOTATION_CONFIG_FILE, 'w') as f:
|
124 |
json.dump(config, f, indent=2)
|
125 |
|
126 |
+
|
127 |
+
|
128 |
def load_jsonl_dataset(file_path):
|
129 |
if not os.path.exists(file_path):
|
130 |
return []
|
131 |
with open(file_path, 'r') as f:
|
132 |
return [json.loads(line.strip()) for line in f if line.strip()]
|
133 |
|
134 |
+
|
135 |
+
|
136 |
def load_dataset(file):
|
137 |
if file is None:
|
138 |
return "", 0, 0, "No file uploaded", "3", [], [], [], ""
|
|
|
155 |
first_row = json.dumps(data[0], indent=2)
|
156 |
return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
|
157 |
|
158 |
+
|
159 |
+
|
160 |
def save_row(file_path, index, row_data):
|
161 |
file_extension = file_path.split('.')[-1].lower()
|
162 |
|
|
|
171 |
|
172 |
return f"Row {index} saved successfully"
|
173 |
|
174 |
+
|
175 |
+
|
176 |
def save_jsonl_row(file_path, index, row_data):
|
177 |
with open(file_path, 'r') as f:
|
178 |
lines = f.readlines()
|
|
|
182 |
with open(file_path, 'w') as f:
|
183 |
f.writelines(lines)
|
184 |
|
185 |
+
|
186 |
+
|
187 |
def save_csv_row(file_path, index, row_data):
|
188 |
df = pd.read_csv(file_path)
|
189 |
row_dict = json.loads(row_data)
|
|
|
191 |
df.at[index, col] = value
|
192 |
df.to_csv(file_path, index=False)
|
193 |
|
194 |
+
|
195 |
+
|
196 |
def save_txt_row(file_path, index, row_data):
|
197 |
with open(file_path, 'r') as f:
|
198 |
lines = f.readlines()
|
|
|
203 |
with open(file_path, 'w') as f:
|
204 |
f.writelines(lines)
|
205 |
|
206 |
+
|
207 |
+
|
208 |
def get_row(file_path, index):
|
209 |
data = load_jsonl_dataset(file_path)
|
210 |
if not data:
|
|
|
213 |
return json.dumps(data[index], indent=2), len(data)
|
214 |
return "", len(data)
|
215 |
|
216 |
+
|
217 |
+
|
218 |
def json_to_markdown(json_str):
|
219 |
try:
|
220 |
data = json.loads(json_str)
|
|
|
223 |
except json.JSONDecodeError:
|
224 |
return "Error: Invalid JSON format"
|
225 |
|
226 |
+
|
227 |
+
|
228 |
def markdown_to_json(markdown_str):
|
229 |
sections = re.split(r'#\s+(System|Instruction|Response)\s*\n', markdown_str)
|
230 |
if len(sections) != 7: # Should be: ['', 'System', content, 'Instruction', content, 'Response', content]
|
|
|
237 |
}
|
238 |
return json.dumps(json_data, indent=2)
|
239 |
|
240 |
+
|
241 |
+
|
242 |
def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
|
243 |
new_index = max(0, current_index + (-1 if direction == "prev" else 1))
|
244 |
return load_and_show_row(file_path, new_index, metadata_config)
|
245 |
|
246 |
+
|
247 |
+
|
248 |
def load_and_show_row(file_path, index, metadata_config):
|
249 |
row_data, total = get_row(file_path, index)
|
250 |
if not row_data:
|
|
|
266 |
return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
|
267 |
high_quality_tags, low_quality_tags, toxic_tags, other)
|
268 |
|
269 |
+
|
270 |
+
|
271 |
def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
272 |
data = json.loads(row_data)
|
273 |
metadata = {
|
|
|
287 |
data["metadata"] = metadata
|
288 |
return save_row(file_path, index, json.dumps(data))
|
289 |
|
290 |
+
|
291 |
+
|
292 |
def update_annotation_ui(config):
|
293 |
quality_choices = [(item["value"], item["label"]) for item in config["quality_scale"]["scale"]]
|
294 |
quality_label = gr.Radio(
|
|
|
312 |
|
313 |
return quality_label, *tag_components, other_description
|
314 |
|
315 |
+
|
316 |
+
|
317 |
def load_config_to_ui(config):
|
318 |
return (
|
319 |
config["quality_scale"]["name"],
|
|
|
323 |
[[field["name"], field["description"]] for field in config["free_text_fields"]]
|
324 |
)
|
325 |
|
326 |
+
|
327 |
+
|
328 |
def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
|
329 |
if all_topics_text.visible:
|
330 |
topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
|
|
|
344 |
save_annotation_config(new_config)
|
345 |
return "Configuration saved successfully", new_config
|
346 |
|
347 |
+
|
348 |
+
|
349 |
# Add this new function to generate the preview
|
350 |
def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
351 |
try:
|
|
|
368 |
except json.JSONDecodeError:
|
369 |
return "Error: Invalid JSON in the current row data"
|
370 |
|
371 |
+
|
372 |
+
|
373 |
def load_dataset_config():
|
374 |
params = load_params()
|
375 |
with open("system_messages.py", "r") as f:
|
|
|
396 |
params.get('presence_penalty', 0.0)
|
397 |
)
|
398 |
|
399 |
+
|
400 |
+
|
401 |
def edit_all_topics_func(topics):
|
402 |
topics_list = [topic[0] for topic in topics]
|
403 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
|
|
407 |
gr.update(visible=True)
|
408 |
)
|
409 |
|
410 |
+
|
411 |
+
|
412 |
def update_topics_from_text(text):
|
413 |
try:
|
414 |
# Try parsing as JSONL
|
|
|
419 |
|
420 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
421 |
|
422 |
+
|
423 |
+
|
424 |
def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
425 |
# Save VODALUS_SYSTEM_MESSAGE to system_messages.py
|
426 |
with open("system_messages.py", "w") as f:
|
|
|
481 |
print(f"Error in chat_with_llm: {str(e)}")
|
482 |
return history + [[message, f"Error: {str(e)}"]]
|
483 |
|
484 |
+
|
485 |
def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
486 |
context = f"""Current app state:
|
487 |
Row: {index + 1}/{total}
|
|
|
496 |
return [[None, context]]
|
497 |
|
498 |
|
499 |
+
|
500 |
+
async def run_generate_dataset(num_workers, num_generations, output_file_path, loaded_dataset):
|
501 |
+
if loaded_dataset is None:
|
502 |
+
return "Error: No dataset loaded. Please load a dataset before generating.", ""
|
503 |
+
|
504 |
generated_data = []
|
505 |
for _ in range(num_generations):
|
506 |
topic_selected = random.choice(TOPICS)
|
507 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
508 |
+
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path, loaded_dataset)
|
509 |
if data:
|
510 |
generated_data.append(json.dumps(data))
|
511 |
|
|
|
516 |
|
517 |
return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
|
518 |
|
519 |
+
|
520 |
+
|
521 |
def add_topic_row(data):
|
522 |
if isinstance(data, pd.DataFrame):
|
523 |
return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
|
524 |
else:
|
525 |
return data + [["New Topic"]]
|
526 |
|
527 |
+
|
528 |
+
|
529 |
def remove_last_topic_row(data):
|
530 |
return data[:-1] if len(data) > 1 else data
|
531 |
|
532 |
+
|
533 |
+
|
534 |
def edit_all_topics_func(topics):
|
535 |
topics_list = [topic[0] for topic in topics]
|
536 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
|
|
540 |
gr.update(visible=True)
|
541 |
)
|
542 |
|
543 |
+
|
544 |
+
|
545 |
def update_topics_from_text(text):
|
546 |
try:
|
547 |
# Try parsing as JSONL
|
|
|
552 |
|
553 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
554 |
|
555 |
+
|
556 |
+
|
557 |
def update_topics_from_text(text):
|
558 |
try:
|
559 |
# Try parsing as JSONL
|
|
|
564 |
|
565 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
566 |
|
567 |
+
|
568 |
+
|
569 |
+
def search_huggingface_datasets(query):
|
570 |
+
try:
|
571 |
+
api = HfApi()
|
572 |
+
datasets = api.list_datasets(search=query, limit=20)
|
573 |
+
dataset_ids = [dataset.id for dataset in datasets]
|
574 |
+
return gr.update(choices=dataset_ids, visible=True), ""
|
575 |
+
except Exception as e:
|
576 |
+
print(f"Error searching datasets: {str(e)}")
|
577 |
+
return gr.update(choices=["Error: Could not search datasets"], visible=True), ""
|
578 |
+
|
579 |
+
|
580 |
+
|
581 |
+
def load_huggingface_dataset(dataset_name, split="train"):
|
582 |
+
try:
|
583 |
+
print(f"Attempting to load dataset: {dataset_name}")
|
584 |
+
|
585 |
+
# Check if dataset_name is a string
|
586 |
+
if not isinstance(dataset_name, str):
|
587 |
+
raise ValueError(f"Expected dataset_name to be a string, but got {type(dataset_name)}")
|
588 |
+
|
589 |
+
# Try loading the dataset without specifying a config
|
590 |
+
full_dataset = load_dataset(dataset_name)
|
591 |
+
|
592 |
+
print(f"Dataset loaded. Available splits: {list(full_dataset.keys())}")
|
593 |
+
|
594 |
+
# Select the appropriate split
|
595 |
+
if split in full_dataset:
|
596 |
+
dataset = full_dataset[split]
|
597 |
+
print(f"Using specified split: {split}")
|
598 |
+
else:
|
599 |
+
available_splits = list(full_dataset.keys())
|
600 |
+
if available_splits:
|
601 |
+
dataset = full_dataset[available_splits[0]]
|
602 |
+
split = available_splits[0]
|
603 |
+
print(f"Specified split not found. Using first available split: {split}")
|
604 |
+
else:
|
605 |
+
raise ValueError("No valid splits found in the dataset")
|
606 |
+
|
607 |
+
return dataset, f"Dataset '{dataset_name}' (split: {split}) loaded successfully."
|
608 |
+
except Exception as e:
|
609 |
+
error_msg = f"Error loading dataset: {str(e)}"
|
610 |
+
print(f"Error details: {error_msg}")
|
611 |
+
|
612 |
+
# If loading fails, try to get the dataset card
|
613 |
+
try:
|
614 |
+
dataset_card = hf_hub_download(repo_id=dataset_name, filename="README.md")
|
615 |
+
with open(dataset_card, 'r') as f:
|
616 |
+
card_content = f.read()
|
617 |
+
return None, f"Dataset couldn't be loaded, but here's the dataset card:\n\n{card_content[:500]}..."
|
618 |
+
except:
|
619 |
+
return None, error_msg
|
620 |
+
|
621 |
+
# Wrapper function to handle the Gradio interface
|
622 |
+
def load_dataset_wrapper(dataset_name, split):
|
623 |
+
if not dataset_name:
|
624 |
+
return None, "Please enter a dataset name."
|
625 |
+
dataset, message = load_huggingface_dataset(dataset_name, split)
|
626 |
+
return dataset, message
|
627 |
+
|
628 |
+
|
629 |
+
def get_popular_datasets():
|
630 |
+
return [
|
631 |
+
"wikipedia",
|
632 |
+
"squad",
|
633 |
+
"glue",
|
634 |
+
"imdb",
|
635 |
+
"wmt16",
|
636 |
+
"common_voice",
|
637 |
+
"cnn_dailymail",
|
638 |
+
"amazon_reviews_multi",
|
639 |
+
"yelp_review_full",
|
640 |
+
"ag_news"
|
641 |
+
]
|
642 |
+
|
643 |
css = """
|
644 |
body, #root {
|
645 |
margin: 0;
|
|
|
886 |
with gr.Row():
|
887 |
save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
|
888 |
dataset_config_status = gr.Textbox(label="Status")
|
889 |
+
|
890 |
+
# gr.Markdown("### Hugging Face Dataset")
|
891 |
+
# with gr.Row():
|
892 |
+
# dataset_search = gr.Textbox(label="Search Datasets")
|
893 |
+
# search_button = gr.Button("Search")
|
894 |
+
# dataset_input = gr.Textbox(label="Dataset Name", info="Enter a dataset name or select from search results")
|
895 |
+
# dataset_results = gr.Radio(label="Search Results", choices=[], visible=False)
|
896 |
+
# dataset_split = gr.Textbox(label="Dataset Split (optional)", value="train")
|
897 |
+
# load_dataset_button = gr.Button("Load Selected Dataset")
|
898 |
+
# dataset_status = gr.Textbox(label="Dataset Status")
|
899 |
+
|
900 |
+
# Add a state to store the loaded dataset
|
901 |
+
# loaded_dataset = gr.State(None)
|
902 |
+
|
903 |
|
904 |
|
905 |
with gr.Tab("Dataset Generation"):
|
|
|
1049 |
|
1050 |
start_generation_btn.click(
|
1051 |
run_generate_dataset,
|
1052 |
+
inputs=[num_workers, num_generations, output_file_path, loaded_dataset],
|
1053 |
outputs=[generation_status, generation_output]
|
1054 |
)
|
1055 |
|
|
|
1075 |
outputs=[chatbot]
|
1076 |
)
|
1077 |
|
1078 |
+
search_button.click(
|
1079 |
+
search_huggingface_datasets,
|
1080 |
+
inputs=[dataset_search],
|
1081 |
+
outputs=[dataset_results, dataset_input]
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
dataset_results.change(
|
1085 |
+
lambda choice: choice,
|
1086 |
+
inputs=[dataset_results],
|
1087 |
+
outputs=[dataset_input]
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
load_dataset_button.click(
|
1091 |
+
load_dataset_wrapper,
|
1092 |
+
inputs=[dataset_input, dataset_split],
|
1093 |
+
outputs=[loaded_dataset, dataset_status]
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
# Modify the start_generation_btn.click to include the loaded dataset
|
1097 |
+
start_generation_btn.click(
|
1098 |
+
run_generate_dataset,
|
1099 |
+
inputs=[num_workers, num_generations, output_file_path, loaded_dataset],
|
1100 |
+
outputs=[generation_status, generation_output]
|
1101 |
+
)
|
1102 |
|
1103 |
demo.load(
|
1104 |
lambda: (
|