BeTaLabs commited on
Commit
113eb1c
1 Parent(s): 8158179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +493 -72
app.py CHANGED
@@ -1,20 +1,54 @@
1
  import gradio as gr
 
2
  import json
3
  import re
4
  from datetime import datetime
5
  from typing import Literal
6
  import os
7
  import importlib
8
- from llm_handler import send_to_llm, agent, settings
9
  from main import generate_data, PROMPT_1
10
  from topics import TOPICS
11
  from system_messages import SYSTEM_MESSAGES_VODALUS
12
  import random
 
 
 
 
13
 
14
 
15
  ANNOTATION_CONFIG_FILE = "annotation_config.json"
16
  OUTPUT_FILE_PATH = "dataset.jsonl"
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def load_annotation_config():
19
  try:
20
  with open(ANNOTATION_CONFIG_FILE, 'r') as f:
@@ -57,6 +91,19 @@ def load_annotation_config():
57
  ]
58
  }
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def save_annotation_config(config):
61
  with open(ANNOTATION_CONFIG_FILE, 'w') as f:
62
  json.dump(config, f, indent=2)
@@ -66,8 +113,44 @@ def load_jsonl_dataset(file_path):
66
  return []
67
  with open(file_path, 'r') as f:
68
  return [json.loads(line.strip()) for line in f if line.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def save_row(file_path, index, row_data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  with open(file_path, 'r') as f:
72
  lines = f.readlines()
73
 
@@ -75,8 +158,23 @@ def save_row(file_path, index, row_data):
75
 
76
  with open(file_path, 'w') as f:
77
  f.writelines(lines)
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- return f"Row {index} saved successfully"
 
 
 
 
80
 
81
  def get_row(file_path, index):
82
  data = load_jsonl_dataset(file_path)
@@ -106,19 +204,19 @@ def markdown_to_json(markdown_str):
106
  }
107
  return json.dumps(json_data, indent=2)
108
 
109
- def navigate_rows(file_path: str, current_index: int, direction: Literal[-1, 1], metadata_config):
110
- new_index = max(0, current_index + direction)
111
  return load_and_show_row(file_path, new_index, metadata_config)
112
 
113
  def load_and_show_row(file_path, index, metadata_config):
114
  row_data, total = get_row(file_path, index)
115
  if not row_data:
116
- return ("", index, total, "3", [], [], [], "")
117
 
118
  try:
119
  data = json.loads(row_data)
120
  except json.JSONDecodeError:
121
- return (row_data, index, total, "3", [], [], [], "Error: Invalid JSON")
122
 
123
  metadata = data.get("metadata", {}).get("annotation", {})
124
 
@@ -128,7 +226,7 @@ def load_and_show_row(file_path, index, metadata_config):
128
  toxic_tags = metadata.get("tags", {}).get("toxic", [])
129
  other = metadata.get("free_text", {}).get("Additional Notes", "")
130
 
131
- return (row_data, index, total, quality,
132
  high_quality_tags, low_quality_tags, toxic_tags, other)
133
 
134
  def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
@@ -182,7 +280,12 @@ def load_config_to_ui(config):
182
  [[field["name"], field["description"]] for field in config["free_text_fields"]]
183
  )
184
 
185
- def save_config_from_ui(name, description, scale, categories, fields):
 
 
 
 
 
186
  new_config = {
187
  "quality_scale": {
188
  "name": name,
@@ -190,7 +293,8 @@ def save_config_from_ui(name, description, scale, categories, fields):
190
  "scale": [{"value": row[0], "label": row[1]} for row in scale]
191
  },
192
  "tag_categories": [{"name": row[0], "type": row[1], "tags": row[2].split(", ")} for row in categories],
193
- "free_text_fields": [{"name": row[0], "description": row[1]} for row in fields]
 
194
  }
195
  save_annotation_config(new_config)
196
  return "Configuration saved successfully", new_config
@@ -218,7 +322,7 @@ def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, tox
218
  return "Error: Invalid JSON in the current row data"
219
 
220
  def load_dataset_config():
221
- # Load VODALUS_SYSTEM_MESSAGE from system_messages.py
222
  with open("system_messages.py", "r") as f:
223
  system_messages_content = f.read()
224
  vodalus_system_message = re.search(r'SYSTEM_MESSAGES_VODALUS = \[(.*?)\]', system_messages_content, re.DOTALL).group(1).strip()[3:-3] # Extract the content between triple quotes
@@ -232,9 +336,37 @@ def load_dataset_config():
232
  topics_module = importlib.import_module("topics")
233
  topics_list = topics_module.TOPICS
234
 
235
- return vodalus_system_message, prompt_1, [[topic] for topic in topics_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- def save_dataset_config(system_messages, prompt_1, topics):
238
  # Save VODALUS_SYSTEM_MESSAGE to system_messages.py
239
  with open("system_messages.py", "w") as f:
240
  f.write(f'SYSTEM_MESSAGES_VODALUS = [\n"""\n{system_messages}\n""",\n]\n')
@@ -261,8 +393,17 @@ def save_dataset_config(system_messages, prompt_1, topics):
261
 
262
  with open("topics.py", "w") as f:
263
  f.write(topics_content)
 
 
 
 
 
 
 
 
264
 
265
  return "Dataset configuration saved successfully"
 
266
 
267
 
268
  def chat_with_llm(message, history):
@@ -273,7 +414,12 @@ def chat_with_llm(message, history):
273
  msg_list.append({"role": "assistant", "content": h[1]})
274
  msg_list.append({"role": "user", "content": message})
275
 
276
- response, _ = send_to_llm(agent, msg_list)
 
 
 
 
 
277
 
278
  return history + [[message, response]]
279
  except Exception as e:
@@ -283,14 +429,15 @@ def chat_with_llm(message, history):
283
  def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
284
  context = f"""Current app state:
285
  Row: {index + 1}/{total}
286
- Data: {row_data}
287
  Quality: {quality}
288
  High Quality Tags: {', '.join(high_quality_tags)}
289
  Low Quality Tags: {', '.join(low_quality_tags)}
290
  Toxic Tags: {', '.join(toxic_tags)}
291
  Additional Notes: {other}
 
 
292
  """
293
- return [[None, context]] # Return as a list of message pairs
294
 
295
 
296
  async def run_generate_dataset(num_workers, num_generations, output_file_path):
@@ -309,34 +456,191 @@ async def run_generate_dataset(num_workers, num_generations, output_file_path):
309
 
310
  return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
311
 
312
- demo = gr.Blocks()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  with demo:
315
- gr.Markdown("# JSONL Dataset Editor and Annotation Tool")
316
 
317
  config = gr.State(load_annotation_config())
318
 
319
  with gr.Row():
320
- with gr.Column(scale=3):
321
  with gr.Tab("Dataset Editor"):
322
- with gr.Row():
323
- file_path = gr.Textbox(label="JSONL File Path", value=OUTPUT_FILE_PATH)
324
- load_button = gr.Button("Load Dataset")
 
 
325
 
326
  with gr.Row():
327
  prev_button = gr.Button("← Previous")
328
- row_index = gr.Number(value=0, label="Current Row", precision=0)
329
- total_rows = gr.Number(value=0, label="Total Rows", precision=0)
 
330
  next_button = gr.Button("Next →")
331
 
332
  with gr.Row():
333
  with gr.Column(scale=3):
334
- row_editor = gr.TextArea(label="Edit Row", lines=20)
335
 
336
  with gr.Column(scale=2):
337
  quality_label = gr.Radio(label="Relevance for Training", choices=[])
338
  tag_components = [gr.CheckboxGroup(label=f"Tag Group {i+1}", choices=[]) for i in range(3)]
339
  other_description = gr.Textbox(label="Additional annotations", lines=3)
 
 
 
 
 
 
340
 
341
  with gr.Row():
342
  to_markdown_button = gr.Button("Convert to Markdown")
@@ -349,50 +653,94 @@ with demo:
349
 
350
  with gr.Tab("Annotation Configuration"):
351
  with gr.Row():
352
- with gr.Column():
353
- quality_scale_name = gr.Textbox(label="Quality Scale Name")
354
- quality_scale_description = gr.Textbox(label="Quality Scale Description")
 
 
 
355
  quality_scale = gr.Dataframe(
356
  headers=["Value", "Label"],
357
  datatype=["str", "str"],
358
- label="Quality Scale",
359
- interactive=True
 
 
 
 
360
  )
361
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  with gr.Row():
363
- tag_categories = gr.Dataframe(
364
- headers=["Name", "Type", "Tags"],
365
- datatype=["str", "str", "str"],
366
- label="Tag Categories",
367
- interactive=True
368
- )
 
 
 
 
 
 
 
 
369
 
370
  with gr.Row():
371
- free_text_fields = gr.Dataframe(
372
- headers=["Name", "Description"],
373
- datatype=["str", "str"],
374
- label="Free Text Fields",
375
- interactive=True
376
- )
377
 
378
- save_config_btn = gr.Button("Save Configuration")
379
- config_status = gr.Textbox(label="Status")
 
 
380
 
381
  with gr.Tab("Dataset Configuration"):
382
  with gr.Row():
383
- vodalus_system_message = gr.TextArea(label="VODALUS_SYSTEM_MESSAGE", lines=10)
384
- prompt_1 = gr.TextArea(label="PROMPT_1", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  with gr.Row():
387
- topics = gr.Dataframe(
388
- headers=["Topic"],
389
- datatype=["str"],
390
- label="TOPICS",
391
- interactive=True
392
- )
393
 
394
- save_dataset_config_btn = gr.Button("Save Dataset Configuration")
395
- dataset_config_status = gr.Textbox(label="Status")
396
 
397
  with gr.Tab("Dataset Generation"):
398
  with gr.Row():
@@ -406,16 +754,54 @@ with demo:
406
  generation_status = gr.Textbox(label="Generation Status")
407
  generation_output = gr.TextArea(label="Generation Output", lines=10)
408
 
409
- with gr.Column(scale=1):
410
- gr.Markdown("## AI Assistant")
411
- chatbot = gr.Chatbot(height=600)
412
- msg = gr.Textbox(label="Chat with AI Assistant")
413
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  load_button.click(
416
- load_and_show_row,
417
- inputs=[file_path, gr.Number(value=0), config],
418
- outputs=[row_editor, row_index, total_rows, quality_label, *tag_components, other_description]
419
  ).then(
420
  update_annotation_ui,
421
  inputs=[config],
@@ -424,8 +810,8 @@ with demo:
424
 
425
  prev_button.click(
426
  navigate_rows,
427
- inputs=[file_path, row_index, gr.Number(value=-1), config],
428
- outputs=[row_editor, row_index, total_rows, quality_label, *tag_components, other_description]
429
  ).then(
430
  update_annotation_ui,
431
  inputs=[config],
@@ -434,8 +820,8 @@ with demo:
434
 
435
  next_button.click(
436
  navigate_rows,
437
- inputs=[file_path, row_index, gr.Number(value=1), config],
438
- outputs=[row_editor, row_index, total_rows, quality_label, *tag_components, other_description]
439
  ).then(
440
  update_annotation_ui,
441
  inputs=[config],
@@ -444,7 +830,7 @@ with demo:
444
 
445
  save_row_button.click(
446
  save_row_with_metadata,
447
- inputs=[file_path, row_index, row_editor, config, quality_label,
448
  tag_components[0], tag_components[1], tag_components[2], other_description],
449
  outputs=[editor_status]
450
  ).then(
@@ -476,7 +862,7 @@ with demo:
476
 
477
  save_config_btn.click(
478
  save_config_from_ui,
479
- inputs=[quality_scale_name, quality_scale_description, quality_scale, tag_categories, free_text_fields],
480
  outputs=[config_status, config]
481
  ).then(
482
  update_annotation_ui,
@@ -492,12 +878,12 @@ with demo:
492
 
493
  demo.load(
494
  load_dataset_config,
495
- outputs=[vodalus_system_message, prompt_1, topics]
496
  )
497
 
498
  save_dataset_config_btn.click(
499
  save_dataset_config,
500
- inputs=[vodalus_system_message, prompt_1, topics],
501
  outputs=[dataset_config_status]
502
  )
503
 
@@ -507,10 +893,21 @@ with demo:
507
  outputs=[generation_status, generation_output]
508
  )
509
 
 
 
 
 
 
 
 
 
 
 
 
510
  msg.submit(chat_with_llm, [msg, chatbot], [chatbot])
511
  clear.click(lambda: None, None, chatbot, queue=False)
512
 
513
- # Update chat context when navigating rows or loading dataset
514
  for button in [load_button, prev_button, next_button]:
515
  button.click(
516
  update_chat_context,
@@ -518,6 +915,30 @@ with demo:
518
  outputs=[chatbot]
519
  )
520
 
521
- if __name__ == "__main__":
522
- demo.launch(share=True)
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio import update
3
  import json
4
  import re
5
  from datetime import datetime
6
  from typing import Literal
7
  import os
8
  import importlib
9
+ from llm_handler import send_to_llm
10
  from main import generate_data, PROMPT_1
11
  from topics import TOPICS
12
  from system_messages import SYSTEM_MESSAGES_VODALUS
13
  import random
14
+ from params import load_params, save_params
15
+ import pandas as pd
16
+ import csv
17
+
18
 
19
 
20
  ANNOTATION_CONFIG_FILE = "annotation_config.json"
21
  OUTPUT_FILE_PATH = "dataset.jsonl"
22
 
23
+ def load_llm_config():
24
+ params = load_params()
25
+ return (
26
+ params.get('PROVIDER', ''),
27
+ params.get('BASE_URL', ''),
28
+ params.get('WORKSPACE', ''),
29
+ params.get('API_KEY', ''),
30
+ params.get('max_tokens', 2048),
31
+ params.get('temperature', 0.7),
32
+ params.get('top_p', 0.9),
33
+ params.get('frequency_penalty', 0.0),
34
+ params.get('presence_penalty', 0.0)
35
+ )
36
+
37
+ def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
38
+ save_params({
39
+ 'PROVIDER': provider,
40
+ 'BASE_URL': base_url,
41
+ 'WORKSPACE': workspace,
42
+ 'API_KEY': api_key,
43
+ 'max_tokens': max_tokens,
44
+ 'temperature': temperature,
45
+ 'top_p': top_p,
46
+ 'frequency_penalty': frequency_penalty,
47
+ 'presence_penalty': presence_penalty
48
+ })
49
+ return "LLM configuration saved successfully"
50
+
51
+
52
  def load_annotation_config():
53
  try:
54
  with open(ANNOTATION_CONFIG_FILE, 'r') as f:
 
91
  ]
92
  }
93
 
94
+
95
+ def load_csv_dataset(file_path):
96
+ data = []
97
+ with open(file_path, 'r') as f:
98
+ reader = csv.DictReader(f)
99
+ for row in reader:
100
+ data.append(row)
101
+ return data
102
+
103
+ def load_txt_dataset(file_path):
104
+ with open(file_path, 'r') as f:
105
+ return [{"content": line.strip()} for line in f if line.strip()]
106
+
107
  def save_annotation_config(config):
108
  with open(ANNOTATION_CONFIG_FILE, 'w') as f:
109
  json.dump(config, f, indent=2)
 
113
  return []
114
  with open(file_path, 'r') as f:
115
  return [json.loads(line.strip()) for line in f if line.strip()]
116
+
117
+ def load_dataset(file):
118
+ if file is None:
119
+ return "", 0, 0, "No file uploaded", "3", [], [], [], ""
120
+
121
+ file_path = file.name
122
+ file_extension = os.path.splitext(file_path)[1].lower()
123
+
124
+ if file_extension == '.csv':
125
+ data = load_csv_dataset(file_path)
126
+ elif file_extension == '.txt':
127
+ data = load_txt_dataset(file_path)
128
+ elif file_extension == '.jsonl':
129
+ data = load_jsonl_dataset(file_path)
130
+ else:
131
+ return "", 0, 0, f"Unsupported file type: {file_extension}", "3", [], [], [], ""
132
+
133
+ if not data:
134
+ return "", 0, 0, "No data found in the file", "3", [], [], [], ""
135
+
136
+ first_row = json.dumps(data[0], indent=2)
137
+ return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
138
 
139
  def save_row(file_path, index, row_data):
140
+ file_extension = file_path.split('.')[-1].lower()
141
+
142
+ if file_extension == 'jsonl':
143
+ save_jsonl_row(file_path, index, row_data)
144
+ elif file_extension == 'csv':
145
+ save_csv_row(file_path, index, row_data)
146
+ elif file_extension == 'txt':
147
+ save_txt_row(file_path, index, row_data)
148
+ else:
149
+ raise ValueError(f"Unsupported file format: {file_extension}")
150
+
151
+ return f"Row {index} saved successfully"
152
+
153
+ def save_jsonl_row(file_path, index, row_data):
154
  with open(file_path, 'r') as f:
155
  lines = f.readlines()
156
 
 
158
 
159
  with open(file_path, 'w') as f:
160
  f.writelines(lines)
161
+
162
+ def save_csv_row(file_path, index, row_data):
163
+ df = pd.read_csv(file_path)
164
+ row_dict = json.loads(row_data)
165
+ for col, value in row_dict.items():
166
+ df.at[index, col] = value
167
+ df.to_csv(file_path, index=False)
168
+
169
+ def save_txt_row(file_path, index, row_data):
170
+ with open(file_path, 'r') as f:
171
+ lines = f.readlines()
172
 
173
+ row_dict = json.loads(row_data)
174
+ lines[index] = row_dict.get('content', '') + '\n'
175
+
176
+ with open(file_path, 'w') as f:
177
+ f.writelines(lines)
178
 
179
  def get_row(file_path, index):
180
  data = load_jsonl_dataset(file_path)
 
204
  }
205
  return json.dumps(json_data, indent=2)
206
 
207
+ def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
208
+ new_index = max(0, current_index + (-1 if direction == "prev" else 1))
209
  return load_and_show_row(file_path, new_index, metadata_config)
210
 
211
  def load_and_show_row(file_path, index, metadata_config):
212
  row_data, total = get_row(file_path, index)
213
  if not row_data:
214
+ return ("", index, total, f"Row: {index + 1}/{total}", "3", [], [], [], "")
215
 
216
  try:
217
  data = json.loads(row_data)
218
  except json.JSONDecodeError:
219
+ return (row_data, index, total, f"Row: {index + 1}/{total}", "3", [], [], [], "Error: Invalid JSON")
220
 
221
  metadata = data.get("metadata", {}).get("annotation", {})
222
 
 
226
  toxic_tags = metadata.get("tags", {}).get("toxic", [])
227
  other = metadata.get("free_text", {}).get("Additional Notes", "")
228
 
229
+ return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
230
  high_quality_tags, low_quality_tags, toxic_tags, other)
231
 
232
  def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
 
280
  [[field["name"], field["description"]] for field in config["free_text_fields"]]
281
  )
282
 
283
+ def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
284
+ if all_topics_text.visible:
285
+ topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
286
+ else:
287
+ topics_list = [topic[0] for topic in topics]
288
+
289
  new_config = {
290
  "quality_scale": {
291
  "name": name,
 
293
  "scale": [{"value": row[0], "label": row[1]} for row in scale]
294
  },
295
  "tag_categories": [{"name": row[0], "type": row[1], "tags": row[2].split(", ")} for row in categories],
296
+ "free_text_fields": [{"name": row[0], "description": row[1]} for row in fields],
297
+ "topics": topics_list
298
  }
299
  save_annotation_config(new_config)
300
  return "Configuration saved successfully", new_config
 
322
  return "Error: Invalid JSON in the current row data"
323
 
324
  def load_dataset_config():
325
+ params = load_params()
326
  with open("system_messages.py", "r") as f:
327
  system_messages_content = f.read()
328
  vodalus_system_message = re.search(r'SYSTEM_MESSAGES_VODALUS = \[(.*?)\]', system_messages_content, re.DOTALL).group(1).strip()[3:-3] # Extract the content between triple quotes
 
336
  topics_module = importlib.import_module("topics")
337
  topics_list = topics_module.TOPICS
338
 
339
+ return (
340
+ vodalus_system_message,
341
+ prompt_1,
342
+ [[topic] for topic in topics_list],
343
+ params.get('max_tokens', 2048),
344
+ params.get('temperature', 0.7),
345
+ params.get('top_p', 0.9),
346
+ params.get('frequency_penalty', 0.0),
347
+ params.get('presence_penalty', 0.0)
348
+ )
349
+
350
+ def edit_all_topics_func(topics):
351
+ topics_list = [topic[0] for topic in topics]
352
+ jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
353
+ return (
354
+ gr.update(visible=False),
355
+ gr.update(value=jsonl_rows, visible=True),
356
+ gr.update(visible=True)
357
+ )
358
+
359
+ def update_topics_from_text(text):
360
+ try:
361
+ # Try parsing as JSONL
362
+ topics_list = [json.loads(line)["topic"] for line in text.split("\n") if line.strip()]
363
+ except json.JSONDecodeError:
364
+ # If parsing fails, treat as plain text
365
+ topics_list = [topic.strip() for topic in text.split("\n") if topic.strip()]
366
+
367
+ return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
368
 
369
+ def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
370
  # Save VODALUS_SYSTEM_MESSAGE to system_messages.py
371
  with open("system_messages.py", "w") as f:
372
  f.write(f'SYSTEM_MESSAGES_VODALUS = [\n"""\n{system_messages}\n""",\n]\n')
 
393
 
394
  with open("topics.py", "w") as f:
395
  f.write(topics_content)
396
+
397
+ save_params({
398
+ 'max_tokens': max_tokens,
399
+ 'temperature': temperature,
400
+ 'top_p': top_p,
401
+ 'frequency_penalty': frequency_penalty,
402
+ 'presence_penalty': presence_penalty
403
+ })
404
 
405
  return "Dataset configuration saved successfully"
406
+
407
 
408
 
409
  def chat_with_llm(message, history):
 
414
  msg_list.append({"role": "assistant", "content": h[1]})
415
  msg_list.append({"role": "user", "content": message})
416
 
417
+ response, _ = send_to_llm(msg_list)
418
+
419
+ return history + [[message, response]]
420
+ except Exception as e:
421
+ print(f"Error in chat_with_llm: {str(e)}")
422
+ return history + [[message, f"Error: {str(e)}"]]
423
 
424
  return history + [[message, response]]
425
  except Exception as e:
 
429
  def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
430
  context = f"""Current app state:
431
  Row: {index + 1}/{total}
 
432
  Quality: {quality}
433
  High Quality Tags: {', '.join(high_quality_tags)}
434
  Low Quality Tags: {', '.join(low_quality_tags)}
435
  Toxic Tags: {', '.join(toxic_tags)}
436
  Additional Notes: {other}
437
+
438
+ Data: {row_data}
439
  """
440
+ return [[None, context]]
441
 
442
 
443
  async def run_generate_dataset(num_workers, num_generations, output_file_path):
 
456
 
457
  return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
458
 
459
+ def add_topic_row(data):
460
+ if isinstance(data, pd.DataFrame):
461
+ return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
462
+ else:
463
+ return data + [["New Topic"]]
464
+
465
+ def remove_last_topic_row(data):
466
+ return data[:-1] if len(data) > 1 else data
467
+
468
+ def edit_all_topics_func(topics):
469
+ topics_list = [topic[0] for topic in topics]
470
+ jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
471
+ return (
472
+ gr.update(visible=False),
473
+ gr.update(value=jsonl_rows, visible=True),
474
+ gr.update(visible=True)
475
+ )
476
+
477
+ def update_topics_from_text(text):
478
+ try:
479
+ # Try parsing as JSONL
480
+ topics_list = [json.loads(line)["topic"] for line in text.split("\n") if line.strip()]
481
+ except json.JSONDecodeError:
482
+ # If parsing fails, treat as plain text
483
+ topics_list = [topic.strip() for topic in text.split("\n") if topic.strip()]
484
+
485
+ return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
486
+
487
+ def update_topics_from_text(text):
488
+ try:
489
+ # Try parsing as JSONL
490
+ topics_list = [json.loads(line)["topic"] for line in text.split("\n") if line.strip()]
491
+ except json.JSONDecodeError:
492
+ # If parsing fails, treat as plain text
493
+ topics_list = [topic.strip() for topic in text.split("\n") if topic.strip()]
494
+
495
+ return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
496
+
497
+ css = """
498
+ body, #root {
499
+ margin: 0;
500
+ padding: 0;
501
+ width: 100%;
502
+ height: 100%;
503
+ overflow-x: hidden;
504
+ }
505
+ .gradio-container {
506
+ max-width: 100% !important;
507
+ width: 100% !important;
508
+ margin: 0 auto !important;
509
+ padding: 0 !important;
510
+ }
511
+ .message-row {
512
+ justify-content: space-evenly !important;
513
+ }
514
+ .message-bubble-border {
515
+ border-radius: 6px !important;
516
+ }
517
+ .message-buttons-bot, .message-buttons-user {
518
+ right: 10px !important;
519
+ left: auto !important;
520
+ bottom: 2px !important;
521
+ }
522
+ .dark.message-bubble-border {
523
+ border-color: #343140 !important;
524
+ }
525
+ .dark.user {
526
+ background: #1e1c26 !important;
527
+ }
528
+ .dark.assistant.dark, .dark.pending.dark {
529
+ background: #16141c !important;
530
+ }
531
+ .tab-nav {
532
+ border-bottom: 2px solid #e0e0e0 !important;
533
+ }
534
+ .tab-nav button {
535
+ font-size: 16px !important;
536
+ padding: 10px 20px !important;
537
+ }
538
+ .input-row {
539
+ margin-bottom: 20px !important;
540
+ }
541
+ .button-row {
542
+ display: flex !important;
543
+ justify-content: space-between !important;
544
+ margin-top: 20px !important;
545
+ }
546
+ #row-editor {
547
+ height: 80vh !important;
548
+ font-size: 16px !important;
549
+ }
550
+
551
+ .file-upload-row {
552
+ height: 50px !important;
553
+ margin-bottom: 1rem !important;
554
+ }
555
+
556
+ .file-upload-row > .gr-column {
557
+ min-width: 0 !important;
558
+ }
559
+
560
+ .compact-file-upload {
561
+ height: 50px !important;
562
+ overflow: hidden !important;
563
+ }
564
+
565
+ .compact-file-upload > .file-preview {
566
+ min-height: 0 !important;
567
+ max-height: 50px !important;
568
+ padding: 0 !important;
569
+ }
570
+
571
+ .compact-file-upload > .file-preview > .file-preview-handler {
572
+ height: 50px !important;
573
+ padding: 0 8px !important;
574
+ display: flex !important;
575
+ align-items: center !important;
576
+ }
577
+
578
+ .compact-file-upload > .file-preview > .file-preview-handler > .file-preview-title {
579
+ white-space: nowrap !important;
580
+ overflow: hidden !important;
581
+ text-overflow: ellipsis !important;
582
+ flex: 1 !important;
583
+ }
584
+
585
+ .compact-file-upload > .file-preview > .file-preview-handler > .file-preview-remove {
586
+ padding: 0 !important;
587
+ min-width: 24px !important;
588
+ width: 24px !important;
589
+ height: 24px !important;
590
+ }
591
+
592
+ .compact-button {
593
+ height: 50px !important;
594
+ min-height: 40px !important;
595
+ width: 100% !important;
596
+ }
597
+
598
+ .compact-file-upload > label {
599
+ height: 50px !important;
600
+ padding: 0 8px !important;
601
+ display: flex !important;
602
+ align-items: center !important;
603
+ justify-content: left !important;
604
+ }
605
+ """
606
+
607
+ demo = gr.Blocks(theme='Ama434/neutral-barlow', css=css)
608
 
609
  with demo:
610
+ gr.Markdown("# Dataset Editor and Annotation Tool")
611
 
612
  config = gr.State(load_annotation_config())
613
 
614
  with gr.Row():
615
+ with gr.Column(min_width=1000):
616
  with gr.Tab("Dataset Editor"):
617
+ with gr.Row(elem_classes="file-upload-row"):
618
+ with gr.Column(scale=3, min_width=400):
619
+ file_upload = gr.File(label="Upload Dataset File (.txt, .jsonl, or .csv)", elem_classes="compact-file-upload")
620
+ with gr.Column(scale=1, min_width=100):
621
+ load_button = gr.Button("Load Dataset", elem_classes="compact-button")
622
 
623
  with gr.Row():
624
  prev_button = gr.Button("← Previous")
625
+ row_index = gr.State(value=0)
626
+ total_rows = gr.State(value=0)
627
+ current_row_display = gr.Textbox(label="Current Row", interactive=False)
628
  next_button = gr.Button("Next →")
629
 
630
  with gr.Row():
631
  with gr.Column(scale=3):
632
+ row_editor = gr.TextArea(label="Edit Row", lines=40)
633
 
634
  with gr.Column(scale=2):
635
  quality_label = gr.Radio(label="Relevance for Training", choices=[])
636
  tag_components = [gr.CheckboxGroup(label=f"Tag Group {i+1}", choices=[]) for i in range(3)]
637
  other_description = gr.Textbox(label="Additional annotations", lines=3)
638
+
639
+ # Add the AI Assistant as a dropdown
640
+ with gr.Accordion("AI Assistant", open=False):
641
+ chatbot = gr.Chatbot(height=300)
642
+ msg = gr.Textbox(label="Chat with AI Assistant")
643
+ clear = gr.Button("Clear")
644
 
645
  with gr.Row():
646
  to_markdown_button = gr.Button("Convert to Markdown")
 
653
 
654
  with gr.Tab("Annotation Configuration"):
655
  with gr.Row():
656
+ with gr.Column(scale=1):
657
+ gr.Markdown("### Quality Scale")
658
+ quality_scale_name = gr.Textbox(label="Scale Name")
659
+ quality_scale_description = gr.Textbox(label="Scale Description", lines=2)
660
+
661
+ with gr.Column(scale=2):
662
  quality_scale = gr.Dataframe(
663
  headers=["Value", "Label"],
664
  datatype=["str", "str"],
665
+ label="Quality Scale Options",
666
+ interactive=True,
667
+ col_count=(2, "fixed"),
668
+ row_count=(5, "dynamic"),
669
+ height=400,
670
+ wrap=True
671
  )
672
 
673
+ gr.Markdown("### Tag Categories")
674
+ tag_categories = gr.Dataframe(
675
+ headers=["Name", "Type", "Tags"],
676
+ datatype=["str", "str", "str"],
677
+ label="Tag Categories",
678
+ interactive=True,
679
+ col_count=(3, "fixed"),
680
+ row_count=(3, "dynamic"),
681
+ height=250,
682
+ wrap=True
683
+ )
684
+
685
  with gr.Row():
686
+ add_tag_category = gr.Button("Add Category")
687
+ remove_tag_category = gr.Button("Remove Last Category")
688
+
689
+ gr.Markdown("### Free Text Fields")
690
+ free_text_fields = gr.Dataframe(
691
+ headers=["Name", "Description"],
692
+ datatype=["str", "str"],
693
+ label="Free Text Fields",
694
+ interactive=True,
695
+ col_count=(2, "fixed"),
696
+ row_count=(2, "dynamic"),
697
+ height=300,
698
+ wrap=True
699
+ )
700
 
701
  with gr.Row():
702
+ add_free_text_field = gr.Button("Add Field")
703
+ remove_free_text_field = gr.Button("Remove Last Field")
 
 
 
 
704
 
705
+
706
+ with gr.Row():
707
+ save_config_btn = gr.Button("Save Configuration", variant="primary")
708
+ config_status = gr.Textbox(label="Status", interactive=False)
709
 
710
  with gr.Tab("Dataset Configuration"):
711
  with gr.Row():
712
+ vodalus_system_message = gr.TextArea(label="System Message for JSONL Dataset", lines=10)
713
+ prompt_1 = gr.TextArea(label="Dataset Gerenation Prompt", lines=10)
714
+
715
+ gr.Markdown("### Topics")
716
+ with gr.Row():
717
+ with gr.Column(scale=2):
718
+ topics = gr.Dataframe(
719
+ headers=["Topic"],
720
+ datatype=["str"],
721
+ label="Topics",
722
+ interactive=True,
723
+ col_count=(1, "fixed"),
724
+ row_count=(5, "dynamic"),
725
+ height=200,
726
+ wrap=True
727
+ )
728
+
729
+ with gr.Column(scale=1):
730
+ with gr.Row():
731
+ add_topic = gr.Button("Add Topic")
732
+ remove_topic = gr.Button("Remove Last Topic")
733
+ edit_all_topics = gr.Button("Edit All Topics")
734
+ all_topics_edit = gr.TextArea(label="Edit All Topics (JSONL or Plain Text)", visible=False, lines=10)
735
+ format_info = gr.Markdown("""
736
+ Enter topics as JSONL (e.g., {"topic": "Example Topic"}) or plain text (one topic per line).
737
+ JSONL format allows for additional metadata if needed.
738
+ """, visible=False)
739
 
740
  with gr.Row():
741
+ save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
742
+ dataset_config_status = gr.Textbox(label="Status")
 
 
 
 
743
 
 
 
744
 
745
  with gr.Tab("Dataset Generation"):
746
  with gr.Row():
 
754
  generation_status = gr.Textbox(label="Generation Status")
755
  generation_output = gr.TextArea(label="Generation Output", lines=10)
756
 
757
+ with gr.Tab("LLM Configuration"):
758
+ with gr.Row():
759
+ provider = gr.Dropdown(choices=["local-model", "anything-llm"], label="LLM Provider")
760
+ base_url = gr.Textbox(label="Base URL (for local model)")
761
+ with gr.Row():
762
+ workspace = gr.Textbox(label="Workspace (for AnythingLLM)")
763
+ api_key = gr.Textbox(label="API Key (for AnythingLLM)")
764
+
765
+ with gr.Accordion("Advanced Options", open=False):
766
+ with gr.Row():
767
+ max_tokens = gr.Slider(minimum=100, maximum=4096, value=2048, step=1, label="Max Tokens")
768
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.01, label="Temperature")
769
+ with gr.Row():
770
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="Top P")
771
+ frequency_penalty = gr.Slider(minimum=0, maximum=2, value=0.0, step=0.01, label="Frequency Penalty")
772
+ presence_penalty = gr.Slider(minimum=0, maximum=2, value=0.0, step=0.01, label="Presence Penalty")
773
+
774
+ save_llm_config_btn = gr.Button("Save LLM Configuration")
775
+ llm_config_status = gr.Textbox(label="Status")
776
+
777
+ add_topic.click(
778
+ lambda x: x + [["New Topic"]],
779
+ inputs=[topics],
780
+ outputs=[topics]
781
+ )
782
+
783
+ remove_topic.click(
784
+ lambda x: x[:-1] if len(x) > 0 else x,
785
+ inputs=[topics],
786
+ outputs=[topics]
787
+ )
788
+
789
+ edit_all_topics.click(
790
+ edit_all_topics_func,
791
+ inputs=[topics],
792
+ outputs=[topics, all_topics_edit, format_info]
793
+ )
794
+
795
+ all_topics_edit.submit(
796
+ update_topics_from_text,
797
+ inputs=[all_topics_edit],
798
+ outputs=[topics, all_topics_edit, format_info]
799
+ )
800
 
801
  load_button.click(
802
+ load_dataset,
803
+ inputs=[file_upload],
804
+ outputs=[row_editor, row_index, total_rows, current_row_display, quality_label, *tag_components, other_description]
805
  ).then(
806
  update_annotation_ui,
807
  inputs=[config],
 
810
 
811
  prev_button.click(
812
  navigate_rows,
813
+ inputs=[file_upload, row_index, gr.State("prev"), config],
814
+ outputs=[row_editor, row_index, total_rows, current_row_display, quality_label, *tag_components, other_description]
815
  ).then(
816
  update_annotation_ui,
817
  inputs=[config],
 
820
 
821
  next_button.click(
822
  navigate_rows,
823
+ inputs=[file_upload, row_index, gr.State("next"), config],
824
+ outputs=[row_editor, row_index, total_rows, current_row_display, quality_label, *tag_components, other_description]
825
  ).then(
826
  update_annotation_ui,
827
  inputs=[config],
 
830
 
831
  save_row_button.click(
832
  save_row_with_metadata,
833
+ inputs=[file_upload, row_index, row_editor, config, quality_label,
834
  tag_components[0], tag_components[1], tag_components[2], other_description],
835
  outputs=[editor_status]
836
  ).then(
 
862
 
863
  save_config_btn.click(
864
  save_config_from_ui,
865
+ inputs=[quality_scale_name, quality_scale_description, quality_scale, tag_categories, free_text_fields, topics, all_topics_edit],
866
  outputs=[config_status, config]
867
  ).then(
868
  update_annotation_ui,
 
878
 
879
  demo.load(
880
  load_dataset_config,
881
+ outputs=[vodalus_system_message, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty]
882
  )
883
 
884
  save_dataset_config_btn.click(
885
  save_dataset_config,
886
+ inputs=[vodalus_system_message, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty],
887
  outputs=[dataset_config_status]
888
  )
889
 
 
893
  outputs=[generation_status, generation_output]
894
  )
895
 
896
+ demo.load(
897
+ load_llm_config,
898
+ outputs=[provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty]
899
+ )
900
+
901
+ save_llm_config_btn.click(
902
+ save_llm_config,
903
+ inputs=[provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty],
904
+ outputs=[llm_config_status]
905
+ )
906
+
907
  msg.submit(chat_with_llm, [msg, chatbot], [chatbot])
908
  clear.click(lambda: None, None, chatbot, queue=False)
909
 
910
+
911
  for button in [load_button, prev_button, next_button]:
912
  button.click(
913
  update_chat_context,
 
915
  outputs=[chatbot]
916
  )
917
 
 
 
918
 
919
+ demo.load(
920
+ lambda: (
921
+ initial_values := load_dataset_config(),
922
+ gr.update(value=initial_values[0]), # vodalus_system_message
923
+ gr.update(value=initial_values[1]), # prompt_1
924
+ gr.update(value=initial_values[2]), # topics_data
925
+ gr.update(value=initial_values[3]), # max_tokens_val
926
+ gr.update(value=initial_values[4]), # temperature_val
927
+ gr.update(value=initial_values[5]), # top_p_val
928
+ gr.update(value=initial_values[6]), # frequency_penalty_val
929
+ gr.update(value=initial_values[7]) # presence_penalty_val
930
+ )[1:], # We return a tuple slice to exclude the initial_values assignment
931
+ outputs=[
932
+ vodalus_system_message,
933
+ prompt_1,
934
+ topics,
935
+ max_tokens,
936
+ temperature,
937
+ top_p,
938
+ frequency_penalty,
939
+ presence_penalty
940
+ ]
941
+ )
942
+
943
+ if __name__ == "__main__":
944
+ demo.launch(share=True)