Alexander Watson commited on
Commit
06594f2
1 Parent(s): fb73a92

update interfaces

Browse files
Files changed (1) hide show
  1. app.py +105 -83
app.py CHANGED
@@ -11,7 +11,11 @@ import requests
11
  import streamlit as st
12
  from datasets import load_dataset
13
  from gretel_client import Gretel
14
- from navigator_helpers import DataAugmentationConfig, DataAugmenter, StreamlitLogHandler
 
 
 
 
15
 
16
  # Create a StringIO buffer to capture the logging output
17
  log_buffer = StringIO()
@@ -103,6 +107,9 @@ def main():
103
  )
104
  if "gretel" not in st.session_state:
105
  st.session_state.gretel = None
 
 
 
106
  if st.button("Validate API Key"):
107
  if api_key:
108
  try:
@@ -340,6 +347,22 @@ def main():
340
  st.markdown("---")
341
  st.markdown("### Format Prompts")
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  instruction_format_prompt = st.text_area(
344
  "Instruction Format Prompt",
345
  value=st.session_state.get(
@@ -365,58 +388,53 @@ def main():
365
  st.write(
366
  "Get started with your current configuration using the SDK code below:"
367
  )
 
368
 
369
- config_text = f"""
370
- #!pip install -Uqq git+https://github.com/gretelai/navigator-helpers.git
371
-
372
- import logging
373
- import sys
374
- import pandas as pd
375
-
376
- from navigator_helpers import DataAugmentationConfig, DataAugmenter
377
-
378
- # Configure the logger
379
- logger = logging.getLogger()
380
- logger.setLevel(logging.INFO)
381
-
382
- DATASET = "YOUR_DATASET"
383
- API_KEY = "YOUR_API_KEY"
384
-
385
- df = pd.read_csv(DATASET)
386
-
387
- # Create the data augmentation configuration
388
- config = DataAugmentationConfig(
389
- input_fields={st.session_state.selected_fields},
390
- output_instruction_field="{output_instruction_field}",
391
- output_response_field="{output_response_field}",
392
- num_instructions={num_instructions},
393
- num_responses={num_responses},
394
- temperature={temperature},
395
- max_tokens_instruction={max_tokens_instruction},
396
- max_tokens_response={max_tokens_response},
397
- api_key=API_KEY,
398
- navigator_tabular="{navigator_tabular}",
399
- navigator_llm="{navigator_llm}",
400
- co_teach_llms={co_teach_llms},
401
- instruction_format_prompt='''{instruction_format_prompt}''',
402
- response_format_prompt='''{response_format_prompt}'''
403
- )
404
-
405
- # Create the data augmenter and perform augmentation
406
- augmenter = DataAugmenter(
407
- df,
408
- config,
409
- use_aaa={use_aaa},
410
- output_file="results.csv",
411
- verbose=True,
412
- )
413
- new_df = augmenter.augment()
414
- """
415
  st.code(config_text, language="python")
416
  st.download_button(
417
  label="Download SDK Code",
418
  data=config_text,
419
- file_name="data_augmentation_code.py",
420
  mime="text/plain",
421
  )
422
 
@@ -431,20 +449,20 @@ def main():
431
  if "logs" not in st.session_state:
432
  st.session_state.logs = []
433
 
434
- if "augmented_data" not in st.session_state:
435
- st.session_state.augmented_data = []
436
 
437
  if start_button:
438
- # Clear the augmented data and logs before starting a new generation
439
- st.session_state.augmented_data = []
440
  st.session_state.logs = []
441
 
442
  with st.expander("Synthetic Data", expanded=True):
443
  st.subheader("Synthetic Data Generation")
444
  progress_bar = st.progress(0)
445
- tab1, tab2 = st.tabs(["Augmented Data", "Logs"])
446
  with tab1:
447
- augmented_data_placeholder = st.empty()
448
  st.info(
449
  "Click on the 'Logs' tab to see and debug real-time logging for each record as it is generated by the agents."
450
  )
@@ -467,7 +485,7 @@ def main():
467
  handler = StreamlitLogHandler(custom_log_handler)
468
  logger.addHandler(handler)
469
 
470
- config = DataAugmentationConfig(
471
  input_fields=selected_fields,
472
  output_instruction_field=output_instruction_field,
473
  output_response_field=output_response_field,
@@ -480,26 +498,28 @@ def main():
480
  navigator_tabular=navigator_tabular,
481
  navigator_llm=navigator_llm,
482
  co_teach_llms=co_teach_llms,
 
483
  instruction_format_prompt=instruction_format_prompt,
484
  response_format_prompt=response_format_prompt,
485
  )
 
486
  start_time = time.time()
487
  with st.spinner("Generating synthetic data..."):
488
  for index in range(num_records):
489
  row = df.iloc[index]
490
- augmenter = DataAugmenter(
491
  pd.DataFrame([row]),
492
  config,
493
  use_aaa=use_aaa,
494
  output_file="results.csv",
495
  verbose=True,
496
  )
497
- new_df = augmenter.augment()
498
- st.session_state.augmented_data.append(new_df)
499
- augmented_data_placeholder.subheader("Synthetic Data")
500
- augmented_data_placeholder.dataframe(
501
  pd.concat(
502
- st.session_state.augmented_data, ignore_index=True
503
  )
504
  )
505
  progress = (index + 1) / num_records
@@ -520,11 +540,11 @@ def main():
520
 
521
  time.sleep(0.1)
522
  logger.removeHandler(handler)
523
- st.success("Data augmentation completed!")
524
  st.stop()
525
 
526
  if stop_button:
527
- st.warning("Augmentation stopped by the user.")
528
 
529
  # Get the complete logs from the session state
530
  complete_logs = st.session_state.logs
@@ -532,22 +552,22 @@ def main():
532
  # Convert complete logs to JSONL format
533
  log_jsonl = "\n".join([json.dumps({"log": log}) for log in complete_logs])
534
 
535
- # Convert augmented data to JSONL format if it exists
536
- if st.session_state.augmented_data:
537
- augmented_df = pd.concat(
538
- st.session_state.augmented_data, ignore_index=True
539
  )
540
- if not augmented_df.empty:
541
- augmented_data_jsonl = "\n".join(
542
  [
543
  json.dumps(row.to_dict())
544
- for _, row in augmented_df.iterrows()
545
  ]
546
  )
547
  else:
548
- augmented_data_jsonl = None
549
  else:
550
- augmented_data_jsonl = None
551
 
552
  # Create a temporary directory to store the files
553
  with tempfile.TemporaryDirectory() as temp_dir:
@@ -556,26 +576,28 @@ def main():
556
  with open(log_file_path, "w") as log_file:
557
  log_file.write(log_jsonl)
558
 
559
- # Write the augmented data to a file if it exists
560
- if augmented_data_jsonl:
561
- augmented_data_file_path = os.path.join(
562
  temp_dir, "synthetic_data.jsonl"
563
  )
564
- with open(augmented_data_file_path, "w") as augmented_data_file:
565
- augmented_data_file.write(augmented_data_jsonl)
566
 
567
  # Write the SDK code to a file
568
- sdk_file_path = os.path.join(temp_dir, "data_augmentation_code.py")
569
  with open(sdk_file_path, "w") as sdk_file:
570
  sdk_file.write(config_text)
571
 
572
- # Create a ZIP file containing the logs, augmented data, and SDK code
573
- zip_file_path = os.path.join(temp_dir, "augmentation_results.zip")
574
  with zipfile.ZipFile(zip_file_path, "w") as zip_file:
575
  zip_file.write(log_file_path, "complete_logs.jsonl")
576
- if augmented_data_jsonl:
577
- zip_file.write(augmented_data_file_path, "augmented_data.jsonl")
578
- zip_file.write(sdk_file_path, "data_augmentation_code.py")
 
 
579
 
580
  # Download the ZIP file
581
  with open(zip_file_path, "rb") as zip_file:
 
11
  import streamlit as st
12
  from datasets import load_dataset
13
  from gretel_client import Gretel
14
+ from navigator_helpers import (
15
+ DataSynthesisConfig,
16
+ TrainingDataSynthesizer,
17
+ StreamlitLogHandler,
18
+ )
19
 
20
  # Create a StringIO buffer to capture the logging output
21
  log_buffer = StringIO()
 
107
  )
108
  if "gretel" not in st.session_state:
109
  st.session_state.gretel = None
110
+ if "synthesized_data" not in st.session_state:
111
+ st.session_state.synthesized_data = []
112
+
113
  if st.button("Validate API Key"):
114
  if api_key:
115
  try:
 
347
  st.markdown("---")
348
  st.markdown("### Format Prompts")
349
 
350
+ st.markdown("---")
351
+ st.markdown("### Format Prompts")
352
+
353
+ system_prompt = st.text_area(
354
+ "System Prompt",
355
+ value=st.session_state.get(
356
+ "system_prompt",
357
+ "You are an AI assistant tasked with generating high-quality instruction-response pairs.\n"
358
+ "Your goal is to create diverse, engaging, and informative content that covers a wide range of topics.\n"
359
+ "When generating instructions, aim for clear, concise questions or commands that prompt thoughtful responses.\n"
360
+ "When generating responses, provide detailed, accurate, and helpful information that directly addresses the instruction.",
361
+ ),
362
+ help="Specify the system prompt for the LLM",
363
+ )
364
+ st.session_state.system_prompt = system_prompt
365
+
366
  instruction_format_prompt = st.text_area(
367
  "Instruction Format Prompt",
368
  value=st.session_state.get(
 
388
  st.write(
389
  "Get started with your current configuration using the SDK code below:"
390
  )
391
+ config_text = f"""#!pip install -Uqq git+https://github.com/gretelai/navigator-helpers.git
392
 
393
+ import logging
394
+ import pandas as pd
395
+ from navigator_helpers import DataSynthesisConfig, TrainingDataSynthesizer
396
+
397
+ # Configure the logger
398
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
399
+
400
+ DATASET = "YOUR_DATASET"
401
+ API_KEY = "YOUR_API_KEY"
402
+
403
+ df = pd.read_csv(DATASET)
404
+
405
+ # Create the data synthesis configuration
406
+ config = DataSynthesisConfig(
407
+ input_fields={st.session_state.selected_fields},
408
+ output_instruction_field="{output_instruction_field}",
409
+ output_response_field="{output_response_field}",
410
+ num_instructions={num_instructions},
411
+ num_responses={num_responses},
412
+ temperature={temperature},
413
+ max_tokens_instruction={max_tokens_instruction},
414
+ max_tokens_response={max_tokens_response},
415
+ api_key=API_KEY,
416
+ navigator_tabular="{navigator_tabular}",
417
+ navigator_llm="{navigator_llm}",
418
+ co_teach_llms={co_teach_llms},
419
+ system_prompt='''{system_prompt}''',
420
+ instruction_format_prompt='''{instruction_format_prompt}''',
421
+ response_format_prompt='''{response_format_prompt}'''
422
+ )
423
+
424
+ # Create the training data synthesizer and perform synthesis
425
+ synthesizer = TrainingDataSynthesizer(
426
+ df,
427
+ config,
428
+ use_aaa={use_aaa},
429
+ output_file="results.csv",
430
+ verbose=True,
431
+ )
432
+ new_df = synthesizer.generate()"""
 
 
 
 
 
 
433
  st.code(config_text, language="python")
434
  st.download_button(
435
  label="Download SDK Code",
436
  data=config_text,
437
+ file_name="data_synthesis_code.py",
438
  mime="text/plain",
439
  )
440
 
 
449
  if "logs" not in st.session_state:
450
  st.session_state.logs = []
451
 
452
+ if "synthetic_data" not in st.session_state:
453
+ st.session_state.synthetic_data = []
454
 
455
  if start_button:
456
+ # Clear the synthetic data and logs before starting a new generation
457
+ st.session_state.synthetic_data = []
458
  st.session_state.logs = []
459
 
460
  with st.expander("Synthetic Data", expanded=True):
461
  st.subheader("Synthetic Data Generation")
462
  progress_bar = st.progress(0)
463
+ tab1, tab2 = st.tabs(["synthetic Data", "Logs"])
464
  with tab1:
465
+ synthetic_data_placeholder = st.empty()
466
  st.info(
467
  "Click on the 'Logs' tab to see and debug real-time logging for each record as it is generated by the agents."
468
  )
 
485
  handler = StreamlitLogHandler(custom_log_handler)
486
  logger.addHandler(handler)
487
 
488
+ config = DataSynthesisConfig(
489
  input_fields=selected_fields,
490
  output_instruction_field=output_instruction_field,
491
  output_response_field=output_response_field,
 
498
  navigator_tabular=navigator_tabular,
499
  navigator_llm=navigator_llm,
500
  co_teach_llms=co_teach_llms,
501
+ system_prompt=system_prompt,
502
  instruction_format_prompt=instruction_format_prompt,
503
  response_format_prompt=response_format_prompt,
504
  )
505
+
506
  start_time = time.time()
507
  with st.spinner("Generating synthetic data..."):
508
  for index in range(num_records):
509
  row = df.iloc[index]
510
+ synthesizer = TrainingDataSynthesizer(
511
  pd.DataFrame([row]),
512
  config,
513
  use_aaa=use_aaa,
514
  output_file="results.csv",
515
  verbose=True,
516
  )
517
+ new_df = synthesizer.generate()
518
+ st.session_state.synthetic_data.append(new_df)
519
+ synthetic_data_placeholder.subheader("Synthetic Data")
520
+ synthetic_data_placeholder.dataframe(
521
  pd.concat(
522
+ st.session_state.synthetic_data, ignore_index=True
523
  )
524
  )
525
  progress = (index + 1) / num_records
 
540
 
541
  time.sleep(0.1)
542
  logger.removeHandler(handler)
543
+ st.success("Data synthetic completed!")
544
  st.stop()
545
 
546
  if stop_button:
547
+ st.warning("Synthesis stopped by the user.")
548
 
549
  # Get the complete logs from the session state
550
  complete_logs = st.session_state.logs
 
552
  # Convert complete logs to JSONL format
553
  log_jsonl = "\n".join([json.dumps({"log": log}) for log in complete_logs])
554
 
555
+ # Convert synthesized data to JSONL format if it exists
556
+ if st.session_state.synthesized_data:
557
+ synthesized_df = pd.concat(
558
+ st.session_state.synthesized_data, ignore_index=True
559
  )
560
+ if not synthesized_df.empty:
561
+ synthesized_data_jsonl = "\n".join(
562
  [
563
  json.dumps(row.to_dict())
564
+ for _, row in synthesized_df.iterrows()
565
  ]
566
  )
567
  else:
568
+ synthesized_data_jsonl = None
569
  else:
570
+ synthesized_data_jsonl = None
571
 
572
  # Create a temporary directory to store the files
573
  with tempfile.TemporaryDirectory() as temp_dir:
 
576
  with open(log_file_path, "w") as log_file:
577
  log_file.write(log_jsonl)
578
 
579
+ # Write the synthesized data to a file if it exists
580
+ if synthesized_data_jsonl:
581
+ synthesized_data_file_path = os.path.join(
582
  temp_dir, "synthetic_data.jsonl"
583
  )
584
+ with open(synthesized_data_file_path, "w") as synthesized_data_file:
585
+ synthesized_data_file.write(synthesized_data_jsonl)
586
 
587
  # Write the SDK code to a file
588
+ sdk_file_path = os.path.join(temp_dir, "data_synthesis_code.py")
589
  with open(sdk_file_path, "w") as sdk_file:
590
  sdk_file.write(config_text)
591
 
592
+ # Create a ZIP file containing the logs, synthesized data, and SDK code
593
+ zip_file_path = os.path.join(temp_dir, "synthesis_results.zip")
594
  with zipfile.ZipFile(zip_file_path, "w") as zip_file:
595
  zip_file.write(log_file_path, "complete_logs.jsonl")
596
+ if synthesized_data_jsonl:
597
+ zip_file.write(
598
+ synthesized_data_file_path, "synthesized_data.jsonl"
599
+ )
600
+ zip_file.write(sdk_file_path, "data_synthesis_code.py")
601
 
602
  # Download the ZIP file
603
  with open(zip_file_path, "rb") as zip_file: