|
import streamlit as st |
|
from streamlit.errors import EntryNotFoundError |
|
|
|
from hub import pull_seed_data_from_repo, push_pipeline_to_hub |
|
from defaults import ( |
|
DEFAULT_SYSTEM_PROMPT, |
|
PIPELINE_PATH, |
|
PROJECT_NAME, |
|
ARGILLA_SPACE_REPO_ID, |
|
DATASET_REPO_ID, |
|
ARGILLA_SPACE_NAME, |
|
ARGILLA_URL, |
|
PROJECT_SPACE_REPO_ID, |
|
HUB_USERNAME, |
|
) |
|
from utils import project_sidebar |
|
|
|
from pipeline import serialize_pipeline, run_pipeline, create_pipelines_run_command |
|
|
|
st.set_page_config( |
|
page_title="Domain Data Grower", |
|
page_icon="π§βπΎ", |
|
) |
|
|
|
project_sidebar() |
|
|
|
|
|
|
|
|
|
|
|
st.header("π§βπΎ Domain Data Grower") |
|
st.divider() |
|
st.subheader("Step 3. Run the pipeline to generate synthetic data") |
|
st.write( |
|
"Define the project details, including the project name, domain, and API credentials" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.divider() |
|
|
|
st.markdown("### Pipeline Configuration") |
|
|
|
st.write("π€ Hub details to pull the seed data") |
|
hub_username = st.text_input("Hub Username", HUB_USERNAME) |
|
project_name = st.text_input("Project Name", PROJECT_NAME) |
|
repo_id = f"{hub_username}/{project_name}" |
|
hub_token = st.text_input("Hub Token", type="password") |
|
|
|
st.write("π€ Inference configuration") |
|
|
|
st.write( |
|
"Add the url of the Huggingface inference API or endpoint that your pipeline should use. You can find compatible models here:" |
|
) |
|
st.link_button( |
|
"π€ Inference compaptible models on the hub", |
|
"https://huggingface.co/models?pipeline_tag=text-generation&other=endpoints_compatible&sort=trending", |
|
) |
|
|
|
base_url = st.text_input("Base URL") |
|
|
|
st.write("π¬ Argilla API details to push the generated dataset") |
|
argilla_url = st.text_input("Argilla API URL", ARGILLA_URL) |
|
argilla_api_key = st.text_input("Argilla API Key", "owner.apikey") |
|
argilla_dataset_name = st.text_input("Argilla Dataset Name", project_name) |
|
st.divider() |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("### Run the pipeline") |
|
|
|
st.write( |
|
"Once you've defined the pipeline configuration, you can run the pipeline locally or on this space." |
|
) |
|
|
|
st.write( |
|
"""We recommend running the pipeline locally if you're planning on generating a large dataset. \ |
|
But running the pipeline on this space is a handy way to get started quickly. Your synthetic |
|
samples will be pushed to Argilla and available for review. |
|
""" |
|
) |
|
st.write( |
|
"""If you're planning on running the pipeline on the space, be aware that it \ |
|
will take some time to complete and you will need to maintain a \ |
|
connection to the space.""" |
|
) |
|
|
|
|
|
if st.button("π» Run pipeline locally", key="run_pipeline_local"): |
|
if all( |
|
[ |
|
argilla_api_key, |
|
argilla_url, |
|
base_url, |
|
hub_username, |
|
project_name, |
|
hub_token, |
|
argilla_dataset_name, |
|
] |
|
): |
|
with st.spinner("Pulling seed data from the Hub..."): |
|
seed_data = pull_seed_data_from_repo( |
|
repo_id=f"{hub_username}/{project_name}", |
|
hub_token=hub_token, |
|
) |
|
|
|
domain = seed_data["domain"] |
|
perspectives = seed_data["perspectives"] |
|
topics = seed_data["topics"] |
|
examples = seed_data["examples"] |
|
domain_expert_prompt = seed_data["domain_expert_prompt"] |
|
|
|
with st.spinner("Serializing the pipeline configuration..."): |
|
serialize_pipeline( |
|
argilla_api_key=argilla_api_key, |
|
argilla_dataset_name=argilla_dataset_name, |
|
argilla_api_url=argilla_url, |
|
topics=topics, |
|
perspectives=perspectives, |
|
pipeline_config_path=PIPELINE_PATH, |
|
domain_expert_prompt=domain_expert_prompt or DEFAULT_SYSTEM_PROMPT, |
|
hub_token=hub_token, |
|
endpoint_base_url=base_url, |
|
examples=examples, |
|
) |
|
push_pipeline_to_hub( |
|
pipeline_path=PIPELINE_PATH, |
|
hub_token=hub_token, |
|
hub_username=hub_username, |
|
project_name=project_name, |
|
) |
|
|
|
st.success(f"Pipeline configuration saved to {hub_username}/{project_name}") |
|
|
|
st.info( |
|
"To run the pipeline locally, you need to have the `distilabel` library installed. You can install it using the following command:" |
|
) |
|
st.text( |
|
"Execute the following command to generate a synthetic dataset from the seed data:" |
|
) |
|
command_to_run = create_pipelines_run_command( |
|
hub_token=hub_token, |
|
pipeline_config_path=PIPELINE_PATH, |
|
argilla_dataset_name=argilla_dataset_name, |
|
) |
|
st.code( |
|
f""" |
|
pip install git+https://github.com/argilla-io/distilabel.git |
|
git clone https://huggingface.co/{hub_username}/{project_name} |
|
cd {project_name} |
|
{' '.join(command_to_run[2:])} |
|
""", |
|
language="bash", |
|
) |
|
else: |
|
st.error("Please fill all the required fields.") |
|
|
|
|
|
|
|
|
|
|
|
if st.button("π₯ Run pipeline right here, right now!"): |
|
if all( |
|
[ |
|
argilla_api_key, |
|
argilla_url, |
|
base_url, |
|
hub_username, |
|
project_name, |
|
hub_token, |
|
argilla_dataset_name, |
|
] |
|
): |
|
with st.spinner("Pulling seed data from the Hub..."): |
|
try: |
|
seed_data = pull_seed_data_from_repo( |
|
repo_id=f"{hub_username}/{project_name}", |
|
hub_token=hub_token, |
|
) |
|
except EntryNotFoundError: |
|
st.error( |
|
"Seed data not found. Please make sure you pushed the data seed in Step 2." |
|
) |
|
|
|
domain = seed_data["domain"] |
|
perspectives = seed_data["perspectives"] |
|
topics = seed_data["topics"] |
|
examples = seed_data["examples"] |
|
domain_expert_prompt = seed_data["domain_expert_prompt"] |
|
|
|
with st.spinner("Serializing the pipeline configuration..."): |
|
serialize_pipeline( |
|
argilla_api_key=argilla_api_key, |
|
argilla_dataset_name=argilla_dataset_name, |
|
argilla_api_url=argilla_url, |
|
topics=topics, |
|
perspectives=perspectives, |
|
pipeline_config_path=PIPELINE_PATH, |
|
domain_expert_prompt=domain_expert_prompt or DEFAULT_SYSTEM_PROMPT, |
|
hub_token=hub_token, |
|
endpoint_base_url=base_url, |
|
examples=examples, |
|
) |
|
|
|
with st.spinner("Starting the pipeline..."): |
|
logs = run_pipeline(PIPELINE_PATH) |
|
|
|
st.success(f"Pipeline started successfully! π") |
|
|
|
with st.expander(label="View Logs", expanded=True): |
|
for out in logs: |
|
st.text(out) |
|
else: |
|
st.error("Please fill all the required fields.") |
|
|