github-actions[bot]
Sync with https://github.com/mozilla-ai/surf-spot-finder
89845e5
from datetime import datetime, timedelta
import json
import requests
import streamlit as st
from any_agent import AgentFramework
from any_agent.tracing.trace import _is_tracing_supported
from any_agent.evaluation import EvaluationCase
from any_agent.evaluation.schemas import CheckpointCriteria
import pandas as pd
from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
from pydantic import BaseModel, ConfigDict
class UserInputs(BaseModel):
model_config = ConfigDict(extra="forbid")
model_id: str
location: str
max_driving_hours: int
date: datetime
framework: str
evaluation_case: EvaluationCase
run_evaluation: bool
@st.cache_resource
def get_area(area_name: str) -> dict:
"""Get the area from Nominatim.
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
Args:
area_name (str): The name of the area.
Returns:
dict: The area found.
"""
response = requests.get(
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
headers={"User-Agent": "Mozilla/5.0"},
timeout=5,
)
response.raise_for_status()
response_json = json.loads(response.content.decode())
return response_json
def get_user_inputs() -> UserInputs:
default_val = "Los Angeles California, US"
location = st.text_input("Enter a location", value=default_val)
if location:
location_check = get_area(location)
if not location_check:
st.error("❌ Invalid location")
max_driving_hours = st.number_input(
"Enter the maximum driving hours", min_value=1, value=2
)
col_date, col_time = st.columns([2, 1])
with col_date:
date = st.date_input(
"Select a date in the future", value=datetime.now() + timedelta(days=1)
)
with col_time:
# default to 9am
time = st.selectbox(
"Select a time",
[datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
index=9,
)
date = datetime.combine(date, time)
supported_frameworks = [
framework for framework in AgentFramework if _is_tracing_supported(framework)
]
framework = st.selectbox(
"Select the agent framework to use",
supported_frameworks,
index=2,
format_func=lambda x: x.name,
)
model_id = st.selectbox(
"Select the model to use",
MODEL_OPTIONS,
index=1,
format_func=lambda x: "/".join(x.split("/")[-3:]),
)
# Add evaluation case section
with st.expander("Custom Evaluation"):
evaluation_model_id = st.selectbox(
"Select the model to use for LLM-as-a-Judge evaluation",
MODEL_OPTIONS,
index=2,
format_func=lambda x: "/".join(x.split("/")[-3:]),
)
evaluation_case = DEFAULT_EVALUATION_CASE
evaluation_case.llm_judge = evaluation_model_id
# make this an editable json section
# convert the checkpoints to a df series so that it can be edited
checkpoints = evaluation_case.checkpoints
checkpoints_df = pd.DataFrame(
[checkpoint.model_dump() for checkpoint in checkpoints]
)
checkpoints_df = st.data_editor(
checkpoints_df,
column_config={
"points": st.column_config.NumberColumn(label="Points"),
"criteria": st.column_config.TextColumn(label="Criteria"),
},
hide_index=True,
num_rows="dynamic",
)
# for each checkpoint, convert it back to a CheckpointCriteria object
new_ckpts = []
# don't let a user add more than 20 checkpoints
if len(checkpoints_df) > 20:
st.error(
"You can only add up to 20 checkpoints for the purpose of this demo."
)
checkpoints_df = checkpoints_df[:20]
for _, row in checkpoints_df.iterrows():
if row["criteria"] == "":
continue
try:
# Don't let people write essays for criteria in this demo
if len(row["criteria"].split(" ")) > 100:
raise ValueError("Criteria is too long")
new_crit = CheckpointCriteria(
criteria=row["criteria"], points=row["points"]
)
new_ckpts.append(new_crit)
except Exception as e:
st.error(f"Error creating checkpoint: {e}")
evaluation_case.checkpoints = new_ckpts
return UserInputs(
model_id=model_id,
location=location,
max_driving_hours=max_driving_hours,
date=date,
framework=framework,
evaluation_case=evaluation_case,
run_evaluation=st.checkbox("Run Evaluation", value=True),
)