Spaces:
Running
Running
import json | |
from typing import Any | |
import streamlit as st | |
import time | |
from surf_spot_finder.tools import ( | |
driving_hours_to_meters, | |
get_area_lat_lon, | |
get_surfing_spots, | |
get_wave_forecast, | |
get_wind_forecast, | |
) | |
from surf_spot_finder.config import Config | |
from any_agent import AgentConfig, AnyAgent, TracingConfig | |
from any_agent.tracing.trace import AgentTrace, TotalTokenUseAndCost | |
from any_agent.tracing.otel_types import StatusCode | |
from any_agent.evaluation import evaluate, TraceEvaluationResult | |
async def run_agent(user_inputs: dict[str, Any]): | |
st.markdown("### π Running Surf Spot Finder...") | |
if "huggingface" in user_inputs["model_id"]: | |
model_args = { | |
"extra_headers": {"X-HF-Bill-To": "mozilla-ai"}, | |
"temperature": 0.0, | |
} | |
else: | |
model_args = {} | |
agent_config = AgentConfig( | |
model_id=user_inputs["model_id"], | |
model_args=model_args, | |
tools=[ | |
get_wind_forecast, | |
get_wave_forecast, | |
get_area_lat_lon, | |
get_surfing_spots, | |
driving_hours_to_meters, | |
], | |
) | |
config = Config( | |
location=user_inputs["location"], | |
max_driving_hours=user_inputs["max_driving_hours"], | |
date=user_inputs["date"], | |
framework=user_inputs["framework"], | |
main_agent=agent_config, | |
managed_agents=[], | |
evaluation_cases=[user_inputs.get("evaluation_case")] | |
if user_inputs.get("evaluation_case") | |
else None, | |
) | |
agent = await AnyAgent.create_async( | |
agent_framework=config.framework, | |
agent_config=config.main_agent, | |
managed_agents=config.managed_agents, | |
tracing=TracingConfig(console=True, cost_info=True), | |
) | |
query = config.input_prompt_template.format( | |
LOCATION=config.location, | |
MAX_DRIVING_HOURS=config.max_driving_hours, | |
DATE=config.date, | |
) | |
st.markdown("#### π Query") | |
st.code(query, language="text") | |
start_time = time.time() | |
with st.spinner("π€ Analyzing surf spots..."): | |
agent_trace: AgentTrace = await agent.run_async(query) | |
agent.exit() | |
end_time = time.time() | |
execution_time = end_time - start_time | |
cost: TotalTokenUseAndCost = agent_trace.get_total_cost() | |
st.markdown("### π Results") | |
time_col, cost_col, tokens_col = st.columns(3) | |
with time_col: | |
st.info(f"β±οΈ Execution Time: {execution_time:.2f} seconds") | |
with cost_col: | |
st.info(f"π° Estimated Cost: ${cost.total_cost:.6f}") | |
with tokens_col: | |
st.info(f"π¦ Total Tokens: {cost.total_tokens:,}") | |
st.markdown("#### Final Output") | |
st.info(agent_trace.final_output) | |
# Display the agent trace in a more organized way | |
with st.expander("### π§© Agent Trace"): | |
for span in agent_trace.spans: | |
# Header with name and status | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
st.markdown(f"**{span.name}**") | |
if span.attributes: | |
# st.json(span.attributes, expanded=False) | |
if "input.value" in span.attributes: | |
try: | |
input_value = json.loads(span.attributes["input.value"]) | |
if isinstance(input_value, list) and len(input_value) > 0: | |
st.write(f"Input: {input_value[-1]}") | |
else: | |
st.write(f"Input: {input_value}") | |
except: # noqa: E722 | |
st.write(f"Input: {span.attributes['input.value']}") | |
if "output.value" in span.attributes: | |
try: | |
output_value = json.loads(span.attributes["output.value"]) | |
if isinstance(output_value, list) and len(output_value) > 0: | |
st.write(f"Output: {output_value[-1]}") | |
else: | |
st.write(f"Output: {output_value}") | |
except: # noqa: E722 | |
st.write(f"Output: {span.attributes['output.value']}") | |
with col2: | |
status_color = ( | |
"green" if span.status.status_code == StatusCode.OK else "red" | |
) | |
st.markdown( | |
f"<span style='color: {status_color}'>β {span.status.status_code.name}</span>", | |
unsafe_allow_html=True, | |
) | |
if config.evaluation_cases is not None: | |
assert ( | |
len(config.evaluation_cases) == 1 | |
), "Only one evaluation case is supported in the demo" | |
st.markdown("### π Evaluation Results") | |
with st.spinner("Evaluating results..."): | |
case = config.evaluation_cases[0] | |
result: TraceEvaluationResult = evaluate( | |
evaluation_case=case, | |
trace=agent_trace, | |
agent_framework=config.framework, | |
) | |
all_results = ( | |
result.checkpoint_results | |
+ result.hypothesis_answer_results | |
+ result.direct_results | |
) | |
# Create columns for better layout | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown("#### Criteria Results") | |
for checkpoint in all_results: | |
if checkpoint.passed: | |
st.success(f"β {checkpoint.criteria}") | |
else: | |
st.error(f"β {checkpoint.criteria}") | |
with col2: | |
st.markdown("#### Overall Score") | |
total_points = sum([result.points for result in all_results]) | |
if total_points == 0: | |
msg = "Total points is 0, cannot calculate score." | |
raise ValueError(msg) | |
passed_points = sum( | |
[result.points for result in all_results if result.passed] | |
) | |
# Create a nice score display | |
st.markdown(f"### {passed_points}/{total_points}") | |
percentage = (passed_points / total_points) * 100 | |
st.progress(percentage / 100) | |
st.markdown(f"**{percentage:.1f}%**") | |