ksaramout's picture
Upload app.py
4ec16ee
import streamlit as st
import requests
import re
import json
import time
import pandas as pd
import labelbox
@st.cache_data(show_spinner=True)
def fetch_databases(cluster_id, formatted_title, databricks_api_key):
query = "SHOW DATABASES;"
return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
# Cached function to fetch tables
@st.cache_data(show_spinner=True)
def fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key):
query = f"SHOW TABLES IN {selected_database};"
return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
# Cached function to fetch columns
@st.cache_data(show_spinner=True)
def fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key):
query = f"SHOW COLUMNS IN {selected_database}.{selected_table};"
return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
def validate_dataset_name(name):
"""Validate the dataset name."""
# Check length
if len(name) > 256:
return "Dataset name should be limited to 256 characters."
# Check allowed characters
allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$')
if not allowed_characters_pattern.match(name):
return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.")
return None
def create_new_dataset_labelbox (new_dataset_name):
client = labelbox.Client(api_key=labelbox_api_key)
dataset_name = new_dataset_name
dataset = client.create_dataset(name=dataset_name)
dataset_id = dataset.uid
return dataset_id
def get_dataset_from_labelbox(labelbox_api_key):
client = labelbox.Client(api_key=labelbox_api_key)
datasets = client.get_datasets()
return datasets
def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key):
DOMAIN = f"https://{domain}"
TOKEN = f"Bearer {databricks_api_key}"
headers = {
"Authorization": TOKEN,
"Content-Type": "application/json",
}
# Destroy context
destroy_payload = {
"clusterId": cluster_id,
"contextId": context_id
}
destroy_response = requests.post(
f"{DOMAIN}/api/1.2/contexts/destroy",
headers=headers,
data=json.dumps(destroy_payload)
)
if destroy_response.status_code != 200:
raise ValueError("Failed to destroy context.")
def execute_databricks_query(query, cluster_id, domain, databricks_api_key):
DOMAIN = f"https://{domain}"
TOKEN = f"Bearer {databricks_api_key}"
headers = {
"Authorization": TOKEN,
"Content-Type": "application/json",
}
# Create context
context_payload = {
"clusterId": cluster_id,
"language": "sql"
}
context_response = requests.post(
f"{DOMAIN}/api/1.2/contexts/create",
headers=headers,
data=json.dumps(context_payload)
)
context_response_data = context_response.json()
if 'id' not in context_response_data:
raise ValueError("Failed to create context.")
context_id = context_response_data['id']
# Execute query
command_payload = {
"clusterId": cluster_id,
"contextId": context_id,
"language": "sql",
"command": query
}
command_response = requests.post(
f"{DOMAIN}/api/1.2/commands/execute",
headers=headers,
data=json.dumps(command_payload)
).json()
if 'id' not in command_response:
raise ValueError("Failed to execute command.")
command_id = command_response['id']
# Wait for the command to complete
while True:
status_response = requests.get(
f"{DOMAIN}/api/1.2/commands/status",
headers=headers,
params={
"clusterId": cluster_id,
"contextId": context_id,
"commandId": command_id
}
).json()
command_status = status_response.get("status")
if command_status == "Finished":
break
elif command_status in ["Error", "Cancelled"]:
raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}")
else:
time.sleep(1) # Wait for 5 seconds before checking again
# Convert the results into a pandas DataFrame
data = status_response.get('results', {}).get('data', [])
columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])]
df = pd.DataFrame(data, columns=columns)
destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key)
return df
st.title("Labelbox 🀝 Databricks")
st.header("Pipeline Creator", divider='rainbow')
def is_valid_url_or_uri(value):
"""Check if the provided value is a valid URL or URI."""
# Check general URLs
url_pattern = re.compile(
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
)
# Check general URIs including cloud storage URIs (like gs://, s3://, etc.)
uri_pattern = re.compile(
r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+'
)
return url_pattern.match(value) or uri_pattern.match(value)
is_preview = st.toggle('Run in Preview Mode', value=False)
if is_preview:
st.success('Running in Preview mode!', icon="βœ…")
else:
st.success('Running in Production mode!', icon="βœ…")
st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey')
#cloud = "GCP"
cloud = st.selectbox('Which cloud environment does your Databricks Workspace run in?', ['AWS', 'Azure', 'GCP'], index=None)
title = st.text_input('Enter Databricks Domain (e.g., <instance>.<cloud>.databricks.com)', '')
databricks_api_key = st.text_input('Databricks API Key', type='password')
labelbox_api_key = st.text_input('Labelbox API Key', type='password')
# After Labelbox API key is entered
if labelbox_api_key:
# Fetching datasets
datasets = get_dataset_from_labelbox(labelbox_api_key)
create_new_dataset = st.toggle("Make me a new dataset", value=False)
if not create_new_dataset:
# The existing logic for selecting datasets goes here.
dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets}
selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys()))
dataset_id = dataset_name_to_id[selected_dataset_name]
else:
# If user toggles "make me a new dataset"
new_dataset_name = st.text_input("Enter the new dataset name:")
# Check if the name is valid
if new_dataset_name:
validation_message = validate_dataset_name(new_dataset_name)
if validation_message:
st.error(validation_message, icon="🚫")
else:
st.success(f"Valid dataset name! Dataset_id", icon="βœ…")
dataset_name = new_dataset_name
# Define the variables beforehand with default values (if not defined)
new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None
selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None
if new_dataset_name or selected_dataset_name:
# Handling various formats of input
formatted_title = re.sub(r'^https?://', '', title) # Remove http:// or https://
formatted_title = re.sub(r'/$', '', formatted_title) # Remove trailing slash if present
if formatted_title:
st.subheader("Select an existing cluster", divider='grey', help="Jobs will use job clusters to reduce DBUs consumed.")
DOMAIN = f"https://{formatted_title}"
TOKEN = f"Bearer {databricks_api_key}"
HEADERS = {
"Authorization": TOKEN,
"Content-Type": "application/json",
}
# Endpoint to list clusters
ENDPOINT = "/api/2.0/clusters/list"
try:
response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS)
response.raise_for_status()
# Include clusters with cluster_source "UI" or "API"
clusters = response.json().get("clusters", [])
cluster_dict = {
cluster["cluster_name"]: cluster["cluster_id"]
for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"]
}
# Display dropdown with cluster names
if cluster_dict:
selected_cluster_name = st.selectbox(
'Select a cluster to run on',
list(cluster_dict.keys()),
key='unique_key_for_cluster_selectbox',
index=None,
placeholder="Select a cluster..",
)
if selected_cluster_name:
cluster_id = cluster_dict[selected_cluster_name]
except requests.RequestException as e:
st.write(f"Error communicating with Databricks API: {str(e)}")
except ValueError:
st.write("Received unexpected response from Databricks API.")
if selected_cluster_name and cluster_id:
# Check if the selected cluster is running
cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0]
# If the cluster is not running, start it
if cluster_state != "RUNNING":
with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."):
start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id})
start_response.raise_for_status()
# Poll until the cluster is up or until timeout
start_time = time.time()
timeout = 1200 # 20 minutes in seconds
while True:
cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json()
if "state" in cluster_response:
if cluster_response["state"] == "RUNNING":
break
elif cluster_response["state"] in ["TERMINATED", "ERROR"]:
st.write(f"Error starting cluster. Current state: {cluster_response['state']}")
break
if (time.time() - start_time) > timeout:
st.write("Timeout reached while starting the cluster.")
break
time.sleep(10) # Check every 10 seconds
st.success(f"{selected_cluster_name} is now running!", icon="πŸƒβ€β™‚οΈ")
else:
st.success(f"{selected_cluster_name} is already running!", icon="πŸƒβ€β™‚οΈ")
def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None):
"""
Generate a cron expression based on the provided frequency and time.
"""
if freq == "1 minute":
return "0 * * * * ?"
elif freq == "1 hour":
return f"0 {minute} * * * ?"
elif freq == "1 day":
return f"0 {minute} {hour} * * ?"
elif freq == "1 week":
if not day_of_week:
raise ValueError("Day of week not provided for weekly frequency.")
return f"0 {minute} {hour} ? * {day_of_week}"
elif freq == "1 month":
if not day_of_month:
raise ValueError("Day of month not provided for monthly frequency.")
return f"0 {minute} {hour} {day_of_month} * ?"
else:
raise ValueError("Invalid frequency provided")
# Streamlit UI
st.subheader("Run Frequency", divider='grey')
# Dropdown to select frequency
freq_options = ["1 day", "1 week", "1 month"]
selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..")
day_of_week = None
day_of_month = None
# If the frequency is hourly, daily, weekly, or monthly, ask for a specific time
if selected_freq != "1 minute":
if selected_freq == "1 week":
days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"]
day_of_week = st.selectbox("Select day of the week:", days_options)
elif selected_freq == "1 month":
day_of_month = st.selectbox("Select day of the month:", list(range(1, 32)))
col1, col2 = st.columns(2)
with col1:
hour = st.selectbox("Hour:", list(range(0, 24)))
with col2:
minute = st.selectbox("Minute:", list(range(0, 60)))
else:
hour, minute = 0, 0
# Generate the cron expression
frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month)
# Assumed DBU consumption rate for a 32GB, 4-core node per hour
X = 1 # Replace this with the actual rate from Databricks' pricing or documentation
# Calculate DBU consumption for a single run
min_dbu_single_run = (X/6) * (1 + 10) # Assuming maximum scaling to 10 workers
max_dbu_single_run = (2*X/3) * (1 + 10)
# Estimate monthly DBU consumption based on frequency
if freq_options == "1 day":
min_dbu_monthly = 30 * min_dbu_single_run
max_dbu_monthly = 30 * max_dbu_single_run
elif freq_options == "1 week":
min_dbu_monthly = 4 * min_dbu_single_run
max_dbu_monthly = 4 * max_dbu_single_run
else: # Monthly
min_dbu_monthly = min_dbu_single_run
max_dbu_monthly = max_dbu_single_run
# Calculate runs per month
if selected_freq == "1 day":
runs_per_month = 30
elif selected_freq == "1 week":
runs_per_month = 4
else: # "1 month"
runs_per_month = 1
# Calculate estimated DBU consumption per month
min_dbu_monthly = runs_per_month * min_dbu_single_run
max_dbu_monthly = runs_per_month * max_dbu_single_run
def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None):
"""
Generate a human-readable message for the scheduling.
"""
if freq == "1 minute":
return "Job will run every minute."
elif freq == "1 hour":
return f"Job will run once an hour at minute {minute}."
elif freq == "1 day":
return f"Job will run daily at {hour:02}:{minute:02}."
elif freq == "1 week":
if not day_of_week:
raise ValueError("Day of week not provided for weekly frequency.")
return f"Job will run every {day_of_week} at {hour:02}:{minute:02}."
elif freq == "1 month":
if not day_of_month:
raise ValueError("Day of month not provided for monthly frequency.")
return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}."
else:
raise ValueError("Invalid frequency provided")
# Generate the human-readable message
readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month)
# Main code block
if frequency:
st.success(readable_msg, icon="πŸ“…")
# Display the estimated DBU consumption to the user
st.warning(f"Estimated DBU Consumption:\n- For a single run: {min_dbu_single_run:.2f} to {max_dbu_single_run:.2f} DBUs\n- Monthly (based on {runs_per_month} runs): {min_dbu_monthly:.2f} to {max_dbu_monthly:.2f} DBUs")
# Disclaimer
st.info("Disclaimer: This is only an estimation. Always monitor the job in Databricks to assess actual DBU consumption.")
st.subheader("Select a table", divider="grey")
# Fetching databases
result_data = fetch_databases(cluster_id, formatted_title, databricks_api_key)
database_names = result_data['databaseName'].tolist()
selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..")
if selected_database:
# Fetching tables
result_data = fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key)
table_names = result_data['tableName'].tolist()
selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..")
if selected_table:
# Fetching columns
result_data = fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key)
column_names = result_data['col_name'].tolist()
st.subheader("Map table schema to Labelbox schema", divider="grey")
# Your existing code to handle schema mapping...
# Fetch the first 5 rows of the selected table
with st.spinner('Fetching first 5 rows of the selected table...'):
query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;"
table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
st.write(table_sample_data)
# Define two columns for side-by-side selectboxes
col1, col2 = st.columns(2)
with col1:
selected_row_data = st.selectbox(
"row_data (required):",
column_names,
index=None,
placeholder="Select a column..",
help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox."
)
with col2:
selected_global_key = st.selectbox(
"global_key (optional):",
column_names,
index=None,
placeholder="Select a column..",
help="Select the column that contains the global key. If not provided, a new key will be generated for you."
)
# Fetch a single row from the selected table
query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;"
result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key)
if selected_row_data:
# Extract the value from the selected row_data column
sample_row_data_value = result_sample[selected_row_data].iloc[0]
# Validate the extracted value
dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id
# Mode
mode = "preview" if is_preview else "production"
# Databricks instance and API key
databricks_instance = formatted_title
databricks_api_key = databricks_api_key
# Dataset ID and New Dataset
new_dataset = 1 if create_new_dataset else 0
dataset_id = dataset_id
# Table Path
table_path = f"{selected_database}.{selected_table}"
# Frequency
frequency = frequency
# Schema Map
row_data_input = selected_row_data
global_key_input = selected_global_key
# Create the initial dictionary
schema_map_dict = {'row_data': row_data_input}
if global_key_input:
schema_map_dict['global_key'] = global_key_input
# Swap keys and values
reversed_schema_map_dict = {v: k for k, v in schema_map_dict.items()}
# Convert the reversed dictionary to a stringified JSON
reversed_schema_map_str = json.dumps(reversed_schema_map_dict)
data = {
"cloud": cloud,
"mode": mode,
"databricks_instance": databricks_instance,
"databricks_api_key": databricks_api_key,
"new_dataset": new_dataset,
"dataset_id": dataset_id,
"table_path": table_path,
"labelbox_api_key": labelbox_api_key,
"frequency": frequency,
"new_cluster": 0,
"cluster_id": cluster_id,
"schema_map": reversed_schema_map_str
}
if st.button("Deploy Pipeline!", type="primary"):
# Ensure all fields are filled out
required_fields = [
mode, databricks_instance, databricks_api_key, new_dataset, dataset_id,
table_path, labelbox_api_key, frequency, cluster_id, reversed_schema_map_str
]
# Sending a POST request to the Flask app endpoint
with st.spinner("Deploying pipeline..."):
response = requests.post("https://us-central1-dbt-prod.cloudfunctions.net/deploy-databricks-pipeline", json=data)
# Check if request was successful
if response.status_code == 200:
# Display the response using Streamlit
st.balloons()
response = response.json()
# Extract the job_id
job_id = response['message'].split('job_id":')[1].split('}')[0]
from urllib.parse import urlparse, parse_qs
# Parse the Databricks instance URL to extract the organization ID
parsed_url = urlparse(formatted_title)
query_params = parse_qs(parsed_url.query)
organization_id = query_params.get("o", [""])[0]
# Generate the Databricks Job URL
job_url = f"http://{formatted_title}/?o={organization_id}#job/{job_id}"
st.success(f"Pipeline deployed successfully! [{job_url}]({job_url}) πŸš€")
else:
st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="🚫")
st.markdown("""
<style>
/* Add a large bottom padding to the main content */
.main .block-container {
padding-bottom: 1000px; /* Adjust this value as needed */
}
</style>
""", unsafe_allow_html=True)