sasan commited on
Commit
fea02f6
1 Parent(s): b4ec609

chore: Add set_vehicle_speed and set_vehicle_destination functions

Browse files
kitt/core/model.py CHANGED
@@ -12,6 +12,7 @@ from loguru import logger
12
 
13
 
14
  from kitt.skills import vehicle_status
 
15
 
16
 
17
  class FunctionCall(BaseModel):
@@ -29,7 +30,7 @@ class FunctionCall(BaseModel):
29
 
30
  schema_json = json.loads(FunctionCall.schema_json())
31
  HRMS_SYSTEM_PROMPT = """<|im_start|>system
32
- You are a function calling AI agent with self-recursion.
33
  You can call only one function at a time and analyse data you get from function response.
34
  You are provided with function signatures within <tools></tools> XML tags.
35
 
@@ -53,7 +54,7 @@ Make sure that the json object above with code markdown block is parseable with
53
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
54
 
55
  Example 1:
56
- User: How is the weather today?
57
  Assistant:
58
  <tool_call>
59
  {{"arguments": {{"location": ""}}, "name": "get_weather"}}
@@ -206,14 +207,7 @@ def process_response(user_query, res, history, tools, depth):
206
  return True, tool_calls, errors
207
 
208
 
209
- def run_inference_step(depth, history, tools, schema_json, dry_run=False):
210
- # If we decide to call a function, we need to generate the prompt for the model
211
- # based on the history of the conversation so far.
212
- # not break the loop
213
- openai_tools = [convert_to_openai_function(tool) for tool in tools]
214
- prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
215
- print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
216
-
217
  data = {
218
  "prompt": prompt
219
  + "\nThis is the first turn and you don't have <tool_results> to analyze yet"
@@ -230,37 +224,86 @@ def run_inference_step(depth, history, tools, schema_json, dry_run=False):
230
  "temperature": 0.8,
231
  # "max_tokens": 1500,
232
  "num_predict": 1500,
233
- "mirostat": 1,
234
  # "mirostat_tau": 2,
235
- "repeat_penalty": 1.5,
236
  "top_k": 25,
237
  "top_p": 0.5,
 
238
  # "num_predict": 1500,
239
  # "max_tokens": 1500,
240
  },
241
  }
242
 
243
- if dry_run:
244
- print(prompt + AI_PREAMBLE)
245
- return "Didn't really run it."
246
-
247
- client = Client(host='http://localhost:11444')
248
  # out = ollama.generate(**data)
249
  out = client.generate(**data)
250
- logger.debug(f"Response from model: {out}")
251
- res = out["response"]
252
-
 
 
 
253
  return res
254
 
255
 
256
- def process_query(user_query: str, history: ChatMessageHistory, tools):
257
- # Add vehicle status to the history
258
- user_query_status = (
259
- f"Given that:\n{vehicle_status()[0]}\nAnswer the following:\n{user_query}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  history.add_message(HumanMessage(content=user_query_status))
262
  for depth in range(10):
263
- out = run_inference_step(depth, history, tools, schema_json)
 
264
  print(f"Inference step result:\n{out}\n------------------\n")
265
  history.add_message(AIMessage(content=out))
266
  to_continue, tool_calls, errors = process_response(
 
12
 
13
 
14
  from kitt.skills import vehicle_status
15
+ from kitt.skills.common import config
16
 
17
 
18
  class FunctionCall(BaseModel):
 
30
 
31
  schema_json = json.loads(FunctionCall.schema_json())
32
  HRMS_SYSTEM_PROMPT = """<|im_start|>system
33
+ You are a function calling AI agent. Your name is KITT. You are embodied in a Car. You know where you are, where you are going, and the current date and time. You can call functions to help with user queries.
34
  You can call only one function at a time and analyse data you get from function response.
35
  You are provided with function signatures within <tools></tools> XML tags.
36
 
 
54
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
55
 
56
  Example 1:
57
+ User: How is the weather?
58
  Assistant:
59
  <tool_call>
60
  {{"arguments": {{"location": ""}}, "name": "get_weather"}}
 
207
  return True, tool_calls, errors
208
 
209
 
210
+ def run_inference_ollama(prompt):
 
 
 
 
 
 
 
211
  data = {
212
  "prompt": prompt
213
  + "\nThis is the first turn and you don't have <tool_results> to analyze yet"
 
224
  "temperature": 0.8,
225
  # "max_tokens": 1500,
226
  "num_predict": 1500,
227
+ # "mirostat": 1,
228
  # "mirostat_tau": 2,
229
+ "repeat_penalty": 1.1,
230
  "top_k": 25,
231
  "top_p": 0.5,
232
+ "num_ctx": 8000,
233
  # "num_predict": 1500,
234
  # "max_tokens": 1500,
235
  },
236
  }
237
 
238
+ client = Client(host="http://localhost:11434")
 
 
 
 
239
  # out = ollama.generate(**data)
240
  out = client.generate(**data)
241
+ res = out.pop("response")
242
+ # Report prompt and eval tokens
243
+ logger.warning(
244
+ f"Prompt tokens: {out.get('prompt_eval_count')}, Response tokens: {out.get('eval_count')}"
245
+ )
246
+ logger.debug(f"Response from Ollama: {res}\nOut:{out}")
247
  return res
248
 
249
 
250
+ def run_inference_step(
251
+ depth, history, tools, schema_json, dry_run=False, backend="ollama"
252
+ ):
253
+ # If we decide to call a function, we need to generate the prompt for the model
254
+ # based on the history of the conversation so far.
255
+ # not break the loop
256
+ openai_tools = [convert_to_openai_function(tool) for tool in tools]
257
+ prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
258
+ print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
259
+
260
+ if backend == "ollama":
261
+ output = run_inference_ollama(prompt)
262
+ else:
263
+ output = run_inference_replicate(prompt)
264
+
265
+ logger.debug(f"Response from model: {output}")
266
+ return output
267
+
268
+
269
+ def run_inference_replicate(prompt):
270
+ from replicate import Client
271
+
272
+ replicate = Client(api_token=config.REPLICATE_API_KEY)
273
+
274
+ input = {
275
+ "prompt": prompt
276
+ + "\nThis is the first turn and you don't have <tool_results> to analyze yet"
277
+ + AI_PREAMBLE,
278
+ "temperature": 0.5,
279
+ "system_prompt": "",
280
+ "max_new_tokens": 1024,
281
+ "repeat_penalty": 1.1,
282
+ "prompt_template": "{prompt}",
283
+ }
284
+
285
+ output = replicate.run(
286
+ "mikeei/dolphin-2.9-llama3-8b-gguf:0f79fb14c45ae2b92e1f07d872dceed3afafcacd903258df487d3bec9e393cb2",
287
+ input=input,
288
  )
289
+ out = "".join(output)
290
+
291
+ return out
292
+
293
+
294
+ def process_query(
295
+ user_query: str,
296
+ history: ChatMessageHistory,
297
+ user_preferences,
298
+ tools,
299
+ backend="ollama",
300
+ ):
301
+ # Add vehicle status to the history
302
+ user_query_status = f"Given that:\n{vehicle_status()[0]}\nUser preferences:\n{user_preferences}\nAnswer the following:\n{user_query}"
303
  history.add_message(HumanMessage(content=user_query_status))
304
  for depth in range(10):
305
+ # out = run_inference_step(depth, history, tools, schema_json)
306
+ out = run_inference_step(depth, history, tools, schema_json, backend=backend)
307
  print(f"Inference step result:\n{out}\n------------------\n")
308
  history.add_message(AIMessage(content=out))
309
  to_continue, tool_calls, errors = process_response(
kitt/skills/__init__.py CHANGED
@@ -5,7 +5,7 @@ from .common import execute_function_call, extract_func_args, vehicle as vehicle
5
  from .weather import get_weather_current_location, get_weather, get_forecast
6
  from .routing import find_route
7
  from .poi import search_points_of_interests, search_along_route_w_coordinates
8
- from .vehicle import vehicle_status
9
  from .interpreter import code_interpreter
10
 
11
 
 
5
  from .weather import get_weather_current_location, get_weather, get_forecast
6
  from .routing import find_route
7
  from .poi import search_points_of_interests, search_along_route_w_coordinates
8
+ from .vehicle import vehicle_status, set_vehicle_speed, set_vehicle_destination
9
  from .interpreter import code_interpreter
10
 
11
 
kitt/skills/common.py CHANGED
@@ -1,25 +1,34 @@
1
  import re
2
- from typing import Union
3
 
4
 
5
  from pydantic_settings import BaseSettings, SettingsConfigDict
6
  from pydantic import BaseModel
7
 
8
  from .. import skills
 
9
 
10
  class Settings(BaseSettings):
11
  WEATHER_API_KEY: str
12
  TOMTOM_API_KEY: str
 
13
 
14
  model_config = SettingsConfigDict(env_file=".env")
15
 
16
 
 
 
 
 
 
17
  class VehicleStatus(BaseModel):
18
  location: str
19
  location_coordinates: tuple[float, float] # (latitude, longitude)
20
  date: str
21
  time: str
22
  destination: str
 
 
23
 
24
 
25
  def execute_function_call(text: str, dry_run=False) -> str:
 
1
  import re
2
+ from typing import Union, Optional
3
 
4
 
5
  from pydantic_settings import BaseSettings, SettingsConfigDict
6
  from pydantic import BaseModel
7
 
8
  from .. import skills
9
+ from enum import Enum
10
 
11
  class Settings(BaseSettings):
12
  WEATHER_API_KEY: str
13
  TOMTOM_API_KEY: str
14
+ REPLICATE_API_KEY: Optional[str]
15
 
16
  model_config = SettingsConfigDict(env_file=".env")
17
 
18
 
19
+ class Speed(Enum):
20
+ SLOW = "slow"
21
+ FAST = "fast"
22
+
23
+
24
  class VehicleStatus(BaseModel):
25
  location: str
26
  location_coordinates: tuple[float, float] # (latitude, longitude)
27
  date: str
28
  time: str
29
  destination: str
30
+ speed: Speed = Speed.SLOW
31
+
32
 
33
 
34
  def execute_function_call(text: str, dry_run=False) -> str:
kitt/skills/vehicle.py CHANGED
@@ -1,7 +1,8 @@
1
- from .common import vehicle
2
 
3
 
4
- STATUS_TEMPLATE = """The current location is: {location} ({lat}, {lon})
 
5
  The current date and time: {date} {time}
6
  The current destination is: {destination}"""
7
 
@@ -32,3 +33,21 @@ def vehicle_status() -> tuple[str, dict[str, str]]:
32
  vs["lat"] = vs["location_coordinates"][0]
33
  vs["lon"] = vs["location_coordinates"][1]
34
  return STATUS_TEMPLATE.format(**vs), vs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import vehicle, Speed
2
 
3
 
4
+ STATUS_TEMPLATE = """The current location is: {location}
5
+ Current coordinates: {lat}, {lon}
6
  The current date and time: {date} {time}
7
  The current destination is: {destination}"""
8
 
 
33
  vs["lat"] = vs["location_coordinates"][0]
34
  vs["lon"] = vs["location_coordinates"][1]
35
  return STATUS_TEMPLATE.format(**vs), vs
36
+
37
+
38
+
39
+ def set_vehicle_speed(speed: Speed):
40
+ """Set the speed of the vehicle.
41
+ Args:
42
+ speed (Speed): The speed of the vehicle. ("slow", "fast")
43
+ """
44
+ vehicle.speed = speed
45
+ return f"The vehicle speed is set to {speed.value}."
46
+
47
+ def set_vehicle_destination(destination: str):
48
+ """Set the destination of the vehicle.
49
+ Args:
50
+ destination (str): The destination of the vehicle.
51
+ """
52
+ vehicle.destination = destination
53
+ return f"The vehicle destination is set to {destination}."
kitt/skills/weather.py CHANGED
@@ -1,4 +1,5 @@
1
  import requests
 
2
 
3
  from .common import config, vehicle
4
 
@@ -19,27 +20,26 @@ def get_weather_current_location():
19
 
20
 
21
  # current weather API
22
- def get_weather(location: str = ""):
23
  """
24
  Get the current weather in a specified location.
25
  When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
26
 
27
  Args:
28
- location (string) : Optional. The name of the location, if empty, the vehicle location is used.
29
 
30
  Returns:
31
  dict: The weather data in the specified location.
32
  """
33
 
34
- if location == "":
35
- print(
36
  f"get_weather: location is empty, using the vehicle location. ({vehicle.location})"
37
  )
38
  location = vehicle.location
39
 
40
  # The endpoint URL provided by WeatherAPI
41
  url = f"http://api.weatherapi.com/v1/current.json?key={config.WEATHER_API_KEY}&q={location}&aqi=no"
42
- print(url)
43
 
44
  # Make the API request
45
  response = requests.get(url)
 
1
  import requests
2
+ from loguru import logger
3
 
4
  from .common import config, vehicle
5
 
 
20
 
21
 
22
  # current weather API
23
+ def get_weather(location: str = "here"):
24
  """
25
  Get the current weather in a specified location.
26
  When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
27
 
28
  Args:
29
+ location (string) : Optional. The name of the location, if empty or here, the vehicle location is used.
30
 
31
  Returns:
32
  dict: The weather data in the specified location.
33
  """
34
 
35
+ if location == "" or location == "here":
36
+ logger.warning(
37
  f"get_weather: location is empty, using the vehicle location. ({vehicle.location})"
38
  )
39
  location = vehicle.location
40
 
41
  # The endpoint URL provided by WeatherAPI
42
  url = f"http://api.weatherapi.com/v1/current.json?key={config.WEATHER_API_KEY}&q={location}&aqi=no"
 
43
 
44
  # Make the API request
45
  response = requests.get(url)
main.py CHANGED
@@ -21,8 +21,10 @@ from kitt.skills import (
21
  find_route,
22
  get_forecast,
23
  vehicle_status as vehicle_status_fn,
 
24
  search_points_of_interests,
25
  search_along_route_w_coordinates,
 
26
  do_anything_else,
27
  date_time_info,
28
  get_weather_current_location,
@@ -120,11 +122,12 @@ def get_vehicle_status(state):
120
  tools = [
121
  StructuredTool.from_function(get_weather),
122
  StructuredTool.from_function(find_route),
123
- StructuredTool.from_function(vehicle_status_fn),
 
124
  StructuredTool.from_function(search_points_of_interests),
125
  StructuredTool.from_function(search_along_route),
126
- StructuredTool.from_function(date_time_info),
127
- StructuredTool.from_function(get_weather_current_location),
128
  StructuredTool.from_function(code_interpreter),
129
  # StructuredTool.from_function(do_anything_else),
130
  ]
@@ -148,7 +151,7 @@ def clear_history():
148
  history.clear()
149
 
150
 
151
- def run_nexusraven_model(query, voice_character):
152
  global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
153
  print("Prompt: ", global_context["prompt"])
154
  data = {
@@ -182,11 +185,18 @@ def run_nexusraven_model(query, voice_character):
182
  )
183
 
184
 
185
- def run_llama3_model(query, voice_character):
186
- output_text = process_query(query, history, tools)
 
 
 
 
 
 
187
  gr.Info(f"Output text: {output_text}, generating voice output...")
188
- # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
189
  voice_out = None
 
 
190
  return (
191
  output_text,
192
  voice_out,
@@ -196,15 +206,17 @@ def run_llama3_model(query, voice_character):
196
  def run_model(query, voice_character, state):
197
  model = state.get("model", "nexusraven")
198
  query = query.strip().replace("'", "")
199
- print("Query: ", query)
200
- print("Model: ", model)
 
201
  global_context["query"] = query
202
  if model == "nexusraven":
203
- return run_nexusraven_model(query, voice_character)
204
  elif model == "llama3":
205
- return run_llama3_model(query, voice_character)
206
- return "Error running model", None
207
-
 
208
 
209
 
210
  def calculate_route_gradio(origin, destination):
@@ -276,6 +288,32 @@ def save_and_transcribe_run_model(audio, voice_character, state):
276
  out_text, out_voice = run_model(text, voice_character, state)
277
  return text, out_text, out_voice
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
280
  # in "Insecure origins treated as secure", enable it and relaunch chrome
281
 
@@ -284,7 +322,7 @@ def save_and_transcribe_run_model(audio, voice_character, state):
284
  # What's the closest restaurant from here?
285
 
286
 
287
- def create_demo(tts_server: bool = False, model="llama3", tts=True):
288
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
289
  with gr.Blocks(theme=gr.themes.Default()) as demo:
290
  state = gr.State(
@@ -293,7 +331,9 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
293
  "query": "",
294
  "route_points": [],
295
  "model": model,
296
- "tts": tts,
 
 
297
  }
298
  )
299
  trip_points = gr.State(value=[])
@@ -328,6 +368,12 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
328
  label="Destination",
329
  interactive=True,
330
  )
 
 
 
 
 
 
331
 
332
  with gr.Column(scale=2, min_width=600):
333
  map_plot = gr.Plot()
@@ -363,6 +409,19 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
363
  vehicle_status = gr.JSON(
364
  value=vehicle.model_dump_json(), label="Vehicle status"
365
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  # Push button
367
  clear_history_btn = gr.Button(value="Clear History")
368
  with gr.Column():
@@ -383,6 +442,9 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
383
  inputs=[origin, destination],
384
  outputs=[map_plot, vehicle_status, trip_progress],
385
  )
 
 
 
386
 
387
  # Update time based on the time picker
388
  time_picker.select(fn=set_time, inputs=[time_picker], outputs=[vehicle_status])
@@ -391,12 +453,12 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
391
  input_text.submit(
392
  fn=run_model,
393
  inputs=[input_text, voice_character, state],
394
- outputs=[output_text, output_audio],
395
  )
396
  input_text_debug.submit(
397
  fn=run_model,
398
- inputs=[input_text, voice_character, state],
399
- outputs=[output_text, output_audio],
400
  )
401
 
402
  # Set the vehicle status based on the trip progress
@@ -408,15 +470,26 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
408
 
409
  # Save and transcribe the audio
410
  input_audio.stop_recording(
411
- fn=save_and_transcribe_run_model, inputs=[input_audio, voice_character, state], outputs=[input_text, output_text, output_audio]
 
 
412
  )
413
  input_audio_debug.stop_recording(
414
- fn=save_and_transcribe_audio, inputs=[input_audio_debug], outputs=[input_text_debug]
 
 
415
  )
416
 
417
  # Clear the history
418
  clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
419
 
 
 
 
 
 
 
 
420
  return demo
421
 
422
 
@@ -424,7 +497,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
424
  gr.close_all()
425
 
426
 
427
- demo = create_demo(False, "llama3", tts=False)
428
  demo.launch(
429
  debug=True,
430
  server_name="0.0.0.0",
 
21
  find_route,
22
  get_forecast,
23
  vehicle_status as vehicle_status_fn,
24
+ set_vehicle_speed,
25
  search_points_of_interests,
26
  search_along_route_w_coordinates,
27
+ set_vehicle_destination,
28
  do_anything_else,
29
  date_time_info,
30
  get_weather_current_location,
 
122
  tools = [
123
  StructuredTool.from_function(get_weather),
124
  StructuredTool.from_function(find_route),
125
+ # StructuredTool.from_function(vehicle_status_fn),
126
+ StructuredTool.from_function(set_vehicle_speed),
127
  StructuredTool.from_function(search_points_of_interests),
128
  StructuredTool.from_function(search_along_route),
129
+ # StructuredTool.from_function(date_time_info),
130
+ # StructuredTool.from_function(get_weather_current_location),
131
  StructuredTool.from_function(code_interpreter),
132
  # StructuredTool.from_function(do_anything_else),
133
  ]
 
151
  history.clear()
152
 
153
 
154
+ def run_nexusraven_model(query, voice_character, state):
155
  global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
156
  print("Prompt: ", global_context["prompt"])
157
  data = {
 
185
  )
186
 
187
 
188
+ def run_llama3_model(query, voice_character, state):
189
+ output_text = process_query(
190
+ query,
191
+ history=history,
192
+ user_preferences=state["user_preferences"],
193
+ tools=tools,
194
+ backend=state["llm_backend"],
195
+ )
196
  gr.Info(f"Output text: {output_text}, generating voice output...")
 
197
  voice_out = None
198
+ if state["tts_enabled"]:
199
+ voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
200
  return (
201
  output_text,
202
  voice_out,
 
206
  def run_model(query, voice_character, state):
207
  model = state.get("model", "nexusraven")
208
  query = query.strip().replace("'", "")
209
+ logger.info(
210
+ f"Running model: {model} with query: {query}, voice_character: {voice_character} and llm_backend: {state['llm_backend']}, tts_enabled: {state['tts_enabled']}"
211
+ )
212
  global_context["query"] = query
213
  if model == "nexusraven":
214
+ text, voice = run_nexusraven_model(query, voice_character, state)
215
  elif model == "llama3":
216
+ text, voice = run_llama3_model(query, voice_character, state)
217
+ else:
218
+ text, voice = "Error running model", None
219
+ return text, voice, vehicle.model_dump_json()
220
 
221
 
222
  def calculate_route_gradio(origin, destination):
 
288
  out_text, out_voice = run_model(text, voice_character, state)
289
  return text, out_text, out_voice
290
 
291
+
292
+ def set_tts_enabled(tts_enabled, state):
293
+ new_tts_enabled = tts_enabled == "Yes"
294
+ logger.info(
295
+ f"TTS enabled was {state['tts_enabled']} and changed to {new_tts_enabled}"
296
+ )
297
+ state["tts_enabled"] = new_tts_enabled
298
+ return state
299
+
300
+
301
+ def set_llm_backend(llm_backend, state):
302
+ new_llm_backend = "ollama" if llm_backend == "Ollama" else "replicate"
303
+ logger.info(
304
+ f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
305
+ )
306
+ state["llm_backend"] = new_llm_backend
307
+ return state
308
+
309
+
310
+ def set_user_preferences(preferences, state):
311
+ new_preferences = preferences
312
+ logger.info(f"User preferences changed to: {new_preferences}")
313
+ state["user_preferences"] = new_preferences
314
+ return state
315
+
316
+
317
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
318
  # in "Insecure origins treated as secure", enable it and relaunch chrome
319
 
 
322
  # What's the closest restaurant from here?
323
 
324
 
325
+ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
326
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
327
  with gr.Blocks(theme=gr.themes.Default()) as demo:
328
  state = gr.State(
 
331
  "query": "",
332
  "route_points": [],
333
  "model": model,
334
+ "tts_enabled": tts_enabled,
335
+ "llm_backend": "Ollama",
336
+ "user_preferences": "",
337
  }
338
  )
339
  trip_points = gr.State(value=[])
 
368
  label="Destination",
369
  interactive=True,
370
  )
371
+ preferences = gr.Textbox(
372
+ value="I love italian food\nI like doing sports",
373
+ label="User preferences",
374
+ lines=3,
375
+ interactive=True,
376
+ )
377
 
378
  with gr.Column(scale=2, min_width=600):
379
  map_plot = gr.Plot()
 
409
  vehicle_status = gr.JSON(
410
  value=vehicle.model_dump_json(), label="Vehicle status"
411
  )
412
+ with gr.Accordion("Config"):
413
+ tts_enabled = gr.Radio(
414
+ choices=["Yes", "No"],
415
+ label="Enable TTS",
416
+ value="No",
417
+ interactive=True,
418
+ )
419
+ llm_backend = gr.Radio(
420
+ choices=["Ollama", "Replicate"],
421
+ label="LLM Backend",
422
+ value="Ollama",
423
+ interactive=True,
424
+ )
425
  # Push button
426
  clear_history_btn = gr.Button(value="Clear History")
427
  with gr.Column():
 
442
  inputs=[origin, destination],
443
  outputs=[map_plot, vehicle_status, trip_progress],
444
  )
445
+ preferences.submit(
446
+ fn=set_user_preferences, inputs=[preferences, state], outputs=[state]
447
+ )
448
 
449
  # Update time based on the time picker
450
  time_picker.select(fn=set_time, inputs=[time_picker], outputs=[vehicle_status])
 
453
  input_text.submit(
454
  fn=run_model,
455
  inputs=[input_text, voice_character, state],
456
+ outputs=[output_text, output_audio, vehicle_status],
457
  )
458
  input_text_debug.submit(
459
  fn=run_model,
460
+ inputs=[input_text_debug, voice_character, state],
461
+ outputs=[output_text, output_audio, vehicle_status],
462
  )
463
 
464
  # Set the vehicle status based on the trip progress
 
470
 
471
  # Save and transcribe the audio
472
  input_audio.stop_recording(
473
+ fn=save_and_transcribe_run_model,
474
+ inputs=[input_audio, voice_character, state],
475
+ outputs=[input_text, output_text, output_audio],
476
  )
477
  input_audio_debug.stop_recording(
478
+ fn=save_and_transcribe_audio,
479
+ inputs=[input_audio_debug],
480
+ outputs=[input_text_debug],
481
  )
482
 
483
  # Clear the history
484
  clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
485
 
486
+ # Config
487
+ tts_enabled.change(
488
+ fn=set_tts_enabled, inputs=[tts_enabled, state], outputs=[state]
489
+ )
490
+ llm_backend.change(
491
+ fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
492
+ )
493
  return demo
494
 
495
 
 
497
  gr.close_all()
498
 
499
 
500
+ demo = create_demo(False, "llama3", tts_enabled=False)
501
  demo.launch(
502
  debug=True,
503
  server_name="0.0.0.0",