Iisakki Rotko commited on
Commit
f228ce0
1 Parent(s): 1c7799f

feat: working with assistant, UI, and others

Browse files
Files changed (2) hide show
  1. icons/explore.svg +1 -0
  2. wanderlust.py +145 -62
icons/explore.svg ADDED
wanderlust.py CHANGED
@@ -2,7 +2,11 @@ import json
2
  import os
3
 
4
  import ipyleaflet
5
- import openai
 
 
 
 
6
 
7
  import solara
8
 
@@ -17,11 +21,11 @@ center = solara.reactive(center_default)
17
  markers = solara.reactive([])
18
 
19
  url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
20
- openai.api_key = os.getenv("OPENAI_API_KEY")
21
  model = "gpt-4-1106-preview"
22
 
23
 
24
- function_descriptions = [
25
  {
26
  "type": "function",
27
  "function": {
@@ -94,17 +98,15 @@ functions = {
94
 
95
 
96
  def ai_call(tool_call):
97
- function = tool_call["function"]
98
- name = function["name"]
99
- arguments = json.loads(function["arguments"])
100
  return_value = functions[name](**arguments)
101
- message = {
102
- "role": "tool",
103
- "tool_call_id": tool_call["id"],
104
- "name": tool_call["function"]["name"],
105
- "content": return_value,
106
  }
107
- return message
108
 
109
 
110
  @solara.component
@@ -129,39 +131,66 @@ def Map():
129
  @solara.component
130
  def ChatInterface():
131
  prompt = solara.use_reactive("")
 
 
 
 
132
 
133
  def add_message(value: str):
134
  if value == "":
135
  return
136
- messages.set(messages.value + [{"role": "user", "content": value}])
137
  prompt.set("")
138
-
139
- def ask():
140
- if not messages.value:
 
 
 
 
 
 
 
 
 
 
141
  return
142
- last_message = messages.value[-1]
143
- if last_message["role"] == "user" or last_message["role"] == "tool":
144
- completion = openai.ChatCompletion.create(
145
- model=model,
146
- messages=messages.value,
147
- # Add function calling
148
- tools=function_descriptions,
149
- tool_choice="auto",
150
- )
151
-
152
- output = completion.choices[0].message
153
- print("received", output)
154
  try:
155
- handled_messages = handle_message(output)
156
- messages.value = [*messages.value, output, *handled_messages]
157
-
158
- except Exception as e:
159
- print("errr", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  def handle_message(message):
162
  print("handle", message)
163
  messages = []
164
- if message["role"] == "assistant":
165
  tools_calls = message.get("tool_calls", [])
166
  for tool_call in tools_calls:
167
  messages.append(ai_call(tool_call))
@@ -173,38 +202,71 @@ def ChatInterface():
173
  handle_message(message)
174
 
175
  solara.use_effect(handle_initial, [])
176
- result = solara.use_thread(ask, dependencies=[messages.value])
177
  with solara.Column(
178
- style={"height": "100%", "width": "38vw", "justify-content": "center"},
 
 
 
 
 
179
  classes=["chat-interface"],
180
  ):
181
  if len(messages.value) > 0:
182
- with solara.Column(style={"flex-grow": "1", "overflow-y": "auto"}):
183
- for message in messages.value:
184
- if message["role"] == "user":
185
- solara.Text(
186
- message["content"], classes=["chat-message", "user-message"]
187
- )
188
- elif message["role"] == "assistant":
189
- if message["content"]:
190
- solara.Markdown(message["content"])
191
- elif message["tool_calls"]:
192
- solara.Markdown("*Calling map functions*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  else:
 
 
 
 
194
  solara.Preformatted(
195
  repr(message),
196
  classes=["chat-message", "assistant-message"],
197
  )
198
- elif message["role"] == "tool":
199
- pass # no need to display
200
- else:
201
- solara.Preformatted(
202
- repr(message), classes=["chat-message", "assistant-message"]
203
- )
204
- # solara.Text(message, classes=["chat-message"])
205
  with solara.Column():
206
  solara.InputText(
207
- label="Ask your ",
208
  value=prompt,
209
  style={"flex-grow": "1"},
210
  on_value=add_message,
@@ -234,26 +296,47 @@ def Page():
234
  messages.set(json.load(f))
235
  reset_ui()
236
 
237
- with solara.Column(style={"flex-grow": "1"}, gap=0):
238
- with solara.AppBar():
239
- solara.Button("Save", on_click=save)
240
- solara.Button("Load", on_click=load)
241
- solara.Button("Soft reset", on_click=reset_ui)
242
- with solara.Row(style={"height": "100%"}, justify="space-between"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  ChatInterface().key(f"chat-{reset_counter}")
244
- with solara.Column(style={"width": "58vw", "justify-content": "center"}):
245
  Map() # .key(f"map-{reset_counter}")
246
 
247
  solara.Style(
248
  """
249
  .jupyter-widgets.leaflet-widgets{
250
  height: 100%;
 
251
  }
252
  .solara-autorouter-content{
253
  display: flex;
254
  flex-direction: column;
255
  justify-content: stretch;
256
  }
 
 
 
 
 
257
  """
258
  )
259
 
 
2
  import os
3
 
4
  import ipyleaflet
5
+ from openai import OpenAI, NotFoundError
6
+ from openai.types.beta import Thread
7
+ from openai.types.beta.threads import Run
8
+
9
+ import time
10
 
11
  import solara
12
 
 
21
  markers = solara.reactive([])
22
 
23
  url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
24
+ openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
25
  model = "gpt-4-1106-preview"
26
 
27
 
28
+ tools = [
29
  {
30
  "type": "function",
31
  "function": {
 
98
 
99
 
100
  def ai_call(tool_call):
101
+ function = tool_call.function
102
+ name = function.name
103
+ arguments = json.loads(function.arguments)
104
  return_value = functions[name](**arguments)
105
+ tool_outputs = {
106
+ "tool_call_id": tool_call.id,
107
+ "output": return_value,
 
 
108
  }
109
+ return tool_outputs
110
 
111
 
112
  @solara.component
 
131
  @solara.component
132
  def ChatInterface():
133
  prompt = solara.use_reactive("")
134
+ run_id: solara.Reactive[str] = solara.use_reactive(None)
135
+
136
+ thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
137
+ print("thread id:", thread.id)
138
 
139
  def add_message(value: str):
140
  if value == "":
141
  return
 
142
  prompt.set("")
143
+ new_message = openai.beta.threads.messages.create(
144
+ thread_id=thread.id, content=value, role="user"
145
+ )
146
+ messages.set([*messages.value, new_message])
147
+ run_id.value = openai.beta.threads.runs.create(
148
+ thread_id=thread.id,
149
+ assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
150
+ tools=tools,
151
+ ).id
152
+ print("Run id:", run_id.value)
153
+
154
+ def poll():
155
+ if not run_id.value:
156
  return
157
+ completed = False
158
+ while not completed:
 
 
 
 
 
 
 
 
 
 
159
  try:
160
+ run = openai.beta.threads.runs.retrieve(
161
+ run_id.value, thread_id=thread.id
162
+ ) # When run is complete
163
+ print("run", run.status)
164
+ except NotFoundError:
165
+ print("run not found (Yet)")
166
+ continue
167
+ if run.status == "requires_action":
168
+ for tool_call in run.required_action.submit_tool_outputs.tool_calls:
169
+ tool_output = ai_call(tool_call)
170
+ openai.beta.threads.runs.submit_tool_outputs(
171
+ thread_id=thread.id,
172
+ run_id=run_id.value,
173
+ tool_outputs=[tool_output],
174
+ )
175
+ if run.status == "completed":
176
+ messages.set(
177
+ [
178
+ *messages.value,
179
+ openai.beta.threads.messages.list(thread.id).data[0],
180
+ ]
181
+ )
182
+ run_id.set(None)
183
+ completed = True
184
+ time.sleep(0.1)
185
+ retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
186
+ messages.set(retrieved_messages.data)
187
+
188
+ result = solara.use_thread(poll, dependencies=[run_id.value])
189
 
190
  def handle_message(message):
191
  print("handle", message)
192
  messages = []
193
+ if message.role == "assistant":
194
  tools_calls = message.get("tool_calls", [])
195
  for tool_call in tools_calls:
196
  messages.append(ai_call(tool_call))
 
202
  handle_message(message)
203
 
204
  solara.use_effect(handle_initial, [])
205
+ # result = solara.use_thread(ask, dependencies=[messages.value])
206
  with solara.Column(
207
+ style={
208
+ "height": "100%",
209
+ "width": "38vw",
210
+ "justify-content": "center",
211
+ "background": "linear-gradient(0deg, transparent 75%, white 100%);",
212
+ },
213
  classes=["chat-interface"],
214
  ):
215
  if len(messages.value) > 0:
216
+ # The height works effectively as `min-height`, since flex will grow the container to fill the available space
217
+ with solara.Column(
218
+ style={
219
+ "flex-grow": "1",
220
+ "overflow-y": "auto",
221
+ "height": "100px",
222
+ "flex-direction": "column-reverse",
223
+ }
224
+ ):
225
+ for message in reversed(messages.value):
226
+ with solara.Row(style={"align-items": "flex-start"}):
227
+ if message.role == "user":
228
+ solara.Text(
229
+ message.content[0].text.value,
230
+ classes=["chat-message", "user-message"],
231
+ )
232
+ assert len(message.content) == 1
233
+ elif message.role == "assistant":
234
+ if message.content[0].text.value:
235
+ solara.v.Icon(
236
+ children=["mdi-compass-outline"],
237
+ style_="padding-top: 10px;",
238
+ )
239
+ solara.Markdown(message.content[0].text.value)
240
+ elif message.content.tool_calls:
241
+ solara.v.Icon(
242
+ children=["mdi-map"],
243
+ style_="padding-top: 10px;",
244
+ )
245
+ solara.Markdown("*Calling map functions*")
246
+ else:
247
+ solara.v.Icon(
248
+ children=["mdi-compass-outline"],
249
+ style_="padding-top: 10px;",
250
+ )
251
+ solara.Preformatted(
252
+ repr(message),
253
+ classes=["chat-message", "assistant-message"],
254
+ )
255
+ elif message["role"] == "tool":
256
+ pass # no need to display
257
  else:
258
+ solara.v.Icon(
259
+ children=["mdi-compass-outline"],
260
+ style_="padding-top: 10px;",
261
+ )
262
  solara.Preformatted(
263
  repr(message),
264
  classes=["chat-message", "assistant-message"],
265
  )
266
+ # solara.Text(message, classes=["chat-message"])
 
 
 
 
 
 
267
  with solara.Column():
268
  solara.InputText(
269
+ label="Ask your question here",
270
  value=prompt,
271
  style={"flex-grow": "1"},
272
  on_value=add_message,
 
296
  messages.set(json.load(f))
297
  reset_ui()
298
 
299
+ with solara.Column(
300
+ style={
301
+ "height": "95vh",
302
+ "justify-content": "center",
303
+ "padding": "45px 50px 75px 50px",
304
+ },
305
+ gap="5vh",
306
+ ):
307
+ with solara.Row(justify="space-between"):
308
+ with solara.Row(gap="10px", style={"align-items": "center"}):
309
+ solara.v.Icon(children=["mdi-compass-rose"], size="36px")
310
+ solara.HTML(
311
+ tag="h2",
312
+ unsafe_innerHTML="Wanderlust",
313
+ style={"display": "inline-block"},
314
+ )
315
+ # with solara.Row(gap="10px"):
316
+ # solara.Button("Save", on_click=save)
317
+ # solara.Button("Load", on_click=load)
318
+ # solara.Button("Soft reset", on_click=reset_ui)
319
+ with solara.Row(justify="space-between", style={"flex-grow": "1"}):
320
  ChatInterface().key(f"chat-{reset_counter}")
321
+ with solara.Column(style={"width": "50vw", "justify-content": "center"}):
322
  Map() # .key(f"map-{reset_counter}")
323
 
324
  solara.Style(
325
  """
326
  .jupyter-widgets.leaflet-widgets{
327
  height: 100%;
328
+ border-radius: 20px;
329
  }
330
  .solara-autorouter-content{
331
  display: flex;
332
  flex-direction: column;
333
  justify-content: stretch;
334
  }
335
+ .v-toolbar__title{
336
+ display: flex;
337
+ align-items: center;
338
+ column-gap: 0.5rem;
339
+ }
340
  """
341
  )
342