Alexander Watson
commited on
Commit
β’
476f41e
1
Parent(s):
f402348
fix logging bug, download logs and data
Browse files
app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import json
|
2 |
import logging
|
|
|
|
|
3 |
import time
|
|
|
4 |
from io import StringIO
|
5 |
|
6 |
import pandas as pd
|
@@ -8,8 +11,8 @@ import requests
|
|
8 |
import streamlit as st
|
9 |
from datasets import load_dataset
|
10 |
from gretel_client import Gretel
|
11 |
-
|
12 |
-
|
13 |
|
14 |
# Create a StringIO buffer to capture the logging output
|
15 |
log_buffer = StringIO()
|
@@ -421,9 +424,19 @@ def main():
|
|
421 |
with col2:
|
422 |
stop_button = st.button("π Stop")
|
423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
if start_button:
|
425 |
-
|
426 |
-
|
|
|
|
|
|
|
|
|
427 |
progress_bar = st.progress(0)
|
428 |
tab1, tab2 = st.tabs(["Augmented Data", "Logs"])
|
429 |
with tab1:
|
@@ -433,20 +446,23 @@ def main():
|
|
433 |
)
|
434 |
with tab2:
|
435 |
log_container = st.empty()
|
436 |
-
logs = []
|
437 |
max_log_lines = 50
|
438 |
|
439 |
def custom_log_handler(msg):
|
440 |
-
|
441 |
-
logs
|
442 |
-
|
443 |
-
logs = logs[-max_log_lines:]
|
444 |
-
log_text = "\n".join(logs)
|
445 |
log_container.text(log_text)
|
446 |
|
447 |
-
handler
|
448 |
logger = logging.getLogger("navigator_helpers")
|
|
|
|
|
|
|
|
|
|
|
449 |
logger.addHandler(handler)
|
|
|
450 |
config = DataAugmentationConfig(
|
451 |
input_fields=selected_fields,
|
452 |
output_instruction_field=output_instruction_field,
|
@@ -463,7 +479,6 @@ def main():
|
|
463 |
instruction_format_prompt=instruction_format_prompt,
|
464 |
response_format_prompt=response_format_prompt,
|
465 |
)
|
466 |
-
augmented_data = []
|
467 |
start_time = time.time()
|
468 |
with st.spinner("Generating synthetic data..."):
|
469 |
for index in range(num_records):
|
@@ -476,10 +491,12 @@ def main():
|
|
476 |
verbose=True,
|
477 |
)
|
478 |
new_df = augmenter.augment()
|
479 |
-
augmented_data.append(new_df)
|
480 |
-
augmented_data_placeholder.subheader("
|
481 |
augmented_data_placeholder.dataframe(
|
482 |
-
pd.concat(
|
|
|
|
|
483 |
)
|
484 |
progress = (index + 1) / num_records
|
485 |
progress_bar.progress(progress)
|
@@ -500,13 +517,71 @@ def main():
|
|
500 |
time.sleep(0.1)
|
501 |
logger.removeHandler(handler)
|
502 |
st.success("Data augmentation completed!")
|
|
|
|
|
503 |
if stop_button:
|
504 |
st.warning("Augmentation stopped by the user.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
st.stop()
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
510 |
|
511 |
|
512 |
if __name__ == "__main__":
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
import time
|
6 |
+
import zipfile
|
7 |
from io import StringIO
|
8 |
|
9 |
import pandas as pd
|
|
|
11 |
import streamlit as st
|
12 |
from datasets import load_dataset
|
13 |
from gretel_client import Gretel
|
14 |
+
from navigator_helpers import (DataAugmentationConfig, DataAugmenter,
|
15 |
+
StreamlitLogHandler)
|
16 |
|
17 |
# Create a StringIO buffer to capture the logging output
|
18 |
log_buffer = StringIO()
|
|
|
424 |
with col2:
|
425 |
stop_button = st.button("π Stop")
|
426 |
|
427 |
+
if "logs" not in st.session_state:
|
428 |
+
st.session_state.logs = []
|
429 |
+
|
430 |
+
if "augmented_data" not in st.session_state:
|
431 |
+
st.session_state.augmented_data = []
|
432 |
+
|
433 |
if start_button:
|
434 |
+
# Clear the augmented data and logs before starting a new generation
|
435 |
+
st.session_state.augmented_data = []
|
436 |
+
st.session_state.logs = []
|
437 |
+
|
438 |
+
with st.expander("Synthetic Data", expanded=True):
|
439 |
+
st.subheader("Synthetic Data Generation")
|
440 |
progress_bar = st.progress(0)
|
441 |
tab1, tab2 = st.tabs(["Augmented Data", "Logs"])
|
442 |
with tab1:
|
|
|
446 |
)
|
447 |
with tab2:
|
448 |
log_container = st.empty()
|
|
|
449 |
max_log_lines = 50
|
450 |
|
451 |
def custom_log_handler(msg):
|
452 |
+
st.session_state.logs.append(msg)
|
453 |
+
displayed_logs = st.session_state.logs[-max_log_lines:]
|
454 |
+
log_text = "\n".join(displayed_logs)
|
|
|
|
|
455 |
log_container.text(log_text)
|
456 |
|
457 |
+
# Remove the previous log handler if it exists
|
458 |
logger = logging.getLogger("navigator_helpers")
|
459 |
+
for handler in logger.handlers:
|
460 |
+
if isinstance(handler, StreamlitLogHandler):
|
461 |
+
logger.removeHandler(handler)
|
462 |
+
|
463 |
+
handler = StreamlitLogHandler(custom_log_handler)
|
464 |
logger.addHandler(handler)
|
465 |
+
|
466 |
config = DataAugmentationConfig(
|
467 |
input_fields=selected_fields,
|
468 |
output_instruction_field=output_instruction_field,
|
|
|
479 |
instruction_format_prompt=instruction_format_prompt,
|
480 |
response_format_prompt=response_format_prompt,
|
481 |
)
|
|
|
482 |
start_time = time.time()
|
483 |
with st.spinner("Generating synthetic data..."):
|
484 |
for index in range(num_records):
|
|
|
491 |
verbose=True,
|
492 |
)
|
493 |
new_df = augmenter.augment()
|
494 |
+
st.session_state.augmented_data.append(new_df)
|
495 |
+
augmented_data_placeholder.subheader("Synthetic Data")
|
496 |
augmented_data_placeholder.dataframe(
|
497 |
+
pd.concat(
|
498 |
+
st.session_state.augmented_data, ignore_index=True
|
499 |
+
)
|
500 |
)
|
501 |
progress = (index + 1) / num_records
|
502 |
progress_bar.progress(progress)
|
|
|
517 |
time.sleep(0.1)
|
518 |
logger.removeHandler(handler)
|
519 |
st.success("Data augmentation completed!")
|
520 |
+
st.stop()
|
521 |
+
|
522 |
if stop_button:
|
523 |
st.warning("Augmentation stopped by the user.")
|
524 |
+
|
525 |
+
# Get the complete logs from the session state
|
526 |
+
complete_logs = st.session_state.logs
|
527 |
+
|
528 |
+
# Convert complete logs to JSONL format
|
529 |
+
log_jsonl = "\n".join([json.dumps({"log": log}) for log in complete_logs])
|
530 |
+
|
531 |
+
# Convert augmented data to JSONL format if it exists
|
532 |
+
if st.session_state.augmented_data:
|
533 |
+
augmented_df = pd.concat(
|
534 |
+
st.session_state.augmented_data, ignore_index=True
|
535 |
+
)
|
536 |
+
if not augmented_df.empty:
|
537 |
+
augmented_data_jsonl = "\n".join(
|
538 |
+
[
|
539 |
+
json.dumps(row.to_dict())
|
540 |
+
for _, row in augmented_df.iterrows()
|
541 |
+
]
|
542 |
+
)
|
543 |
+
else:
|
544 |
+
augmented_data_jsonl = None
|
545 |
+
else:
|
546 |
+
augmented_data_jsonl = None
|
547 |
+
|
548 |
+
# Create a temporary directory to store the files
|
549 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
550 |
+
# Write the complete logs to a file
|
551 |
+
log_file_path = os.path.join(temp_dir, "complete_logs.jsonl")
|
552 |
+
with open(log_file_path, "w") as log_file:
|
553 |
+
log_file.write(log_jsonl)
|
554 |
+
|
555 |
+
# Write the augmented data to a file if it exists
|
556 |
+
if augmented_data_jsonl:
|
557 |
+
augmented_data_file_path = os.path.join(
|
558 |
+
temp_dir, "augmented_data.jsonl"
|
559 |
+
)
|
560 |
+
with open(augmented_data_file_path, "w") as augmented_data_file:
|
561 |
+
augmented_data_file.write(augmented_data_jsonl)
|
562 |
+
|
563 |
+
# Create a ZIP file containing the logs and augmented data
|
564 |
+
zip_file_path = os.path.join(temp_dir, "augmentation_results.zip")
|
565 |
+
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
|
566 |
+
zip_file.write(log_file_path, "complete_logs.jsonl")
|
567 |
+
if augmented_data_jsonl:
|
568 |
+
zip_file.write(augmented_data_file_path, "augmented_data.jsonl")
|
569 |
+
|
570 |
+
# Download the ZIP file
|
571 |
+
with open(zip_file_path, "rb") as zip_file:
|
572 |
+
st.download_button(
|
573 |
+
label="Download Synthetic Data and Logs",
|
574 |
+
data=zip_file.read(),
|
575 |
+
file_name="augmentation_results.zip",
|
576 |
+
mime="application/zip",
|
577 |
+
)
|
578 |
+
|
579 |
st.stop()
|
580 |
+
|
581 |
+
else:
|
582 |
+
st.info(
|
583 |
+
"Please upload a file or select a dataset from Hugging Face to proceed."
|
584 |
+
)
|
585 |
|
586 |
|
587 |
if __name__ == "__main__":
|