upload-label / src /streamlit_app.py
Tomatillo's picture
Update src/streamlit_app.py
216cc4c verified
import json
import streamlit as st
from segments import SegmentsClient
def update_label_on_platform(target_uuid: str, label_data: dict, user_api_key: str, labelset: str = "ground-truth") -> str:
"""
Update the label on Segments.ai for the given target_uuid with the provided label_data.
This overwrites the current label on the platform.
Depending on the structure of the uploaded label file, we determine whether the label data is nested
under an "attributes" key. If it is, we use that; otherwise, we assume the entire JSON is the label data.
Parameters:
target_uuid (str): The UUID of the target sample.
label_data (dict): The label data (JSON) to upload.
user_api_key (str): The API key provided by the user.
Returns:
str: A success message upon successful update.
"""
client = SegmentsClient(user_api_key)
try:
# Check if the JSON has an "attributes" key.
if isinstance(label_data, dict) and "attributes" in label_data:
# If your label JSON is nested under "attributes", use this:
# client.update_label(target_uuid, labelset=labelset, attributes=label_data["attributes"])
attributes_to_update = label_data["attributes"]
else:
# If your label JSON is flat, use this:
# client.update_label(target_uuid, labelset=labelset, attributes=label_data)
attributes_to_update = label_data
client.update_label(target_uuid, labelset=labelset, attributes=attributes_to_update)
return "Label updated successfully on Segments.ai."
except Exception as e:
raise Exception(f"Error updating target label on Segments.ai: {e}")
# ---------------------- Streamlit UI ----------------------
st.title("Upload and Overwrite Label on Segments.ai")
st.markdown(
"Select a label file (JSON) from your local machine, enter your API key, and specify the target UUID where the label should be updated."
)
user_api_key = st.text_input("Enter your API key", type="password", value="")
uploaded_file = st.file_uploader("Choose a label file", type=["json"])
target_uuid = st.text_input("Target UUID", value="")
if st.button("Upload Label"):
if not user_api_key:
st.error("Please enter your API key.")
elif uploaded_file is None:
st.error("Please upload a label file.")
elif not target_uuid:
st.error("Please enter a target UUID.")
else:
try:
# Load label data from the uploaded file.
label_data = json.load(uploaded_file)
# Update label on the platform using the user-provided API key.
result_message = update_label_on_platform(target_uuid, label_data, user_api_key)
st.success(result_message)
except Exception as e:
st.error(f"An error occurred: {e}")