import streamlit as st import requests import re import json import time import pandas as pd import labelbox @st.cache_data(show_spinner=True) def fetch_databases(cluster_id, formatted_title, databricks_api_key): query = "SHOW DATABASES;" return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) # Cached function to fetch tables @st.cache_data(show_spinner=True) def fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key): query = f"SHOW TABLES IN {selected_database};" return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) # Cached function to fetch columns @st.cache_data(show_spinner=True) def fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key): query = f"SHOW COLUMNS IN {selected_database}.{selected_table};" return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) def validate_dataset_name(name): """Validate the dataset name.""" # Check length if len(name) > 256: return "Dataset name should be limited to 256 characters." # Check allowed characters allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$') if not allowed_characters_pattern.match(name): return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.") return None def create_new_dataset_labelbox (new_dataset_name): client = labelbox.Client(api_key=labelbox_api_key) dataset_name = new_dataset_name dataset = client.create_dataset(name=dataset_name) dataset_id = dataset.uid return dataset_id def get_dataset_from_labelbox(labelbox_api_key): client = labelbox.Client(api_key=labelbox_api_key) datasets = client.get_datasets() return datasets def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key): DOMAIN = f"https://{domain}" TOKEN = f"Bearer {databricks_api_key}" headers = { "Authorization": TOKEN, "Content-Type": "application/json", } # Destroy context destroy_payload = { "clusterId": cluster_id, "contextId": context_id } destroy_response = requests.post( f"{DOMAIN}/api/1.2/contexts/destroy", headers=headers, data=json.dumps(destroy_payload) ) if destroy_response.status_code != 200: raise ValueError("Failed to destroy context.") def execute_databricks_query(query, cluster_id, domain, databricks_api_key): DOMAIN = f"https://{domain}" TOKEN = f"Bearer {databricks_api_key}" headers = { "Authorization": TOKEN, "Content-Type": "application/json", } # Create context context_payload = { "clusterId": cluster_id, "language": "sql" } context_response = requests.post( f"{DOMAIN}/api/1.2/contexts/create", headers=headers, data=json.dumps(context_payload) ) context_response_data = context_response.json() if 'id' not in context_response_data: raise ValueError("Failed to create context.") context_id = context_response_data['id'] # Execute query command_payload = { "clusterId": cluster_id, "contextId": context_id, "language": "sql", "command": query } command_response = requests.post( f"{DOMAIN}/api/1.2/commands/execute", headers=headers, data=json.dumps(command_payload) ).json() if 'id' not in command_response: raise ValueError("Failed to execute command.") command_id = command_response['id'] # Wait for the command to complete while True: status_response = requests.get( f"{DOMAIN}/api/1.2/commands/status", headers=headers, params={ "clusterId": cluster_id, "contextId": context_id, "commandId": command_id } ).json() command_status = status_response.get("status") if command_status == "Finished": break elif command_status in ["Error", "Cancelled"]: raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}") else: time.sleep(1) # Wait for 5 seconds before checking again # Convert the results into a pandas DataFrame data = status_response.get('results', {}).get('data', []) columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])] df = pd.DataFrame(data, columns=columns) destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key) return df st.title("Labelbox 🤝 Databricks") st.header("Pipeline Creator", divider='rainbow') def is_valid_url_or_uri(value): """Check if the provided value is a valid URL or URI.""" # Check general URLs url_pattern = re.compile( r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' ) # Check general URIs including cloud storage URIs (like gs://, s3://, etc.) uri_pattern = re.compile( r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+' ) return url_pattern.match(value) or uri_pattern.match(value) is_preview = st.toggle('Run in Preview Mode', value=False) if is_preview: st.success('Running in Preview mode!', icon="✅") else: st.success('Running in Production mode!', icon="✅") st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey') #cloud = "GCP" cloud = st.selectbox('Which cloud environment does your Databricks Workspace run in?', ['AWS', 'Azure', 'GCP'], index=None) title = st.text_input('Enter Databricks Domain (e.g., ..databricks.com)', '') databricks_api_key = st.text_input('Databricks API Key', type='password') labelbox_api_key = st.text_input('Labelbox API Key', type='password') # After Labelbox API key is entered if labelbox_api_key: # Fetching datasets datasets = get_dataset_from_labelbox(labelbox_api_key) create_new_dataset = st.toggle("Make me a new dataset", value=False) if not create_new_dataset: # The existing logic for selecting datasets goes here. dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets} selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys())) dataset_id = dataset_name_to_id[selected_dataset_name] else: # If user toggles "make me a new dataset" new_dataset_name = st.text_input("Enter the new dataset name:") # Check if the name is valid if new_dataset_name: validation_message = validate_dataset_name(new_dataset_name) if validation_message: st.error(validation_message, icon="🚫") else: st.success(f"Valid dataset name! Dataset_id", icon="✅") dataset_name = new_dataset_name # Define the variables beforehand with default values (if not defined) new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None if new_dataset_name or selected_dataset_name: # Handling various formats of input formatted_title = re.sub(r'^https?://', '', title) # Remove http:// or https:// formatted_title = re.sub(r'/$', '', formatted_title) # Remove trailing slash if present if formatted_title: st.subheader("Select an existing cluster", divider='grey', help="Jobs will use job clusters to reduce DBUs consumed.") DOMAIN = f"https://{formatted_title}" TOKEN = f"Bearer {databricks_api_key}" HEADERS = { "Authorization": TOKEN, "Content-Type": "application/json", } # Endpoint to list clusters ENDPOINT = "/api/2.0/clusters/list" try: response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS) response.raise_for_status() # Include clusters with cluster_source "UI" or "API" clusters = response.json().get("clusters", []) cluster_dict = { cluster["cluster_name"]: cluster["cluster_id"] for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"] } # Display dropdown with cluster names if cluster_dict: selected_cluster_name = st.selectbox( 'Select a cluster to run on', list(cluster_dict.keys()), key='unique_key_for_cluster_selectbox', index=None, placeholder="Select a cluster..", ) if selected_cluster_name: cluster_id = cluster_dict[selected_cluster_name] except requests.RequestException as e: st.write(f"Error communicating with Databricks API: {str(e)}") except ValueError: st.write("Received unexpected response from Databricks API.") if selected_cluster_name and cluster_id: # Check if the selected cluster is running cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0] # If the cluster is not running, start it if cluster_state != "RUNNING": with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."): start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id}) start_response.raise_for_status() # Poll until the cluster is up or until timeout start_time = time.time() timeout = 1200 # 20 minutes in seconds while True: cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json() if "state" in cluster_response: if cluster_response["state"] == "RUNNING": break elif cluster_response["state"] in ["TERMINATED", "ERROR"]: st.write(f"Error starting cluster. Current state: {cluster_response['state']}") break if (time.time() - start_time) > timeout: st.write("Timeout reached while starting the cluster.") break time.sleep(10) # Check every 10 seconds st.success(f"{selected_cluster_name} is now running!", icon="🏃‍♂️") else: st.success(f"{selected_cluster_name} is already running!", icon="🏃‍♂️") def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): """ Generate a cron expression based on the provided frequency and time. """ if freq == "1 minute": return "0 * * * * ?" elif freq == "1 hour": return f"0 {minute} * * * ?" elif freq == "1 day": return f"0 {minute} {hour} * * ?" elif freq == "1 week": if not day_of_week: raise ValueError("Day of week not provided for weekly frequency.") return f"0 {minute} {hour} ? * {day_of_week}" elif freq == "1 month": if not day_of_month: raise ValueError("Day of month not provided for monthly frequency.") return f"0 {minute} {hour} {day_of_month} * ?" else: raise ValueError("Invalid frequency provided") # Streamlit UI st.subheader("Run Frequency", divider='grey') # Dropdown to select frequency freq_options = ["1 day", "1 week", "1 month"] selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..") day_of_week = None day_of_month = None # If the frequency is hourly, daily, weekly, or monthly, ask for a specific time if selected_freq != "1 minute": if selected_freq == "1 week": days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"] day_of_week = st.selectbox("Select day of the week:", days_options) elif selected_freq == "1 month": day_of_month = st.selectbox("Select day of the month:", list(range(1, 32))) col1, col2 = st.columns(2) with col1: hour = st.selectbox("Hour:", list(range(0, 24))) with col2: minute = st.selectbox("Minute:", list(range(0, 60))) else: hour, minute = 0, 0 # Generate the cron expression frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month) # Assumed DBU consumption rate for a 32GB, 4-core node per hour X = 1 # Replace this with the actual rate from Databricks' pricing or documentation # Calculate DBU consumption for a single run min_dbu_single_run = (X/6) * (1 + 10) # Assuming maximum scaling to 10 workers max_dbu_single_run = (2*X/3) * (1 + 10) # Estimate monthly DBU consumption based on frequency if freq_options == "1 day": min_dbu_monthly = 30 * min_dbu_single_run max_dbu_monthly = 30 * max_dbu_single_run elif freq_options == "1 week": min_dbu_monthly = 4 * min_dbu_single_run max_dbu_monthly = 4 * max_dbu_single_run else: # Monthly min_dbu_monthly = min_dbu_single_run max_dbu_monthly = max_dbu_single_run # Calculate runs per month if selected_freq == "1 day": runs_per_month = 30 elif selected_freq == "1 week": runs_per_month = 4 else: # "1 month" runs_per_month = 1 # Calculate estimated DBU consumption per month min_dbu_monthly = runs_per_month * min_dbu_single_run max_dbu_monthly = runs_per_month * max_dbu_single_run def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): """ Generate a human-readable message for the scheduling. """ if freq == "1 minute": return "Job will run every minute." elif freq == "1 hour": return f"Job will run once an hour at minute {minute}." elif freq == "1 day": return f"Job will run daily at {hour:02}:{minute:02}." elif freq == "1 week": if not day_of_week: raise ValueError("Day of week not provided for weekly frequency.") return f"Job will run every {day_of_week} at {hour:02}:{minute:02}." elif freq == "1 month": if not day_of_month: raise ValueError("Day of month not provided for monthly frequency.") return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}." else: raise ValueError("Invalid frequency provided") # Generate the human-readable message readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month) # Main code block if frequency: st.success(readable_msg, icon="📅") # Display the estimated DBU consumption to the user st.warning(f"Estimated DBU Consumption:\n- For a single run: {min_dbu_single_run:.2f} to {max_dbu_single_run:.2f} DBUs\n- Monthly (based on {runs_per_month} runs): {min_dbu_monthly:.2f} to {max_dbu_monthly:.2f} DBUs") # Disclaimer st.info("Disclaimer: This is only an estimation. Always monitor the job in Databricks to assess actual DBU consumption.") st.subheader("Select a table", divider="grey") # Fetching databases result_data = fetch_databases(cluster_id, formatted_title, databricks_api_key) database_names = result_data['databaseName'].tolist() selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..") if selected_database: # Fetching tables result_data = fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key) table_names = result_data['tableName'].tolist() selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..") if selected_table: # Fetching columns result_data = fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key) column_names = result_data['col_name'].tolist() st.subheader("Map table schema to Labelbox schema", divider="grey") # Your existing code to handle schema mapping... # Fetch the first 5 rows of the selected table with st.spinner('Fetching first 5 rows of the selected table...'): query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;" table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) st.write(table_sample_data) # Define two columns for side-by-side selectboxes col1, col2 = st.columns(2) with col1: selected_row_data = st.selectbox( "row_data (required):", column_names, index=None, placeholder="Select a column..", help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox." ) with col2: selected_global_key = st.selectbox( "global_key (optional):", column_names, index=None, placeholder="Select a column..", help="Select the column that contains the global key. If not provided, a new key will be generated for you." ) # Fetch a single row from the selected table query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;" result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key) if selected_row_data: # Extract the value from the selected row_data column sample_row_data_value = result_sample[selected_row_data].iloc[0] # Validate the extracted value dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id # Mode mode = "preview" if is_preview else "production" # Databricks instance and API key databricks_instance = formatted_title databricks_api_key = databricks_api_key # Dataset ID and New Dataset new_dataset = 1 if create_new_dataset else 0 dataset_id = dataset_id # Table Path table_path = f"{selected_database}.{selected_table}" # Frequency frequency = frequency # Schema Map row_data_input = selected_row_data global_key_input = selected_global_key # Create the initial dictionary schema_map_dict = {'row_data': row_data_input} if global_key_input: schema_map_dict['global_key'] = global_key_input # Swap keys and values reversed_schema_map_dict = {v: k for k, v in schema_map_dict.items()} # Convert the reversed dictionary to a stringified JSON reversed_schema_map_str = json.dumps(reversed_schema_map_dict) data = { "cloud": cloud, "mode": mode, "databricks_instance": databricks_instance, "databricks_api_key": databricks_api_key, "new_dataset": new_dataset, "dataset_id": dataset_id, "table_path": table_path, "labelbox_api_key": labelbox_api_key, "frequency": frequency, "new_cluster": 0, "cluster_id": cluster_id, "schema_map": reversed_schema_map_str } if st.button("Deploy Pipeline!", type="primary"): # Ensure all fields are filled out required_fields = [ mode, databricks_instance, databricks_api_key, new_dataset, dataset_id, table_path, labelbox_api_key, frequency, cluster_id, reversed_schema_map_str ] # Sending a POST request to the Flask app endpoint with st.spinner("Deploying pipeline..."): response = requests.post("https://us-central1-dbt-prod.cloudfunctions.net/deploy-databricks-pipeline", json=data) # Check if request was successful if response.status_code == 200: # Display the response using Streamlit st.balloons() response = response.json() # Extract the job_id job_id = response['message'].split('job_id":')[1].split('}')[0] from urllib.parse import urlparse, parse_qs # Parse the Databricks instance URL to extract the organization ID parsed_url = urlparse(formatted_title) query_params = parse_qs(parsed_url.query) organization_id = query_params.get("o", [""])[0] # Generate the Databricks Job URL job_url = f"http://{formatted_title}/?o={organization_id}#job/{job_id}" st.success(f"Pipeline deployed successfully! [{job_url}]({job_url}) 🚀") else: st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="🚫") st.markdown(""" """, unsafe_allow_html=True)