wanderlust / wanderlust.py
maartenbreddels's picture
feat: adaptive label for InputText
56ce50a
import json
import os
import time
from pathlib import Path
import ipyleaflet
from openai import NotFoundError, OpenAI
from openai.types.beta import Thread
import solara
HERE = Path(__file__).parent
center_default = (0, 0)
zoom_default = 2
messages = solara.reactive([])
zoom_level = solara.reactive(zoom_default)
center = solara.reactive(center_default)
markers = solara.reactive([])
url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4-1106-preview"
app_style = (HERE / "style.css").read_text()
# Declare tools for openai assistant to use
tools = [
{
"type": "function",
"function": {
"name": "update_map",
"description": "Update map to center on a particular location",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to center the map on",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to center the map on",
},
"zoom": {
"type": "integer",
"description": "Zoom level of the map",
},
},
"required": ["longitude", "latitude", "zoom"],
},
},
},
{
"type": "function",
"function": {
"name": "add_marker",
"description": "Add marker to the map",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to the marker",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to the marker",
},
"label": {
"type": "string",
"description": "Text to display on the marker",
},
},
"required": ["longitude", "latitude", "label"],
},
},
},
]
def update_map(longitude, latitude, zoom):
center.set((latitude, longitude))
zoom_level.set(zoom)
return "Map updated"
def add_marker(longitude, latitude, label):
markers.set(markers.value + [{"location": (latitude, longitude), "label": label}])
return "Marker added"
functions = {
"update_map": update_map,
"add_marker": add_marker,
}
def assistant_tool_call(tool_call):
# actually executes the tool call the OpenAI assistant wants to perform
function = tool_call.function
name = function.name
arguments = json.loads(function.arguments)
return_value = functions[name](**arguments)
tool_outputs = {
"tool_call_id": tool_call.id,
"output": return_value,
}
return tool_outputs
@solara.component
def Map():
ipyleaflet.Map.element( # type: ignore
zoom=zoom_level.value,
center=center.value,
scroll_wheel_zoom=True,
layers=[
ipyleaflet.TileLayer.element(url=url),
*[
ipyleaflet.Marker.element(location=k["location"], draggable=False)
for k in markers.value
],
],
)
@solara.component
def ChatMessage(message):
with solara.Row(style={"align-items": "flex-start"}):
# Catch "messages" that are actually tool calls
if isinstance(message, dict):
icon = "mdi-map" if message["output"] == "Map updated" else "mdi-map-marker"
solara.v.Icon(children=[icon], style_="padding-top: 10px;")
solara.Markdown(message["output"])
elif message.role == "user":
solara.Text(message.content[0].text.value, style={"font-weight": "bold;"})
elif message.role == "assistant":
if message.content[0].text.value:
solara.v.Icon(
children=["mdi-compass-outline"], style_="padding-top: 10px;"
)
solara.Markdown(message.content[0].text.value)
elif message.content.tool_calls:
solara.v.Icon(children=["mdi-map"], style_="padding-top: 10px;")
solara.Markdown("*Calling map functions*")
else:
solara.v.Icon(
children=["mdi-compass-outline"], style_="padding-top: 10px;"
)
solara.Preformatted(repr(message))
else:
solara.v.Icon(children=["mdi-compass-outline"], style_="padding-top: 10px;")
solara.Preformatted(repr(message))
@solara.component
def ChatBox(children=[]):
# this uses a flexbox with column-reverse to reverse the order of the messages
# if we now also reverse the order of the messages, we get the correct order
# but the scroll position is at the bottom of the container automatically
with solara.Column(style={"flex-grow": "1"}):
solara.Style(
"""
.chat-box > :last-child{
padding-top: 7.5vh;
}
"""
)
# The height works effectively as `min-height`, since flex will grow the container to fill the available space
solara.Column(
style={
"flex-grow": "1",
"overflow-y": "auto",
"height": "100px",
"flex-direction": "column-reverse",
},
classes=["chat-box"],
children=list(reversed(children)),
)
@solara.component
def ChatInterface():
prompt = solara.use_reactive("")
run_id: solara.Reactive[str] = solara.use_reactive(None)
# Create a thread to hold the conversation only once when this component is created
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
def add_message(value: str):
if value == "":
return
prompt.set("")
new_message = openai.beta.threads.messages.create(
thread_id=thread.id, content=value, role="user"
)
messages.set([*messages.value, new_message])
# this creates a new run for the thread
# also also triggers a rerender (since run_id.value changes)
# which will trigger the poll function blow to start in a thread
run_id.value = openai.beta.threads.runs.create(
thread_id=thread.id,
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
tools=tools,
).id
def poll():
if not run_id.value:
return
completed = False
while not completed:
try:
run = openai.beta.threads.runs.retrieve(
run_id.value, thread_id=thread.id
)
# Above will raise NotFoundError when run creation is still in progress
except NotFoundError:
continue
if run.status == "requires_action":
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
tool_output = assistant_tool_call(tool_call)
tool_outputs.append(tool_output)
messages.set([*messages.value, tool_output])
openai.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run_id.value,
tool_outputs=tool_outputs,
)
if run.status == "completed":
messages.set(
[
*messages.value,
openai.beta.threads.messages.list(thread.id).data[0],
]
)
run_id.set(None)
completed = True
time.sleep(0.1)
# run/restart a thread any time the run_id changes
result = solara.use_thread(poll, dependencies=[run_id.value])
# Create DOM for chat interface
with solara.Column(classes=["chat-interface"]):
if len(messages.value) > 0:
with ChatBox():
for message in messages.value:
ChatMessage(message)
with solara.Column():
solara.InputText(
label="Where do you want to go?"
if len(messages.value) == 0
else "Ask more question here",
value=prompt,
style={"flex-grow": "1"},
on_value=add_message,
disabled=result.state == solara.ResultState.RUNNING,
)
solara.ProgressLinear(result.state == solara.ResultState.RUNNING)
if result.state == solara.ResultState.ERROR:
solara.Error(repr(result.error))
@solara.component
def Page():
with solara.Column(
classes=["ui-container"],
gap="5vh",
):
with solara.Row(justify="space-between"):
with solara.Row(gap="10px", style={"align-items": "center"}):
solara.v.Icon(children=["mdi-compass-rose"], size="36px")
solara.HTML(
tag="h2",
unsafe_innerHTML="Wanderlust",
style={"display": "inline-block"},
)
with solara.Row(
gap="30px",
style={"align-items": "center"},
classes=["link-container"],
justify="end",
):
with solara.Row(gap="5px", style={"align-items": "center"}):
solara.Text("Source Code:", style="font-weight: bold;")
# target="_blank" links are still easiest to do via ipyvuetify
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://github.com/widgetti/wanderlust",
"title": "Wanderlust Source Code",
"target": "_blank",
},
):
solara.v.Icon(children=["mdi-github-circle"])
with solara.Row(gap="5px", style={"align-items": "center"}):
solara.Text("Powered by Solara:", style="font-weight: bold;")
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://solara.dev/",
"title": "Solara",
"target": "_blank",
},
):
solara.HTML(
tag="img",
attributes={
"src": "https://solara.dev/static/public/logo.svg",
"width": "24px",
},
)
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://github.com/widgetti/solara",
"title": "Solara Source Code",
"target": "_blank",
},
):
solara.v.Icon(children=["mdi-github-circle"])
with solara.Row(
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
):
ChatInterface()
with solara.Column(classes=["map-container"]):
Map()
solara.Style(app_style)