|
from __future__ import annotations as _annotations |
|
|
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Any |
|
|
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
from httpx import AsyncClient |
|
from pydantic_ai import Agent, ModelRetry, RunContext |
|
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn |
|
|
|
load_dotenv() |
|
|
|
|
|
@dataclass |
|
class Deps: |
|
client: AsyncClient |
|
weather_api_key: str | None |
|
geo_api_key: str | None |
|
|
|
|
|
weather_agent = Agent( |
|
"openai:gpt-4o", |
|
system_prompt="You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.", |
|
deps_type=Deps, |
|
retries=2, |
|
) |
|
|
|
|
|
@weather_agent.tool |
|
async def get_lat_lng( |
|
ctx: RunContext[Deps], location_description: str |
|
) -> dict[str, float]: |
|
"""Get the latitude and longitude of a location. |
|
Args: |
|
ctx: The context. |
|
location_description: A description of a location. |
|
""" |
|
if ctx.deps.geo_api_key is None: |
|
|
|
return {"lat": 51.1, "lng": -0.1} |
|
|
|
params = { |
|
"q": location_description, |
|
"api_key": ctx.deps.geo_api_key, |
|
} |
|
r = await ctx.deps.client.get("https://geocode.maps.co/search", params=params) |
|
r.raise_for_status() |
|
data = r.json() |
|
|
|
if data: |
|
return {"lat": data[0]["lat"], "lng": data[0]["lon"]} |
|
else: |
|
raise ModelRetry("Could not find the location") |
|
|
|
|
|
@weather_agent.tool |
|
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]: |
|
"""Get the weather at a location. |
|
Args: |
|
ctx: The context. |
|
lat: Latitude of the location. |
|
lng: Longitude of the location. |
|
""" |
|
if ctx.deps.weather_api_key is None: |
|
|
|
return {"temperature": "21 °C", "description": "Sunny"} |
|
|
|
params = { |
|
"apikey": ctx.deps.weather_api_key, |
|
"location": f"{lat},{lng}", |
|
"units": "metric", |
|
} |
|
r = await ctx.deps.client.get( |
|
"https://api.tomorrow.io/v4/weather/realtime", params=params |
|
) |
|
r.raise_for_status() |
|
data = r.json() |
|
|
|
values = data["data"]["values"] |
|
|
|
code_lookup = { |
|
1000: "Clear, Sunny", |
|
1100: "Mostly Clear", |
|
1101: "Partly Cloudy", |
|
1102: "Mostly Cloudy", |
|
1001: "Cloudy", |
|
2000: "Fog", |
|
2100: "Light Fog", |
|
4000: "Drizzle", |
|
4001: "Rain", |
|
4200: "Light Rain", |
|
4201: "Heavy Rain", |
|
5000: "Snow", |
|
5001: "Flurries", |
|
5100: "Light Snow", |
|
5101: "Heavy Snow", |
|
6000: "Freezing Drizzle", |
|
6001: "Freezing Rain", |
|
6200: "Light Freezing Rain", |
|
6201: "Heavy Freezing Rain", |
|
7000: "Ice Pellets", |
|
7101: "Heavy Ice Pellets", |
|
7102: "Light Ice Pellets", |
|
8000: "Thunderstorm", |
|
} |
|
return { |
|
"temperature": f'{values["temperatureApparent"]:0.0f}°C', |
|
"description": code_lookup.get(values["weatherCode"], "Unknown"), |
|
} |
|
|
|
|
|
TOOL_TO_DISPLAY_NAME = {"get_lat_lng": "Geocoding API", "get_weather": "Weather API"} |
|
|
|
client = AsyncClient() |
|
weather_api_key = os.getenv("WEATHER_API_KEY") |
|
|
|
geo_api_key = os.getenv("GEO_API_KEY") |
|
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key) |
|
|
|
|
|
async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list): |
|
chatbot.append({"role": "user", "content": prompt}) |
|
yield gr.Textbox(interactive=False, value=""), chatbot, gr.skip() |
|
async with weather_agent.run_stream( |
|
prompt, deps=deps, message_history=past_messages |
|
) as result: |
|
for message in result.new_messages(): |
|
past_messages.append(message) |
|
if isinstance(message, ModelStructuredResponse): |
|
for call in message.calls: |
|
gr_message = { |
|
"role": "assistant", |
|
"content": "", |
|
"metadata": { |
|
"title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}", |
|
"id": call.tool_id, |
|
}, |
|
} |
|
chatbot.append(gr_message) |
|
if isinstance(message, ToolReturn): |
|
for gr_message in chatbot: |
|
if gr_message.get("metadata", {}).get("id", "") == message.tool_id: |
|
gr_message["content"] = f"Output: {json.dumps(message.content)}" |
|
yield gr.skip(), chatbot, gr.skip() |
|
chatbot.append({"role": "assistant", "content": ""}) |
|
async for message in result.stream_text(): |
|
chatbot[-1]["content"] = message |
|
yield gr.skip(), chatbot, gr.skip() |
|
data = await result.get_data() |
|
past_messages.append(ModelTextResponse(content=data)) |
|
yield gr.Textbox(interactive=True), gr.skip(), past_messages |
|
|
|
|
|
async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData): |
|
new_history = chatbot[: retry_data.index] |
|
previous_prompt = chatbot[retry_data.index]["content"] |
|
past_messages = past_messages[: retry_data.index] |
|
async for update in stream_from_agent(previous_prompt, new_history, past_messages): |
|
yield update |
|
|
|
|
|
def undo(chatbot, past_messages: list, undo_data: gr.UndoData): |
|
new_history = chatbot[: undo_data.index] |
|
past_messages = past_messages[: undo_data.index] |
|
return chatbot[undo_data.index]["content"], new_history, past_messages |
|
|
|
|
|
def select_data(message: gr.SelectData) -> str: |
|
return message.value["text"] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%"> |
|
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto"> |
|
<div> |
|
<h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1> |
|
<h3 style="margin: 0 0 0.5rem 0"> |
|
This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast. |
|
</h3> |
|
<h3 style="margin: 0"> |
|
Feel free to ask for help with any other questions you have about your trip! |
|
</h3> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
past_messages = gr.State([]) |
|
chatbot = gr.Chatbot( |
|
label="Packing Assistant", |
|
type="messages", |
|
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"), |
|
examples=[ |
|
{"text": "I am going to Paris for the holidays, what should I pack?"}, |
|
{"text": "I am going to Tokyo this week."}, |
|
], |
|
) |
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
lines=1, |
|
show_label=False, |
|
placeholder="I am planning a trip to Miami, what should I pack?", |
|
) |
|
generation = prompt.submit( |
|
stream_from_agent, |
|
inputs=[prompt, chatbot, past_messages], |
|
outputs=[prompt, chatbot, past_messages], |
|
) |
|
chatbot.example_select(select_data, None, [prompt]) |
|
chatbot.retry( |
|
handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages] |
|
) |
|
chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |