Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import re | |
| import json | |
| import time | |
| import pandas as pd | |
| import labelbox | |
| def fetch_databases(cluster_id, formatted_title, databricks_api_key): | |
| query = "SHOW DATABASES;" | |
| return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| # Cached function to fetch tables | |
| def fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key): | |
| query = f"SHOW TABLES IN {selected_database};" | |
| return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| # Cached function to fetch columns | |
| def fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key): | |
| query = f"SHOW COLUMNS IN {selected_database}.{selected_table};" | |
| return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| def validate_dataset_name(name): | |
| """Validate the dataset name.""" | |
| # Check length | |
| if len(name) > 256: | |
| return "Dataset name should be limited to 256 characters." | |
| # Check allowed characters | |
| allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$') | |
| if not allowed_characters_pattern.match(name): | |
| return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.") | |
| return None | |
| def create_new_dataset_labelbox (new_dataset_name): | |
| client = labelbox.Client(api_key=labelbox_api_key) | |
| dataset_name = new_dataset_name | |
| dataset = client.create_dataset(name=dataset_name) | |
| dataset_id = dataset.uid | |
| return dataset_id | |
| def get_dataset_from_labelbox(labelbox_api_key): | |
| client = labelbox.Client(api_key=labelbox_api_key) | |
| datasets = client.get_datasets() | |
| return datasets | |
| def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key): | |
| DOMAIN = f"https://{domain}" | |
| TOKEN = f"Bearer {databricks_api_key}" | |
| headers = { | |
| "Authorization": TOKEN, | |
| "Content-Type": "application/json", | |
| } | |
| # Destroy context | |
| destroy_payload = { | |
| "clusterId": cluster_id, | |
| "contextId": context_id | |
| } | |
| destroy_response = requests.post( | |
| f"{DOMAIN}/api/1.2/contexts/destroy", | |
| headers=headers, | |
| data=json.dumps(destroy_payload) | |
| ) | |
| if destroy_response.status_code != 200: | |
| raise ValueError("Failed to destroy context.") | |
| def execute_databricks_query(query, cluster_id, domain, databricks_api_key): | |
| DOMAIN = f"https://{domain}" | |
| TOKEN = f"Bearer {databricks_api_key}" | |
| headers = { | |
| "Authorization": TOKEN, | |
| "Content-Type": "application/json", | |
| } | |
| # Create context | |
| context_payload = { | |
| "clusterId": cluster_id, | |
| "language": "sql" | |
| } | |
| context_response = requests.post( | |
| f"{DOMAIN}/api/1.2/contexts/create", | |
| headers=headers, | |
| data=json.dumps(context_payload) | |
| ) | |
| context_response_data = context_response.json() | |
| if 'id' not in context_response_data: | |
| raise ValueError("Failed to create context.") | |
| context_id = context_response_data['id'] | |
| # Execute query | |
| command_payload = { | |
| "clusterId": cluster_id, | |
| "contextId": context_id, | |
| "language": "sql", | |
| "command": query | |
| } | |
| command_response = requests.post( | |
| f"{DOMAIN}/api/1.2/commands/execute", | |
| headers=headers, | |
| data=json.dumps(command_payload) | |
| ).json() | |
| if 'id' not in command_response: | |
| raise ValueError("Failed to execute command.") | |
| command_id = command_response['id'] | |
| # Wait for the command to complete | |
| while True: | |
| status_response = requests.get( | |
| f"{DOMAIN}/api/1.2/commands/status", | |
| headers=headers, | |
| params={ | |
| "clusterId": cluster_id, | |
| "contextId": context_id, | |
| "commandId": command_id | |
| } | |
| ).json() | |
| command_status = status_response.get("status") | |
| if command_status == "Finished": | |
| break | |
| elif command_status in ["Error", "Cancelled"]: | |
| raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}") | |
| else: | |
| time.sleep(1) # Wait for 5 seconds before checking again | |
| # Convert the results into a pandas DataFrame | |
| data = status_response.get('results', {}).get('data', []) | |
| columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])] | |
| df = pd.DataFrame(data, columns=columns) | |
| destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key) | |
| return df | |
| st.title("Labelbox π€ Databricks") | |
| st.header("Pipeline Creator", divider='rainbow') | |
| def is_valid_url_or_uri(value): | |
| """Check if the provided value is a valid URL or URI.""" | |
| # Check general URLs | |
| url_pattern = re.compile( | |
| r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' | |
| ) | |
| # Check general URIs including cloud storage URIs (like gs://, s3://, etc.) | |
| uri_pattern = re.compile( | |
| r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+' | |
| ) | |
| return url_pattern.match(value) or uri_pattern.match(value) | |
| is_preview = st.toggle('Run in Preview Mode', value=False) | |
| if is_preview: | |
| st.success('Running in Preview mode!', icon="β ") | |
| else: | |
| st.success('Running in Production mode!', icon="β ") | |
| st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey') | |
| #cloud = "GCP" | |
| cloud = st.selectbox('Which cloud environment does your Databricks Workspace run in?', ['AWS', 'Azure', 'GCP'], index=None) | |
| title = st.text_input('Enter Databricks Domain (e.g., <instance>.<cloud>.databricks.com)', '') | |
| databricks_api_key = st.text_input('Databricks API Key', type='password') | |
| labelbox_api_key = st.text_input('Labelbox API Key', type='password') | |
| # After Labelbox API key is entered | |
| if labelbox_api_key: | |
| # Fetching datasets | |
| datasets = get_dataset_from_labelbox(labelbox_api_key) | |
| create_new_dataset = st.toggle("Make me a new dataset", value=False) | |
| if not create_new_dataset: | |
| # The existing logic for selecting datasets goes here. | |
| dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets} | |
| selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys())) | |
| dataset_id = dataset_name_to_id[selected_dataset_name] | |
| else: | |
| # If user toggles "make me a new dataset" | |
| new_dataset_name = st.text_input("Enter the new dataset name:") | |
| # Check if the name is valid | |
| if new_dataset_name: | |
| validation_message = validate_dataset_name(new_dataset_name) | |
| if validation_message: | |
| st.error(validation_message, icon="π«") | |
| else: | |
| st.success(f"Valid dataset name! Dataset_id", icon="β ") | |
| dataset_name = new_dataset_name | |
| # Define the variables beforehand with default values (if not defined) | |
| new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None | |
| selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None | |
| if new_dataset_name or selected_dataset_name: | |
| # Handling various formats of input | |
| formatted_title = re.sub(r'^https?://', '', title) # Remove http:// or https:// | |
| formatted_title = re.sub(r'/$', '', formatted_title) # Remove trailing slash if present | |
| if formatted_title: | |
| st.subheader("Select an existing cluster", divider='grey', help="Jobs will use job clusters to reduce DBUs consumed.") | |
| DOMAIN = f"https://{formatted_title}" | |
| TOKEN = f"Bearer {databricks_api_key}" | |
| HEADERS = { | |
| "Authorization": TOKEN, | |
| "Content-Type": "application/json", | |
| } | |
| # Endpoint to list clusters | |
| ENDPOINT = "/api/2.0/clusters/list" | |
| try: | |
| response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS) | |
| response.raise_for_status() | |
| # Include clusters with cluster_source "UI" or "API" | |
| clusters = response.json().get("clusters", []) | |
| cluster_dict = { | |
| cluster["cluster_name"]: cluster["cluster_id"] | |
| for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"] | |
| } | |
| # Display dropdown with cluster names | |
| if cluster_dict: | |
| selected_cluster_name = st.selectbox( | |
| 'Select a cluster to run on', | |
| list(cluster_dict.keys()), | |
| key='unique_key_for_cluster_selectbox', | |
| index=None, | |
| placeholder="Select a cluster..", | |
| ) | |
| if selected_cluster_name: | |
| cluster_id = cluster_dict[selected_cluster_name] | |
| except requests.RequestException as e: | |
| st.write(f"Error communicating with Databricks API: {str(e)}") | |
| except ValueError: | |
| st.write("Received unexpected response from Databricks API.") | |
| if selected_cluster_name and cluster_id: | |
| # Check if the selected cluster is running | |
| cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0] | |
| # If the cluster is not running, start it | |
| if cluster_state != "RUNNING": | |
| with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."): | |
| start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id}) | |
| start_response.raise_for_status() | |
| # Poll until the cluster is up or until timeout | |
| start_time = time.time() | |
| timeout = 1200 # 20 minutes in seconds | |
| while True: | |
| cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json() | |
| if "state" in cluster_response: | |
| if cluster_response["state"] == "RUNNING": | |
| break | |
| elif cluster_response["state"] in ["TERMINATED", "ERROR"]: | |
| st.write(f"Error starting cluster. Current state: {cluster_response['state']}") | |
| break | |
| if (time.time() - start_time) > timeout: | |
| st.write("Timeout reached while starting the cluster.") | |
| break | |
| time.sleep(10) # Check every 10 seconds | |
| st.success(f"{selected_cluster_name} is now running!", icon="πββοΈ") | |
| else: | |
| st.success(f"{selected_cluster_name} is already running!", icon="πββοΈ") | |
| def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): | |
| """ | |
| Generate a cron expression based on the provided frequency and time. | |
| """ | |
| if freq == "1 minute": | |
| return "0 * * * * ?" | |
| elif freq == "1 hour": | |
| return f"0 {minute} * * * ?" | |
| elif freq == "1 day": | |
| return f"0 {minute} {hour} * * ?" | |
| elif freq == "1 week": | |
| if not day_of_week: | |
| raise ValueError("Day of week not provided for weekly frequency.") | |
| return f"0 {minute} {hour} ? * {day_of_week}" | |
| elif freq == "1 month": | |
| if not day_of_month: | |
| raise ValueError("Day of month not provided for monthly frequency.") | |
| return f"0 {minute} {hour} {day_of_month} * ?" | |
| else: | |
| raise ValueError("Invalid frequency provided") | |
| # Streamlit UI | |
| st.subheader("Run Frequency", divider='grey') | |
| # Dropdown to select frequency | |
| freq_options = ["1 day", "1 week", "1 month"] | |
| selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..") | |
| day_of_week = None | |
| day_of_month = None | |
| # If the frequency is hourly, daily, weekly, or monthly, ask for a specific time | |
| if selected_freq != "1 minute": | |
| if selected_freq == "1 week": | |
| days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"] | |
| day_of_week = st.selectbox("Select day of the week:", days_options) | |
| elif selected_freq == "1 month": | |
| day_of_month = st.selectbox("Select day of the month:", list(range(1, 32))) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| hour = st.selectbox("Hour:", list(range(0, 24))) | |
| with col2: | |
| minute = st.selectbox("Minute:", list(range(0, 60))) | |
| else: | |
| hour, minute = 0, 0 | |
| # Generate the cron expression | |
| frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month) | |
| # Assumed DBU consumption rate for a 32GB, 4-core node per hour | |
| X = 1 # Replace this with the actual rate from Databricks' pricing or documentation | |
| # Calculate DBU consumption for a single run | |
| min_dbu_single_run = (X/6) * (1 + 10) # Assuming maximum scaling to 10 workers | |
| max_dbu_single_run = (2*X/3) * (1 + 10) | |
| # Estimate monthly DBU consumption based on frequency | |
| if freq_options == "1 day": | |
| min_dbu_monthly = 30 * min_dbu_single_run | |
| max_dbu_monthly = 30 * max_dbu_single_run | |
| elif freq_options == "1 week": | |
| min_dbu_monthly = 4 * min_dbu_single_run | |
| max_dbu_monthly = 4 * max_dbu_single_run | |
| else: # Monthly | |
| min_dbu_monthly = min_dbu_single_run | |
| max_dbu_monthly = max_dbu_single_run | |
| # Calculate runs per month | |
| if selected_freq == "1 day": | |
| runs_per_month = 30 | |
| elif selected_freq == "1 week": | |
| runs_per_month = 4 | |
| else: # "1 month" | |
| runs_per_month = 1 | |
| # Calculate estimated DBU consumption per month | |
| min_dbu_monthly = runs_per_month * min_dbu_single_run | |
| max_dbu_monthly = runs_per_month * max_dbu_single_run | |
| def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): | |
| """ | |
| Generate a human-readable message for the scheduling. | |
| """ | |
| if freq == "1 minute": | |
| return "Job will run every minute." | |
| elif freq == "1 hour": | |
| return f"Job will run once an hour at minute {minute}." | |
| elif freq == "1 day": | |
| return f"Job will run daily at {hour:02}:{minute:02}." | |
| elif freq == "1 week": | |
| if not day_of_week: | |
| raise ValueError("Day of week not provided for weekly frequency.") | |
| return f"Job will run every {day_of_week} at {hour:02}:{minute:02}." | |
| elif freq == "1 month": | |
| if not day_of_month: | |
| raise ValueError("Day of month not provided for monthly frequency.") | |
| return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}." | |
| else: | |
| raise ValueError("Invalid frequency provided") | |
| # Generate the human-readable message | |
| readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month) | |
| # Main code block | |
| if frequency: | |
| st.success(readable_msg, icon="π ") | |
| # Display the estimated DBU consumption to the user | |
| st.warning(f"Estimated DBU Consumption:\n- For a single run: {min_dbu_single_run:.2f} to {max_dbu_single_run:.2f} DBUs\n- Monthly (based on {runs_per_month} runs): {min_dbu_monthly:.2f} to {max_dbu_monthly:.2f} DBUs") | |
| # Disclaimer | |
| st.info("Disclaimer: This is only an estimation. Always monitor the job in Databricks to assess actual DBU consumption.") | |
| st.subheader("Select a table", divider="grey") | |
| # Fetching databases | |
| result_data = fetch_databases(cluster_id, formatted_title, databricks_api_key) | |
| database_names = result_data['databaseName'].tolist() | |
| selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..") | |
| if selected_database: | |
| # Fetching tables | |
| result_data = fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key) | |
| table_names = result_data['tableName'].tolist() | |
| selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..") | |
| if selected_table: | |
| # Fetching columns | |
| result_data = fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key) | |
| column_names = result_data['col_name'].tolist() | |
| st.subheader("Map table schema to Labelbox schema", divider="grey") | |
| # Your existing code to handle schema mapping... | |
| # Fetch the first 5 rows of the selected table | |
| with st.spinner('Fetching first 5 rows of the selected table...'): | |
| query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;" | |
| table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| st.write(table_sample_data) | |
| # Define two columns for side-by-side selectboxes | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_row_data = st.selectbox( | |
| "row_data (required):", | |
| column_names, | |
| index=None, | |
| placeholder="Select a column..", | |
| help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox." | |
| ) | |
| with col2: | |
| selected_global_key = st.selectbox( | |
| "global_key (optional):", | |
| column_names, | |
| index=None, | |
| placeholder="Select a column..", | |
| help="Select the column that contains the global key. If not provided, a new key will be generated for you." | |
| ) | |
| # Fetch a single row from the selected table | |
| query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;" | |
| result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key) | |
| if selected_row_data: | |
| # Extract the value from the selected row_data column | |
| sample_row_data_value = result_sample[selected_row_data].iloc[0] | |
| # Validate the extracted value | |
| dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id | |
| # Mode | |
| mode = "preview" if is_preview else "production" | |
| # Databricks instance and API key | |
| databricks_instance = formatted_title | |
| databricks_api_key = databricks_api_key | |
| # Dataset ID and New Dataset | |
| new_dataset = 1 if create_new_dataset else 0 | |
| dataset_id = dataset_id | |
| # Table Path | |
| table_path = f"{selected_database}.{selected_table}" | |
| # Frequency | |
| frequency = frequency | |
| # Schema Map | |
| row_data_input = selected_row_data | |
| global_key_input = selected_global_key | |
| # Create the initial dictionary | |
| schema_map_dict = {'row_data': row_data_input} | |
| if global_key_input: | |
| schema_map_dict['global_key'] = global_key_input | |
| # Swap keys and values | |
| reversed_schema_map_dict = {v: k for k, v in schema_map_dict.items()} | |
| # Convert the reversed dictionary to a stringified JSON | |
| reversed_schema_map_str = json.dumps(reversed_schema_map_dict) | |
| data = { | |
| "cloud": cloud, | |
| "mode": mode, | |
| "databricks_instance": databricks_instance, | |
| "databricks_api_key": databricks_api_key, | |
| "new_dataset": new_dataset, | |
| "dataset_id": dataset_id, | |
| "table_path": table_path, | |
| "labelbox_api_key": labelbox_api_key, | |
| "frequency": frequency, | |
| "new_cluster": 0, | |
| "cluster_id": cluster_id, | |
| "schema_map": reversed_schema_map_str | |
| } | |
| if st.button("Deploy Pipeline!", type="primary"): | |
| # Ensure all fields are filled out | |
| required_fields = [ | |
| mode, databricks_instance, databricks_api_key, new_dataset, dataset_id, | |
| table_path, labelbox_api_key, frequency, cluster_id, reversed_schema_map_str | |
| ] | |
| # Sending a POST request to the Flask app endpoint | |
| with st.spinner("Deploying pipeline..."): | |
| response = requests.post("https://us-central1-dbt-prod.cloudfunctions.net/deploy-databricks-pipeline", json=data) | |
| # Check if request was successful | |
| if response.status_code == 200: | |
| # Display the response using Streamlit | |
| st.balloons() | |
| response = response.json() | |
| # Extract the job_id | |
| job_id = response['message'].split('job_id":')[1].split('}')[0] | |
| from urllib.parse import urlparse, parse_qs | |
| # Parse the Databricks instance URL to extract the organization ID | |
| parsed_url = urlparse(formatted_title) | |
| query_params = parse_qs(parsed_url.query) | |
| organization_id = query_params.get("o", [""])[0] | |
| # Generate the Databricks Job URL | |
| job_url = f"http://{formatted_title}/?o={organization_id}#job/{job_id}" | |
| st.success(f"Pipeline deployed successfully! [{job_url}]({job_url}) π") | |
| else: | |
| st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="π«") | |
| st.markdown(""" | |
| <style> | |
| /* Add a large bottom padding to the main content */ | |
| .main .block-container { | |
| padding-bottom: 1000px; /* Adjust this value as needed */ | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |