Alexander Watson commited on
Commit
476f41e
1 Parent(s): f402348

fix logging bug, download logs and data

Browse files
Files changed (1) hide show
  1. app.py +94 -19
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import json
2
  import logging
 
 
3
  import time
 
4
  from io import StringIO
5
 
6
  import pandas as pd
@@ -8,8 +11,8 @@ import requests
8
  import streamlit as st
9
  from datasets import load_dataset
10
  from gretel_client import Gretel
11
-
12
- from navigator_helpers import DataAugmentationConfig, DataAugmenter, StreamlitLogHandler
13
 
14
  # Create a StringIO buffer to capture the logging output
15
  log_buffer = StringIO()
@@ -421,9 +424,19 @@ def main():
421
  with col2:
422
  stop_button = st.button("🛑 Stop")
423
 
 
 
 
 
 
 
424
  if start_button:
425
- with st.expander("Augmentation Results", expanded=True):
426
- st.subheader("Augmentation Results")
 
 
 
 
427
  progress_bar = st.progress(0)
428
  tab1, tab2 = st.tabs(["Augmented Data", "Logs"])
429
  with tab1:
@@ -433,20 +446,23 @@ def main():
433
  )
434
  with tab2:
435
  log_container = st.empty()
436
- logs = []
437
  max_log_lines = 50
438
 
439
  def custom_log_handler(msg):
440
- nonlocal logs
441
- logs.append(msg)
442
- if len(logs) > max_log_lines:
443
- logs = logs[-max_log_lines:]
444
- log_text = "\n".join(logs)
445
  log_container.text(log_text)
446
 
447
- handler = StreamlitLogHandler(custom_log_handler)
448
  logger = logging.getLogger("navigator_helpers")
 
 
 
 
 
449
  logger.addHandler(handler)
 
450
  config = DataAugmentationConfig(
451
  input_fields=selected_fields,
452
  output_instruction_field=output_instruction_field,
@@ -463,7 +479,6 @@ def main():
463
  instruction_format_prompt=instruction_format_prompt,
464
  response_format_prompt=response_format_prompt,
465
  )
466
- augmented_data = []
467
  start_time = time.time()
468
  with st.spinner("Generating synthetic data..."):
469
  for index in range(num_records):
@@ -476,10 +491,12 @@ def main():
476
  verbose=True,
477
  )
478
  new_df = augmenter.augment()
479
- augmented_data.append(new_df)
480
- augmented_data_placeholder.subheader("Augmented Data")
481
  augmented_data_placeholder.dataframe(
482
- pd.concat(augmented_data, ignore_index=True)
 
 
483
  )
484
  progress = (index + 1) / num_records
485
  progress_bar.progress(progress)
@@ -500,13 +517,71 @@ def main():
500
  time.sleep(0.1)
501
  logger.removeHandler(handler)
502
  st.success("Data augmentation completed!")
 
 
503
  if stop_button:
504
  st.warning("Augmentation stopped by the user.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  st.stop()
506
- else:
507
- st.info(
508
- "Please upload a file or select a dataset from Hugging Face to proceed."
509
- )
 
510
 
511
 
512
  if __name__ == "__main__":
 
1
  import json
2
  import logging
3
+ import os
4
+ import tempfile
5
  import time
6
+ import zipfile
7
  from io import StringIO
8
 
9
  import pandas as pd
 
11
  import streamlit as st
12
  from datasets import load_dataset
13
  from gretel_client import Gretel
14
+ from navigator_helpers import (DataAugmentationConfig, DataAugmenter,
15
+ StreamlitLogHandler)
16
 
17
  # Create a StringIO buffer to capture the logging output
18
  log_buffer = StringIO()
 
424
  with col2:
425
  stop_button = st.button("🛑 Stop")
426
 
427
+ if "logs" not in st.session_state:
428
+ st.session_state.logs = []
429
+
430
+ if "augmented_data" not in st.session_state:
431
+ st.session_state.augmented_data = []
432
+
433
  if start_button:
434
+ # Clear the augmented data and logs before starting a new generation
435
+ st.session_state.augmented_data = []
436
+ st.session_state.logs = []
437
+
438
+ with st.expander("Synthetic Data", expanded=True):
439
+ st.subheader("Synthetic Data Generation")
440
  progress_bar = st.progress(0)
441
  tab1, tab2 = st.tabs(["Augmented Data", "Logs"])
442
  with tab1:
 
446
  )
447
  with tab2:
448
  log_container = st.empty()
 
449
  max_log_lines = 50
450
 
451
  def custom_log_handler(msg):
452
+ st.session_state.logs.append(msg)
453
+ displayed_logs = st.session_state.logs[-max_log_lines:]
454
+ log_text = "\n".join(displayed_logs)
 
 
455
  log_container.text(log_text)
456
 
457
+ # Remove the previous log handler if it exists
458
  logger = logging.getLogger("navigator_helpers")
459
+ for handler in logger.handlers:
460
+ if isinstance(handler, StreamlitLogHandler):
461
+ logger.removeHandler(handler)
462
+
463
+ handler = StreamlitLogHandler(custom_log_handler)
464
  logger.addHandler(handler)
465
+
466
  config = DataAugmentationConfig(
467
  input_fields=selected_fields,
468
  output_instruction_field=output_instruction_field,
 
479
  instruction_format_prompt=instruction_format_prompt,
480
  response_format_prompt=response_format_prompt,
481
  )
 
482
  start_time = time.time()
483
  with st.spinner("Generating synthetic data..."):
484
  for index in range(num_records):
 
491
  verbose=True,
492
  )
493
  new_df = augmenter.augment()
494
+ st.session_state.augmented_data.append(new_df)
495
+ augmented_data_placeholder.subheader("Synthetic Data")
496
  augmented_data_placeholder.dataframe(
497
+ pd.concat(
498
+ st.session_state.augmented_data, ignore_index=True
499
+ )
500
  )
501
  progress = (index + 1) / num_records
502
  progress_bar.progress(progress)
 
517
  time.sleep(0.1)
518
  logger.removeHandler(handler)
519
  st.success("Data augmentation completed!")
520
+ st.stop()
521
+
522
  if stop_button:
523
  st.warning("Augmentation stopped by the user.")
524
+
525
+ # Get the complete logs from the session state
526
+ complete_logs = st.session_state.logs
527
+
528
+ # Convert complete logs to JSONL format
529
+ log_jsonl = "\n".join([json.dumps({"log": log}) for log in complete_logs])
530
+
531
+ # Convert augmented data to JSONL format if it exists
532
+ if st.session_state.augmented_data:
533
+ augmented_df = pd.concat(
534
+ st.session_state.augmented_data, ignore_index=True
535
+ )
536
+ if not augmented_df.empty:
537
+ augmented_data_jsonl = "\n".join(
538
+ [
539
+ json.dumps(row.to_dict())
540
+ for _, row in augmented_df.iterrows()
541
+ ]
542
+ )
543
+ else:
544
+ augmented_data_jsonl = None
545
+ else:
546
+ augmented_data_jsonl = None
547
+
548
+ # Create a temporary directory to store the files
549
+ with tempfile.TemporaryDirectory() as temp_dir:
550
+ # Write the complete logs to a file
551
+ log_file_path = os.path.join(temp_dir, "complete_logs.jsonl")
552
+ with open(log_file_path, "w") as log_file:
553
+ log_file.write(log_jsonl)
554
+
555
+ # Write the augmented data to a file if it exists
556
+ if augmented_data_jsonl:
557
+ augmented_data_file_path = os.path.join(
558
+ temp_dir, "augmented_data.jsonl"
559
+ )
560
+ with open(augmented_data_file_path, "w") as augmented_data_file:
561
+ augmented_data_file.write(augmented_data_jsonl)
562
+
563
+ # Create a ZIP file containing the logs and augmented data
564
+ zip_file_path = os.path.join(temp_dir, "augmentation_results.zip")
565
+ with zipfile.ZipFile(zip_file_path, "w") as zip_file:
566
+ zip_file.write(log_file_path, "complete_logs.jsonl")
567
+ if augmented_data_jsonl:
568
+ zip_file.write(augmented_data_file_path, "augmented_data.jsonl")
569
+
570
+ # Download the ZIP file
571
+ with open(zip_file_path, "rb") as zip_file:
572
+ st.download_button(
573
+ label="Download Synthetic Data and Logs",
574
+ data=zip_file.read(),
575
+ file_name="augmentation_results.zip",
576
+ mime="application/zip",
577
+ )
578
+
579
  st.stop()
580
+
581
+ else:
582
+ st.info(
583
+ "Please upload a file or select a dataset from Hugging Face to proceed."
584
+ )
585
 
586
 
587
  if __name__ == "__main__":