Qiddiya-Smart-Guide / src /agents.py
munals's picture
Upload 33 files
214f910 verified
from __future__ import annotations
import logging
from typing import Dict, List, TypedDict, Optional
from .data_loader import (
load_attractions,
load_wait_history,
get_attraction_by_name,
)
from .models import ItineraryPlan, ItineraryStop, UserRequest
from .routing import greedy_schedule, order_by_nearest_neighbor
from .llm import generate_narrative_logs
logger = logging.getLogger("qiddiya.agents")
class QiddiyaState(TypedDict, total=False):
user_request: Dict
logs: List[str]
wait_time_forecast: Dict[str, int] | None
raw_plan: Dict | None
final_plan: Dict | None
critique: str | None
reflection_round: int
refined_attraction_ids: List[str] | None # after reflection: only schedule these (in order)
def _append_log(state: QiddiyaState, message: str) -> None:
logs = state.get("logs", [])
logs.append(message)
state["logs"] = logs
logger.info(message)
def orchestrator_node(state: QiddiyaState) -> QiddiyaState:
_append_log(state, "Orchestrator: starting planning workflow.")
return state
def wait_time_predictor_node(state: QiddiyaState) -> QiddiyaState:
_append_log(state, "Wait-Time Agent: computing simple demand-based forecast.")
history = load_wait_history()
forecast: Dict[str, int] = {}
for attr_id in history["attraction_id"].unique():
subset = history[history["attraction_id"] == attr_id]
forecast[attr_id] = int(subset["wait_minutes"].mean())
state["wait_time_forecast"] = forecast
_append_log(
state,
f"Wait-Time Agent: produced forecast for {len(forecast)} attractions.",
)
return state
def route_optimizer_node(state: QiddiyaState) -> QiddiyaState:
req = UserRequest.model_validate(state["user_request"])
attractions = load_attractions()
node_lookup = {a.id: a.node_id for a in attractions.values()}
# If reflection trimmed the plan, only schedule those attractions (in geographic order) this pass
refined = state.get("refined_attraction_ids") or []
if refined:
_append_log(
state,
f"Route Agent: re-building schedule from refined list ({len(refined)} attractions) after reflection.",
)
candidate = [aid for aid in refined if aid in attractions]
ordered_ids = order_by_nearest_neighbor(candidate, node_lookup)
state["refined_attraction_ids"] = [] # clear so we don't persist for future runs
else:
_append_log(state, "Route Agent: building schedule by geography (nearest-neighbor from hub).")
must_do_ids: List[str] = []
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids.append(attr.id)
# Only schedule what the user asked for: no extra attractions
if must_do_ids:
candidate = must_do_ids
else:
# No must-dos selected: suggest a few by intensity
other_ids = [
a.id for a in attractions.values()
if req.intensity_preference <= 3 or a.thrill_level <= 3
]
candidate = other_ids[:8] # cap suggestions when nothing selected
# Order by space (minimize walking), not by user selection order
ordered_ids = order_by_nearest_neighbor(candidate, node_lookup)
wait_lookup = state.get("wait_time_forecast") or {}
start_hour, start_minute = [int(p) for p in req.start_time.split(":")]
start_minutes = start_hour * 60 + start_minute
stops_raw, total_wait, total_walk = greedy_schedule(
attractions_order=ordered_ids,
start_time_minutes=start_minutes,
walking_weight=float(6 - req.walking_tolerance),
wait_time_lookup=wait_lookup,
node_lookup=node_lookup,
)
stops = [ItineraryStop.model_validate(s).model_dump() for s in stops_raw]
enjoyment = min(10.0, float(len(stops)) * (0.5 + 0.1 * req.intensity_preference))
must_do_ids_for_coverage: List[str] = []
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids_for_coverage.append(attr.id)
must_do_set = set(must_do_ids_for_coverage)
for s in stops:
s["is_suggested"] = s["attraction_id"] not in must_do_set
coverage = float(len([s for s in stops if s.get("attraction_id") in must_do_set])) / max(len(req.must_do_attractions), 1)
plan = ItineraryPlan(
visit_date=req.visit_date,
total_wait_minutes=total_wait,
total_walking_m=total_walk,
coverage_score=coverage,
enjoyment_score=enjoyment,
stops=[ItineraryStop.model_validate(s) for s in stops],
logs=state.get("logs", []),
)
state["raw_plan"] = plan.model_dump()
_append_log(
state,
f"Route Agent: produced raw plan with {len(stops)} stops, "
f"wait={total_wait}min, walk={total_walk}m.",
)
return state
def experience_writer_node(state: QiddiyaState) -> QiddiyaState:
_append_log(state, "Guide Agent: generating visitor-friendly annotations (Groq if available).")
raw_plan = ItineraryPlan.model_validate(state["raw_plan"])
req = UserRequest.model_validate(state["user_request"])
narrative_logs = generate_narrative_logs(raw_plan, req)
annotated_logs = list(state.get("logs", [])) + narrative_logs
final_plan = raw_plan.model_copy(update={"logs": annotated_logs})
state["final_plan"] = final_plan.model_dump()
_append_log(
state,
"Guide Agent: added narrative guidance to itinerary logs.",
)
return state
def critic_node(state: QiddiyaState) -> QiddiyaState:
"""
Critic Agent: validates the plan against fixed criteria.
Reflection runs only when at least one violation is found and reflection_round < 2.
Criteria (violations):
1. Must-do coverage: every requested must-do attraction must appear in the plan.
2. Time window: total_wait_minutes must not exceed (end_time - start_time).
3. Walking tolerance: total_walking_m must not exceed max_walk
(max_walk = 3500 + 800 * (walking_tolerance - 1)).
4. Enjoyment: enjoyment_score must be at least 3.0 (triggers reflection to add nearby stops).
"""
_append_log(state, "Critic Agent: validating constraints and safety.")
plan = ItineraryPlan.model_validate(state["final_plan"])
req = UserRequest.model_validate(state["user_request"])
violations: List[str] = []
must_do_ids: List[str] = []
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids.append(attr.id)
covered = {s.attraction_id for s in plan.stops}
missing_must_do = [mid for mid in must_do_ids if mid not in covered]
if missing_must_do:
n = len(missing_must_do)
violations.append(
f"{n} must-do attraction(s) are not scheduled."
if n > 1
else "One must-do attraction is not scheduled."
)
window_minutes = (
int(req.end_time.split(":")[0]) * 60
+ int(req.end_time.split(":")[1])
- int(req.start_time.split(":")[0]) * 60
- int(req.start_time.split(":")[1])
)
if plan.total_wait_minutes > window_minutes:
violations.append("Total wait time exceeds visit window.")
# Walking cap: total walking must not exceed tolerance-based limit
max_walk = 3500 + 800 * (req.walking_tolerance - 1)
if plan.total_walking_m > max_walk:
violations.append("Planned walking distance exceeds user tolerance.")
# Low enjoyment: trigger reflection so we can suggest adding nearby optional stops
if plan.enjoyment_score < 3.0:
violations.append("Enjoyment score is low.")
reflection_round = int(state.get("reflection_round", 0))
if violations and reflection_round < 2:
_append_log(
state,
f"Critic Agent: found violations, triggering reflection round {reflection_round + 1}.",
)
state["critique"] = "; ".join(violations)
state["reflection_round"] = reflection_round + 1
else:
if violations:
_append_log(
state,
"Critic Agent: violations remain but reflection limit reached.",
)
else:
_append_log(state, "Critic Agent: plan accepted with no violations.")
state["critique"] = "; ".join(violations) if violations else ""
return state
def reflection_node(state: QiddiyaState) -> QiddiyaState:
"""
Reflection: adjusts the plan based on the Critic's critique.
- If the critique mentions "walking"/"distance": trim optional long-walk stops.
- If the critique mentions "enjoyment": add up to 2 nearby optional attractions to boost the score.
"""
critique: Optional[str] = state.get("critique")
if not critique:
_append_log(state, "Orchestrator: no critique to address, skipping reflection.")
return state
_append_log(
state,
f"Orchestrator: applying heuristic reflection based on critique: {critique}",
)
plan = ItineraryPlan.model_validate(state["final_plan"])
req = UserRequest.model_validate(state["user_request"])
attractions = load_attractions()
node_lookup = {a.id: a.node_id for a in attractions.values()}
c_lower = critique.lower()
must_do_ids_set: set = set()
for name in req.must_do_attractions:
attr = get_attraction_by_name(name)
if attr:
must_do_ids_set.add(attr.id)
# --- Enjoyment: add 1–2 nearby optional stops to boost the plan ---
if "enjoyment" in c_lower and len(plan.stops) < 6:
from .routing import shortest_path_distance
current_ids = [s.attraction_id for s in plan.stops]
last_node = plan.stops[-1].node_id if plan.stops else "HUB"
optional_ids = [
a.id for a in attractions.values()
if a.id not in current_ids
and (req.intensity_preference > 3 or a.thrill_level <= 3)
]
# Nearest optional attractions from the last stop (up to 2)
by_dist = sorted(
optional_ids,
key=lambda aid: shortest_path_distance(last_node, node_lookup[aid]),
)
# Add enough so we reach at least 5 stops (enjoyment >= 3.0 for low intensity)
need = max(0, 5 - len(plan.stops))
add_count = min(3, need) # add up to 3, cap total at 6
to_add = by_dist[: add_count] if add_count > 0 else []
if to_add:
refined_ids = current_ids + to_add
state["refined_attraction_ids"] = refined_ids
state["raw_plan"] = plan.model_dump()
_append_log(
state,
f"Orchestrator: reflection adding {len(to_add)} nearby attraction(s) to boost enjoyment.",
)
else:
state["refined_attraction_ids"] = current_ids
_append_log(state, "Orchestrator: no optional attractions to add for enjoyment.")
return state
# --- Walking: trim optional long-walk stops ---
filtered_stops: List[ItineraryStop] = []
for stop in plan.stops:
if stop.attraction_id in must_do_ids_set:
filtered_stops.append(stop)
continue
if (
("walking" in c_lower or "distance" in c_lower)
and req.walking_tolerance <= 2
and stop.walking_distance_m > 300
):
continue
filtered_stops.append(stop)
final_plan = plan.model_copy(update={"stops": filtered_stops})
state["raw_plan"] = final_plan.model_dump()
state["refined_attraction_ids"] = [s.attraction_id for s in filtered_stops]
_append_log(
state,
f"Orchestrator: reflection adjusted plan to {len(filtered_stops)} stops (removed long-walk stops).",
)
return state