Spaces:
Sleeping
Sleeping
import streamlit as st | |
import requests | |
import re | |
import json | |
import time | |
import pandas as pd | |
import labelbox | |
def fetch_databases(cluster_id, formatted_title, databricks_api_key): | |
query = "SHOW DATABASES;" | |
return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
# Cached function to fetch tables | |
def fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key): | |
query = f"SHOW TABLES IN {selected_database};" | |
return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
# Cached function to fetch columns | |
def fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key): | |
query = f"SHOW COLUMNS IN {selected_database}.{selected_table};" | |
return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
def validate_dataset_name(name): | |
"""Validate the dataset name.""" | |
# Check length | |
if len(name) > 256: | |
return "Dataset name should be limited to 256 characters." | |
# Check allowed characters | |
allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$') | |
if not allowed_characters_pattern.match(name): | |
return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.") | |
return None | |
def create_new_dataset_labelbox (new_dataset_name): | |
client = labelbox.Client(api_key=labelbox_api_key) | |
dataset_name = new_dataset_name | |
dataset = client.create_dataset(name=dataset_name) | |
dataset_id = dataset.uid | |
return dataset_id | |
def get_dataset_from_labelbox(labelbox_api_key): | |
client = labelbox.Client(api_key=labelbox_api_key) | |
datasets = client.get_datasets() | |
return datasets | |
def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key): | |
DOMAIN = f"https://{domain}" | |
TOKEN = f"Bearer {databricks_api_key}" | |
headers = { | |
"Authorization": TOKEN, | |
"Content-Type": "application/json", | |
} | |
# Destroy context | |
destroy_payload = { | |
"clusterId": cluster_id, | |
"contextId": context_id | |
} | |
destroy_response = requests.post( | |
f"{DOMAIN}/api/1.2/contexts/destroy", | |
headers=headers, | |
data=json.dumps(destroy_payload) | |
) | |
if destroy_response.status_code != 200: | |
raise ValueError("Failed to destroy context.") | |
def execute_databricks_query(query, cluster_id, domain, databricks_api_key): | |
DOMAIN = f"https://{domain}" | |
TOKEN = f"Bearer {databricks_api_key}" | |
headers = { | |
"Authorization": TOKEN, | |
"Content-Type": "application/json", | |
} | |
# Create context | |
context_payload = { | |
"clusterId": cluster_id, | |
"language": "sql" | |
} | |
context_response = requests.post( | |
f"{DOMAIN}/api/1.2/contexts/create", | |
headers=headers, | |
data=json.dumps(context_payload) | |
) | |
context_response_data = context_response.json() | |
if 'id' not in context_response_data: | |
raise ValueError("Failed to create context.") | |
context_id = context_response_data['id'] | |
# Execute query | |
command_payload = { | |
"clusterId": cluster_id, | |
"contextId": context_id, | |
"language": "sql", | |
"command": query | |
} | |
command_response = requests.post( | |
f"{DOMAIN}/api/1.2/commands/execute", | |
headers=headers, | |
data=json.dumps(command_payload) | |
).json() | |
if 'id' not in command_response: | |
raise ValueError("Failed to execute command.") | |
command_id = command_response['id'] | |
# Wait for the command to complete | |
while True: | |
status_response = requests.get( | |
f"{DOMAIN}/api/1.2/commands/status", | |
headers=headers, | |
params={ | |
"clusterId": cluster_id, | |
"contextId": context_id, | |
"commandId": command_id | |
} | |
).json() | |
command_status = status_response.get("status") | |
if command_status == "Finished": | |
break | |
elif command_status in ["Error", "Cancelled"]: | |
raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}") | |
else: | |
time.sleep(1) # Wait for 5 seconds before checking again | |
# Convert the results into a pandas DataFrame | |
data = status_response.get('results', {}).get('data', []) | |
columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])] | |
df = pd.DataFrame(data, columns=columns) | |
destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key) | |
return df | |
st.title("Labelbox π€ Databricks") | |
st.header("Pipeline Creator", divider='rainbow') | |
def is_valid_url_or_uri(value): | |
"""Check if the provided value is a valid URL or URI.""" | |
# Check general URLs | |
url_pattern = re.compile( | |
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' | |
) | |
# Check general URIs including cloud storage URIs (like gs://, s3://, etc.) | |
uri_pattern = re.compile( | |
r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+' | |
) | |
return url_pattern.match(value) or uri_pattern.match(value) | |
is_preview = st.toggle('Run in Preview Mode', value=False) | |
if is_preview: | |
st.success('Running in Preview mode!', icon="β ") | |
else: | |
st.success('Running in Production mode!', icon="β ") | |
st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey') | |
#cloud = "GCP" | |
cloud = st.selectbox('Which cloud environment does your Databricks Workspace run in?', ['AWS', 'Azure', 'GCP'], index=None) | |
title = st.text_input('Enter Databricks Domain (e.g., <instance>.<cloud>.databricks.com)', '') | |
databricks_api_key = st.text_input('Databricks API Key', type='password') | |
labelbox_api_key = st.text_input('Labelbox API Key', type='password') | |
# After Labelbox API key is entered | |
if labelbox_api_key: | |
# Fetching datasets | |
datasets = get_dataset_from_labelbox(labelbox_api_key) | |
create_new_dataset = st.toggle("Make me a new dataset", value=False) | |
if not create_new_dataset: | |
# The existing logic for selecting datasets goes here. | |
dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets} | |
selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys())) | |
dataset_id = dataset_name_to_id[selected_dataset_name] | |
else: | |
# If user toggles "make me a new dataset" | |
new_dataset_name = st.text_input("Enter the new dataset name:") | |
# Check if the name is valid | |
if new_dataset_name: | |
validation_message = validate_dataset_name(new_dataset_name) | |
if validation_message: | |
st.error(validation_message, icon="π«") | |
else: | |
st.success(f"Valid dataset name! Dataset_id", icon="β ") | |
dataset_name = new_dataset_name | |
# Define the variables beforehand with default values (if not defined) | |
new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None | |
selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None | |
if new_dataset_name or selected_dataset_name: | |
# Handling various formats of input | |
formatted_title = re.sub(r'^https?://', '', title) # Remove http:// or https:// | |
formatted_title = re.sub(r'/$', '', formatted_title) # Remove trailing slash if present | |
if formatted_title: | |
st.subheader("Select an existing cluster", divider='grey', help="Jobs will use job clusters to reduce DBUs consumed.") | |
DOMAIN = f"https://{formatted_title}" | |
TOKEN = f"Bearer {databricks_api_key}" | |
HEADERS = { | |
"Authorization": TOKEN, | |
"Content-Type": "application/json", | |
} | |
# Endpoint to list clusters | |
ENDPOINT = "/api/2.0/clusters/list" | |
try: | |
response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS) | |
response.raise_for_status() | |
# Include clusters with cluster_source "UI" or "API" | |
clusters = response.json().get("clusters", []) | |
cluster_dict = { | |
cluster["cluster_name"]: cluster["cluster_id"] | |
for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"] | |
} | |
# Display dropdown with cluster names | |
if cluster_dict: | |
selected_cluster_name = st.selectbox( | |
'Select a cluster to run on', | |
list(cluster_dict.keys()), | |
key='unique_key_for_cluster_selectbox', | |
index=None, | |
placeholder="Select a cluster..", | |
) | |
if selected_cluster_name: | |
cluster_id = cluster_dict[selected_cluster_name] | |
except requests.RequestException as e: | |
st.write(f"Error communicating with Databricks API: {str(e)}") | |
except ValueError: | |
st.write("Received unexpected response from Databricks API.") | |
if selected_cluster_name and cluster_id: | |
# Check if the selected cluster is running | |
cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0] | |
# If the cluster is not running, start it | |
if cluster_state != "RUNNING": | |
with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."): | |
start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id}) | |
start_response.raise_for_status() | |
# Poll until the cluster is up or until timeout | |
start_time = time.time() | |
timeout = 1200 # 20 minutes in seconds | |
while True: | |
cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json() | |
if "state" in cluster_response: | |
if cluster_response["state"] == "RUNNING": | |
break | |
elif cluster_response["state"] in ["TERMINATED", "ERROR"]: | |
st.write(f"Error starting cluster. Current state: {cluster_response['state']}") | |
break | |
if (time.time() - start_time) > timeout: | |
st.write("Timeout reached while starting the cluster.") | |
break | |
time.sleep(10) # Check every 10 seconds | |
st.success(f"{selected_cluster_name} is now running!", icon="πββοΈ") | |
else: | |
st.success(f"{selected_cluster_name} is already running!", icon="πββοΈ") | |
def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): | |
""" | |
Generate a cron expression based on the provided frequency and time. | |
""" | |
if freq == "1 minute": | |
return "0 * * * * ?" | |
elif freq == "1 hour": | |
return f"0 {minute} * * * ?" | |
elif freq == "1 day": | |
return f"0 {minute} {hour} * * ?" | |
elif freq == "1 week": | |
if not day_of_week: | |
raise ValueError("Day of week not provided for weekly frequency.") | |
return f"0 {minute} {hour} ? * {day_of_week}" | |
elif freq == "1 month": | |
if not day_of_month: | |
raise ValueError("Day of month not provided for monthly frequency.") | |
return f"0 {minute} {hour} {day_of_month} * ?" | |
else: | |
raise ValueError("Invalid frequency provided") | |
# Streamlit UI | |
st.subheader("Run Frequency", divider='grey') | |
# Dropdown to select frequency | |
freq_options = ["1 day", "1 week", "1 month"] | |
selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..") | |
day_of_week = None | |
day_of_month = None | |
# If the frequency is hourly, daily, weekly, or monthly, ask for a specific time | |
if selected_freq != "1 minute": | |
if selected_freq == "1 week": | |
days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"] | |
day_of_week = st.selectbox("Select day of the week:", days_options) | |
elif selected_freq == "1 month": | |
day_of_month = st.selectbox("Select day of the month:", list(range(1, 32))) | |
col1, col2 = st.columns(2) | |
with col1: | |
hour = st.selectbox("Hour:", list(range(0, 24))) | |
with col2: | |
minute = st.selectbox("Minute:", list(range(0, 60))) | |
else: | |
hour, minute = 0, 0 | |
# Generate the cron expression | |
frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month) | |
# Assumed DBU consumption rate for a 32GB, 4-core node per hour | |
X = 1 # Replace this with the actual rate from Databricks' pricing or documentation | |
# Calculate DBU consumption for a single run | |
min_dbu_single_run = (X/6) * (1 + 10) # Assuming maximum scaling to 10 workers | |
max_dbu_single_run = (2*X/3) * (1 + 10) | |
# Estimate monthly DBU consumption based on frequency | |
if freq_options == "1 day": | |
min_dbu_monthly = 30 * min_dbu_single_run | |
max_dbu_monthly = 30 * max_dbu_single_run | |
elif freq_options == "1 week": | |
min_dbu_monthly = 4 * min_dbu_single_run | |
max_dbu_monthly = 4 * max_dbu_single_run | |
else: # Monthly | |
min_dbu_monthly = min_dbu_single_run | |
max_dbu_monthly = max_dbu_single_run | |
# Calculate runs per month | |
if selected_freq == "1 day": | |
runs_per_month = 30 | |
elif selected_freq == "1 week": | |
runs_per_month = 4 | |
else: # "1 month" | |
runs_per_month = 1 | |
# Calculate estimated DBU consumption per month | |
min_dbu_monthly = runs_per_month * min_dbu_single_run | |
max_dbu_monthly = runs_per_month * max_dbu_single_run | |
def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): | |
""" | |
Generate a human-readable message for the scheduling. | |
""" | |
if freq == "1 minute": | |
return "Job will run every minute." | |
elif freq == "1 hour": | |
return f"Job will run once an hour at minute {minute}." | |
elif freq == "1 day": | |
return f"Job will run daily at {hour:02}:{minute:02}." | |
elif freq == "1 week": | |
if not day_of_week: | |
raise ValueError("Day of week not provided for weekly frequency.") | |
return f"Job will run every {day_of_week} at {hour:02}:{minute:02}." | |
elif freq == "1 month": | |
if not day_of_month: | |
raise ValueError("Day of month not provided for monthly frequency.") | |
return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}." | |
else: | |
raise ValueError("Invalid frequency provided") | |
# Generate the human-readable message | |
readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month) | |
# Main code block | |
if frequency: | |
st.success(readable_msg, icon="π ") | |
# Display the estimated DBU consumption to the user | |
st.warning(f"Estimated DBU Consumption:\n- For a single run: {min_dbu_single_run:.2f} to {max_dbu_single_run:.2f} DBUs\n- Monthly (based on {runs_per_month} runs): {min_dbu_monthly:.2f} to {max_dbu_monthly:.2f} DBUs") | |
# Disclaimer | |
st.info("Disclaimer: This is only an estimation. Always monitor the job in Databricks to assess actual DBU consumption.") | |
st.subheader("Select a table", divider="grey") | |
# Fetching databases | |
result_data = fetch_databases(cluster_id, formatted_title, databricks_api_key) | |
database_names = result_data['databaseName'].tolist() | |
selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..") | |
if selected_database: | |
# Fetching tables | |
result_data = fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key) | |
table_names = result_data['tableName'].tolist() | |
selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..") | |
if selected_table: | |
# Fetching columns | |
result_data = fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key) | |
column_names = result_data['col_name'].tolist() | |
st.subheader("Map table schema to Labelbox schema", divider="grey") | |
# Your existing code to handle schema mapping... | |
# Fetch the first 5 rows of the selected table | |
with st.spinner('Fetching first 5 rows of the selected table...'): | |
query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;" | |
table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
st.write(table_sample_data) | |
# Define two columns for side-by-side selectboxes | |
col1, col2 = st.columns(2) | |
with col1: | |
selected_row_data = st.selectbox( | |
"row_data (required):", | |
column_names, | |
index=None, | |
placeholder="Select a column..", | |
help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox." | |
) | |
with col2: | |
selected_global_key = st.selectbox( | |
"global_key (optional):", | |
column_names, | |
index=None, | |
placeholder="Select a column..", | |
help="Select the column that contains the global key. If not provided, a new key will be generated for you." | |
) | |
# Fetch a single row from the selected table | |
query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;" | |
result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key) | |
if selected_row_data: | |
# Extract the value from the selected row_data column | |
sample_row_data_value = result_sample[selected_row_data].iloc[0] | |
# Validate the extracted value | |
dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id | |
# Mode | |
mode = "preview" if is_preview else "production" | |
# Databricks instance and API key | |
databricks_instance = formatted_title | |
databricks_api_key = databricks_api_key | |
# Dataset ID and New Dataset | |
new_dataset = 1 if create_new_dataset else 0 | |
dataset_id = dataset_id | |
# Table Path | |
table_path = f"{selected_database}.{selected_table}" | |
# Frequency | |
frequency = frequency | |
# Schema Map | |
row_data_input = selected_row_data | |
global_key_input = selected_global_key | |
# Create the initial dictionary | |
schema_map_dict = {'row_data': row_data_input} | |
if global_key_input: | |
schema_map_dict['global_key'] = global_key_input | |
# Swap keys and values | |
reversed_schema_map_dict = {v: k for k, v in schema_map_dict.items()} | |
# Convert the reversed dictionary to a stringified JSON | |
reversed_schema_map_str = json.dumps(reversed_schema_map_dict) | |
data = { | |
"cloud": cloud, | |
"mode": mode, | |
"databricks_instance": databricks_instance, | |
"databricks_api_key": databricks_api_key, | |
"new_dataset": new_dataset, | |
"dataset_id": dataset_id, | |
"table_path": table_path, | |
"labelbox_api_key": labelbox_api_key, | |
"frequency": frequency, | |
"new_cluster": 0, | |
"cluster_id": cluster_id, | |
"schema_map": reversed_schema_map_str | |
} | |
if st.button("Deploy Pipeline!", type="primary"): | |
# Ensure all fields are filled out | |
required_fields = [ | |
mode, databricks_instance, databricks_api_key, new_dataset, dataset_id, | |
table_path, labelbox_api_key, frequency, cluster_id, reversed_schema_map_str | |
] | |
# Sending a POST request to the Flask app endpoint | |
with st.spinner("Deploying pipeline..."): | |
response = requests.post("https://us-central1-dbt-prod.cloudfunctions.net/deploy-databricks-pipeline", json=data) | |
# Check if request was successful | |
if response.status_code == 200: | |
# Display the response using Streamlit | |
st.balloons() | |
response = response.json() | |
# Extract the job_id | |
job_id = response['message'].split('job_id":')[1].split('}')[0] | |
from urllib.parse import urlparse, parse_qs | |
# Parse the Databricks instance URL to extract the organization ID | |
parsed_url = urlparse(formatted_title) | |
query_params = parse_qs(parsed_url.query) | |
organization_id = query_params.get("o", [""])[0] | |
# Generate the Databricks Job URL | |
job_url = f"http://{formatted_title}/?o={organization_id}#job/{job_id}" | |
st.success(f"Pipeline deployed successfully! [{job_url}]({job_url}) π") | |
else: | |
st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="π«") | |
st.markdown(""" | |
<style> | |
/* Add a large bottom padding to the main content */ | |
.main .block-container { | |
padding-bottom: 1000px; /* Adjust this value as needed */ | |
} | |
</style> | |
""", unsafe_allow_html=True) | |