chore: Add set_vehicle_speed and set_vehicle_destination functions
Browse files- kitt/core/model.py +68 -25
- kitt/skills/__init__.py +1 -1
- kitt/skills/common.py +10 -1
- kitt/skills/vehicle.py +21 -2
- kitt/skills/weather.py +5 -5
- main.py +94 -21
kitt/core/model.py
CHANGED
@@ -12,6 +12,7 @@ from loguru import logger
|
|
12 |
|
13 |
|
14 |
from kitt.skills import vehicle_status
|
|
|
15 |
|
16 |
|
17 |
class FunctionCall(BaseModel):
|
@@ -29,7 +30,7 @@ class FunctionCall(BaseModel):
|
|
29 |
|
30 |
schema_json = json.loads(FunctionCall.schema_json())
|
31 |
HRMS_SYSTEM_PROMPT = """<|im_start|>system
|
32 |
-
You are a function calling AI agent with
|
33 |
You can call only one function at a time and analyse data you get from function response.
|
34 |
You are provided with function signatures within <tools></tools> XML tags.
|
35 |
|
@@ -53,7 +54,7 @@ Make sure that the json object above with code markdown block is parseable with
|
|
53 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
54 |
|
55 |
Example 1:
|
56 |
-
User: How is the weather
|
57 |
Assistant:
|
58 |
<tool_call>
|
59 |
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
|
@@ -206,14 +207,7 @@ def process_response(user_query, res, history, tools, depth):
|
|
206 |
return True, tool_calls, errors
|
207 |
|
208 |
|
209 |
-
def
|
210 |
-
# If we decide to call a function, we need to generate the prompt for the model
|
211 |
-
# based on the history of the conversation so far.
|
212 |
-
# not break the loop
|
213 |
-
openai_tools = [convert_to_openai_function(tool) for tool in tools]
|
214 |
-
prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
|
215 |
-
print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
|
216 |
-
|
217 |
data = {
|
218 |
"prompt": prompt
|
219 |
+ "\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
@@ -230,37 +224,86 @@ def run_inference_step(depth, history, tools, schema_json, dry_run=False):
|
|
230 |
"temperature": 0.8,
|
231 |
# "max_tokens": 1500,
|
232 |
"num_predict": 1500,
|
233 |
-
"mirostat": 1,
|
234 |
# "mirostat_tau": 2,
|
235 |
-
"repeat_penalty": 1.
|
236 |
"top_k": 25,
|
237 |
"top_p": 0.5,
|
|
|
238 |
# "num_predict": 1500,
|
239 |
# "max_tokens": 1500,
|
240 |
},
|
241 |
}
|
242 |
|
243 |
-
|
244 |
-
print(prompt + AI_PREAMBLE)
|
245 |
-
return "Didn't really run it."
|
246 |
-
|
247 |
-
client = Client(host='http://localhost:11444')
|
248 |
# out = ollama.generate(**data)
|
249 |
out = client.generate(**data)
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
253 |
return res
|
254 |
|
255 |
|
256 |
-
def
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
history.add_message(HumanMessage(content=user_query_status))
|
262 |
for depth in range(10):
|
263 |
-
out = run_inference_step(depth, history, tools, schema_json)
|
|
|
264 |
print(f"Inference step result:\n{out}\n------------------\n")
|
265 |
history.add_message(AIMessage(content=out))
|
266 |
to_continue, tool_calls, errors = process_response(
|
|
|
12 |
|
13 |
|
14 |
from kitt.skills import vehicle_status
|
15 |
+
from kitt.skills.common import config
|
16 |
|
17 |
|
18 |
class FunctionCall(BaseModel):
|
|
|
30 |
|
31 |
schema_json = json.loads(FunctionCall.schema_json())
|
32 |
HRMS_SYSTEM_PROMPT = """<|im_start|>system
|
33 |
+
You are a function calling AI agent. Your name is KITT. You are embodied in a Car. You know where you are, where you are going, and the current date and time. You can call functions to help with user queries.
|
34 |
You can call only one function at a time and analyse data you get from function response.
|
35 |
You are provided with function signatures within <tools></tools> XML tags.
|
36 |
|
|
|
54 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
55 |
|
56 |
Example 1:
|
57 |
+
User: How is the weather?
|
58 |
Assistant:
|
59 |
<tool_call>
|
60 |
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
|
|
|
207 |
return True, tool_calls, errors
|
208 |
|
209 |
|
210 |
+
def run_inference_ollama(prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
data = {
|
212 |
"prompt": prompt
|
213 |
+ "\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
|
|
224 |
"temperature": 0.8,
|
225 |
# "max_tokens": 1500,
|
226 |
"num_predict": 1500,
|
227 |
+
# "mirostat": 1,
|
228 |
# "mirostat_tau": 2,
|
229 |
+
"repeat_penalty": 1.1,
|
230 |
"top_k": 25,
|
231 |
"top_p": 0.5,
|
232 |
+
"num_ctx": 8000,
|
233 |
# "num_predict": 1500,
|
234 |
# "max_tokens": 1500,
|
235 |
},
|
236 |
}
|
237 |
|
238 |
+
client = Client(host="http://localhost:11434")
|
|
|
|
|
|
|
|
|
239 |
# out = ollama.generate(**data)
|
240 |
out = client.generate(**data)
|
241 |
+
res = out.pop("response")
|
242 |
+
# Report prompt and eval tokens
|
243 |
+
logger.warning(
|
244 |
+
f"Prompt tokens: {out.get('prompt_eval_count')}, Response tokens: {out.get('eval_count')}"
|
245 |
+
)
|
246 |
+
logger.debug(f"Response from Ollama: {res}\nOut:{out}")
|
247 |
return res
|
248 |
|
249 |
|
250 |
+
def run_inference_step(
|
251 |
+
depth, history, tools, schema_json, dry_run=False, backend="ollama"
|
252 |
+
):
|
253 |
+
# If we decide to call a function, we need to generate the prompt for the model
|
254 |
+
# based on the history of the conversation so far.
|
255 |
+
# not break the loop
|
256 |
+
openai_tools = [convert_to_openai_function(tool) for tool in tools]
|
257 |
+
prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
|
258 |
+
print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
|
259 |
+
|
260 |
+
if backend == "ollama":
|
261 |
+
output = run_inference_ollama(prompt)
|
262 |
+
else:
|
263 |
+
output = run_inference_replicate(prompt)
|
264 |
+
|
265 |
+
logger.debug(f"Response from model: {output}")
|
266 |
+
return output
|
267 |
+
|
268 |
+
|
269 |
+
def run_inference_replicate(prompt):
|
270 |
+
from replicate import Client
|
271 |
+
|
272 |
+
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
273 |
+
|
274 |
+
input = {
|
275 |
+
"prompt": prompt
|
276 |
+
+ "\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
277 |
+
+ AI_PREAMBLE,
|
278 |
+
"temperature": 0.5,
|
279 |
+
"system_prompt": "",
|
280 |
+
"max_new_tokens": 1024,
|
281 |
+
"repeat_penalty": 1.1,
|
282 |
+
"prompt_template": "{prompt}",
|
283 |
+
}
|
284 |
+
|
285 |
+
output = replicate.run(
|
286 |
+
"mikeei/dolphin-2.9-llama3-8b-gguf:0f79fb14c45ae2b92e1f07d872dceed3afafcacd903258df487d3bec9e393cb2",
|
287 |
+
input=input,
|
288 |
)
|
289 |
+
out = "".join(output)
|
290 |
+
|
291 |
+
return out
|
292 |
+
|
293 |
+
|
294 |
+
def process_query(
|
295 |
+
user_query: str,
|
296 |
+
history: ChatMessageHistory,
|
297 |
+
user_preferences,
|
298 |
+
tools,
|
299 |
+
backend="ollama",
|
300 |
+
):
|
301 |
+
# Add vehicle status to the history
|
302 |
+
user_query_status = f"Given that:\n{vehicle_status()[0]}\nUser preferences:\n{user_preferences}\nAnswer the following:\n{user_query}"
|
303 |
history.add_message(HumanMessage(content=user_query_status))
|
304 |
for depth in range(10):
|
305 |
+
# out = run_inference_step(depth, history, tools, schema_json)
|
306 |
+
out = run_inference_step(depth, history, tools, schema_json, backend=backend)
|
307 |
print(f"Inference step result:\n{out}\n------------------\n")
|
308 |
history.add_message(AIMessage(content=out))
|
309 |
to_continue, tool_calls, errors = process_response(
|
kitt/skills/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from .common import execute_function_call, extract_func_args, vehicle as vehicle
|
|
5 |
from .weather import get_weather_current_location, get_weather, get_forecast
|
6 |
from .routing import find_route
|
7 |
from .poi import search_points_of_interests, search_along_route_w_coordinates
|
8 |
-
from .vehicle import vehicle_status
|
9 |
from .interpreter import code_interpreter
|
10 |
|
11 |
|
|
|
5 |
from .weather import get_weather_current_location, get_weather, get_forecast
|
6 |
from .routing import find_route
|
7 |
from .poi import search_points_of_interests, search_along_route_w_coordinates
|
8 |
+
from .vehicle import vehicle_status, set_vehicle_speed, set_vehicle_destination
|
9 |
from .interpreter import code_interpreter
|
10 |
|
11 |
|
kitt/skills/common.py
CHANGED
@@ -1,25 +1,34 @@
|
|
1 |
import re
|
2 |
-
from typing import Union
|
3 |
|
4 |
|
5 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
from .. import skills
|
|
|
9 |
|
10 |
class Settings(BaseSettings):
|
11 |
WEATHER_API_KEY: str
|
12 |
TOMTOM_API_KEY: str
|
|
|
13 |
|
14 |
model_config = SettingsConfigDict(env_file=".env")
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
class VehicleStatus(BaseModel):
|
18 |
location: str
|
19 |
location_coordinates: tuple[float, float] # (latitude, longitude)
|
20 |
date: str
|
21 |
time: str
|
22 |
destination: str
|
|
|
|
|
23 |
|
24 |
|
25 |
def execute_function_call(text: str, dry_run=False) -> str:
|
|
|
1 |
import re
|
2 |
+
from typing import Union, Optional
|
3 |
|
4 |
|
5 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
from .. import skills
|
9 |
+
from enum import Enum
|
10 |
|
11 |
class Settings(BaseSettings):
|
12 |
WEATHER_API_KEY: str
|
13 |
TOMTOM_API_KEY: str
|
14 |
+
REPLICATE_API_KEY: Optional[str]
|
15 |
|
16 |
model_config = SettingsConfigDict(env_file=".env")
|
17 |
|
18 |
|
19 |
+
class Speed(Enum):
|
20 |
+
SLOW = "slow"
|
21 |
+
FAST = "fast"
|
22 |
+
|
23 |
+
|
24 |
class VehicleStatus(BaseModel):
|
25 |
location: str
|
26 |
location_coordinates: tuple[float, float] # (latitude, longitude)
|
27 |
date: str
|
28 |
time: str
|
29 |
destination: str
|
30 |
+
speed: Speed = Speed.SLOW
|
31 |
+
|
32 |
|
33 |
|
34 |
def execute_function_call(text: str, dry_run=False) -> str:
|
kitt/skills/vehicle.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from .common import vehicle
|
2 |
|
3 |
|
4 |
-
STATUS_TEMPLATE = """The current location is: {location}
|
|
|
5 |
The current date and time: {date} {time}
|
6 |
The current destination is: {destination}"""
|
7 |
|
@@ -32,3 +33,21 @@ def vehicle_status() -> tuple[str, dict[str, str]]:
|
|
32 |
vs["lat"] = vs["location_coordinates"][0]
|
33 |
vs["lon"] = vs["location_coordinates"][1]
|
34 |
return STATUS_TEMPLATE.format(**vs), vs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import vehicle, Speed
|
2 |
|
3 |
|
4 |
+
STATUS_TEMPLATE = """The current location is: {location}
|
5 |
+
Current coordinates: {lat}, {lon}
|
6 |
The current date and time: {date} {time}
|
7 |
The current destination is: {destination}"""
|
8 |
|
|
|
33 |
vs["lat"] = vs["location_coordinates"][0]
|
34 |
vs["lon"] = vs["location_coordinates"][1]
|
35 |
return STATUS_TEMPLATE.format(**vs), vs
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def set_vehicle_speed(speed: Speed):
|
40 |
+
"""Set the speed of the vehicle.
|
41 |
+
Args:
|
42 |
+
speed (Speed): The speed of the vehicle. ("slow", "fast")
|
43 |
+
"""
|
44 |
+
vehicle.speed = speed
|
45 |
+
return f"The vehicle speed is set to {speed.value}."
|
46 |
+
|
47 |
+
def set_vehicle_destination(destination: str):
|
48 |
+
"""Set the destination of the vehicle.
|
49 |
+
Args:
|
50 |
+
destination (str): The destination of the vehicle.
|
51 |
+
"""
|
52 |
+
vehicle.destination = destination
|
53 |
+
return f"The vehicle destination is set to {destination}."
|
kitt/skills/weather.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import requests
|
|
|
2 |
|
3 |
from .common import config, vehicle
|
4 |
|
@@ -19,27 +20,26 @@ def get_weather_current_location():
|
|
19 |
|
20 |
|
21 |
# current weather API
|
22 |
-
def get_weather(location: str = ""):
|
23 |
"""
|
24 |
Get the current weather in a specified location.
|
25 |
When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
|
26 |
|
27 |
Args:
|
28 |
-
location (string) : Optional. The name of the location, if empty, the vehicle location is used.
|
29 |
|
30 |
Returns:
|
31 |
dict: The weather data in the specified location.
|
32 |
"""
|
33 |
|
34 |
-
if location == "":
|
35 |
-
|
36 |
f"get_weather: location is empty, using the vehicle location. ({vehicle.location})"
|
37 |
)
|
38 |
location = vehicle.location
|
39 |
|
40 |
# The endpoint URL provided by WeatherAPI
|
41 |
url = f"http://api.weatherapi.com/v1/current.json?key={config.WEATHER_API_KEY}&q={location}&aqi=no"
|
42 |
-
print(url)
|
43 |
|
44 |
# Make the API request
|
45 |
response = requests.get(url)
|
|
|
1 |
import requests
|
2 |
+
from loguru import logger
|
3 |
|
4 |
from .common import config, vehicle
|
5 |
|
|
|
20 |
|
21 |
|
22 |
# current weather API
|
23 |
+
def get_weather(location: str = "here"):
|
24 |
"""
|
25 |
Get the current weather in a specified location.
|
26 |
When responding to user, only mention the weather condition, temperature, and the temperature that it feels like, unless the user asks for more information.
|
27 |
|
28 |
Args:
|
29 |
+
location (string) : Optional. The name of the location, if empty or here, the vehicle location is used.
|
30 |
|
31 |
Returns:
|
32 |
dict: The weather data in the specified location.
|
33 |
"""
|
34 |
|
35 |
+
if location == "" or location == "here":
|
36 |
+
logger.warning(
|
37 |
f"get_weather: location is empty, using the vehicle location. ({vehicle.location})"
|
38 |
)
|
39 |
location = vehicle.location
|
40 |
|
41 |
# The endpoint URL provided by WeatherAPI
|
42 |
url = f"http://api.weatherapi.com/v1/current.json?key={config.WEATHER_API_KEY}&q={location}&aqi=no"
|
|
|
43 |
|
44 |
# Make the API request
|
45 |
response = requests.get(url)
|
main.py
CHANGED
@@ -21,8 +21,10 @@ from kitt.skills import (
|
|
21 |
find_route,
|
22 |
get_forecast,
|
23 |
vehicle_status as vehicle_status_fn,
|
|
|
24 |
search_points_of_interests,
|
25 |
search_along_route_w_coordinates,
|
|
|
26 |
do_anything_else,
|
27 |
date_time_info,
|
28 |
get_weather_current_location,
|
@@ -120,11 +122,12 @@ def get_vehicle_status(state):
|
|
120 |
tools = [
|
121 |
StructuredTool.from_function(get_weather),
|
122 |
StructuredTool.from_function(find_route),
|
123 |
-
StructuredTool.from_function(vehicle_status_fn),
|
|
|
124 |
StructuredTool.from_function(search_points_of_interests),
|
125 |
StructuredTool.from_function(search_along_route),
|
126 |
-
StructuredTool.from_function(date_time_info),
|
127 |
-
StructuredTool.from_function(get_weather_current_location),
|
128 |
StructuredTool.from_function(code_interpreter),
|
129 |
# StructuredTool.from_function(do_anything_else),
|
130 |
]
|
@@ -148,7 +151,7 @@ def clear_history():
|
|
148 |
history.clear()
|
149 |
|
150 |
|
151 |
-
def run_nexusraven_model(query, voice_character):
|
152 |
global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
|
153 |
print("Prompt: ", global_context["prompt"])
|
154 |
data = {
|
@@ -182,11 +185,18 @@ def run_nexusraven_model(query, voice_character):
|
|
182 |
)
|
183 |
|
184 |
|
185 |
-
def run_llama3_model(query, voice_character):
|
186 |
-
output_text = process_query(
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
188 |
-
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
189 |
voice_out = None
|
|
|
|
|
190 |
return (
|
191 |
output_text,
|
192 |
voice_out,
|
@@ -196,15 +206,17 @@ def run_llama3_model(query, voice_character):
|
|
196 |
def run_model(query, voice_character, state):
|
197 |
model = state.get("model", "nexusraven")
|
198 |
query = query.strip().replace("'", "")
|
199 |
-
|
200 |
-
|
|
|
201 |
global_context["query"] = query
|
202 |
if model == "nexusraven":
|
203 |
-
|
204 |
elif model == "llama3":
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
208 |
|
209 |
|
210 |
def calculate_route_gradio(origin, destination):
|
@@ -276,6 +288,32 @@ def save_and_transcribe_run_model(audio, voice_character, state):
|
|
276 |
out_text, out_voice = run_model(text, voice_character, state)
|
277 |
return text, out_text, out_voice
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
# to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
|
280 |
# in "Insecure origins treated as secure", enable it and relaunch chrome
|
281 |
|
@@ -284,7 +322,7 @@ def save_and_transcribe_run_model(audio, voice_character, state):
|
|
284 |
# What's the closest restaurant from here?
|
285 |
|
286 |
|
287 |
-
def create_demo(tts_server: bool = False, model="llama3",
|
288 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
289 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
290 |
state = gr.State(
|
@@ -293,7 +331,9 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
293 |
"query": "",
|
294 |
"route_points": [],
|
295 |
"model": model,
|
296 |
-
"
|
|
|
|
|
297 |
}
|
298 |
)
|
299 |
trip_points = gr.State(value=[])
|
@@ -328,6 +368,12 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
328 |
label="Destination",
|
329 |
interactive=True,
|
330 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
with gr.Column(scale=2, min_width=600):
|
333 |
map_plot = gr.Plot()
|
@@ -363,6 +409,19 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
363 |
vehicle_status = gr.JSON(
|
364 |
value=vehicle.model_dump_json(), label="Vehicle status"
|
365 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
# Push button
|
367 |
clear_history_btn = gr.Button(value="Clear History")
|
368 |
with gr.Column():
|
@@ -383,6 +442,9 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
383 |
inputs=[origin, destination],
|
384 |
outputs=[map_plot, vehicle_status, trip_progress],
|
385 |
)
|
|
|
|
|
|
|
386 |
|
387 |
# Update time based on the time picker
|
388 |
time_picker.select(fn=set_time, inputs=[time_picker], outputs=[vehicle_status])
|
@@ -391,12 +453,12 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
391 |
input_text.submit(
|
392 |
fn=run_model,
|
393 |
inputs=[input_text, voice_character, state],
|
394 |
-
outputs=[output_text, output_audio],
|
395 |
)
|
396 |
input_text_debug.submit(
|
397 |
fn=run_model,
|
398 |
-
inputs=[
|
399 |
-
outputs=[output_text, output_audio],
|
400 |
)
|
401 |
|
402 |
# Set the vehicle status based on the trip progress
|
@@ -408,15 +470,26 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
408 |
|
409 |
# Save and transcribe the audio
|
410 |
input_audio.stop_recording(
|
411 |
-
fn=save_and_transcribe_run_model,
|
|
|
|
|
412 |
)
|
413 |
input_audio_debug.stop_recording(
|
414 |
-
fn=save_and_transcribe_audio,
|
|
|
|
|
415 |
)
|
416 |
|
417 |
# Clear the history
|
418 |
clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
return demo
|
421 |
|
422 |
|
@@ -424,7 +497,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts=True):
|
|
424 |
gr.close_all()
|
425 |
|
426 |
|
427 |
-
demo = create_demo(False, "llama3",
|
428 |
demo.launch(
|
429 |
debug=True,
|
430 |
server_name="0.0.0.0",
|
|
|
21 |
find_route,
|
22 |
get_forecast,
|
23 |
vehicle_status as vehicle_status_fn,
|
24 |
+
set_vehicle_speed,
|
25 |
search_points_of_interests,
|
26 |
search_along_route_w_coordinates,
|
27 |
+
set_vehicle_destination,
|
28 |
do_anything_else,
|
29 |
date_time_info,
|
30 |
get_weather_current_location,
|
|
|
122 |
tools = [
|
123 |
StructuredTool.from_function(get_weather),
|
124 |
StructuredTool.from_function(find_route),
|
125 |
+
# StructuredTool.from_function(vehicle_status_fn),
|
126 |
+
StructuredTool.from_function(set_vehicle_speed),
|
127 |
StructuredTool.from_function(search_points_of_interests),
|
128 |
StructuredTool.from_function(search_along_route),
|
129 |
+
# StructuredTool.from_function(date_time_info),
|
130 |
+
# StructuredTool.from_function(get_weather_current_location),
|
131 |
StructuredTool.from_function(code_interpreter),
|
132 |
# StructuredTool.from_function(do_anything_else),
|
133 |
]
|
|
|
151 |
history.clear()
|
152 |
|
153 |
|
154 |
+
def run_nexusraven_model(query, voice_character, state):
|
155 |
global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
|
156 |
print("Prompt: ", global_context["prompt"])
|
157 |
data = {
|
|
|
185 |
)
|
186 |
|
187 |
|
188 |
+
def run_llama3_model(query, voice_character, state):
|
189 |
+
output_text = process_query(
|
190 |
+
query,
|
191 |
+
history=history,
|
192 |
+
user_preferences=state["user_preferences"],
|
193 |
+
tools=tools,
|
194 |
+
backend=state["llm_backend"],
|
195 |
+
)
|
196 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
|
|
197 |
voice_out = None
|
198 |
+
if state["tts_enabled"]:
|
199 |
+
voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
200 |
return (
|
201 |
output_text,
|
202 |
voice_out,
|
|
|
206 |
def run_model(query, voice_character, state):
|
207 |
model = state.get("model", "nexusraven")
|
208 |
query = query.strip().replace("'", "")
|
209 |
+
logger.info(
|
210 |
+
f"Running model: {model} with query: {query}, voice_character: {voice_character} and llm_backend: {state['llm_backend']}, tts_enabled: {state['tts_enabled']}"
|
211 |
+
)
|
212 |
global_context["query"] = query
|
213 |
if model == "nexusraven":
|
214 |
+
text, voice = run_nexusraven_model(query, voice_character, state)
|
215 |
elif model == "llama3":
|
216 |
+
text, voice = run_llama3_model(query, voice_character, state)
|
217 |
+
else:
|
218 |
+
text, voice = "Error running model", None
|
219 |
+
return text, voice, vehicle.model_dump_json()
|
220 |
|
221 |
|
222 |
def calculate_route_gradio(origin, destination):
|
|
|
288 |
out_text, out_voice = run_model(text, voice_character, state)
|
289 |
return text, out_text, out_voice
|
290 |
|
291 |
+
|
292 |
+
def set_tts_enabled(tts_enabled, state):
|
293 |
+
new_tts_enabled = tts_enabled == "Yes"
|
294 |
+
logger.info(
|
295 |
+
f"TTS enabled was {state['tts_enabled']} and changed to {new_tts_enabled}"
|
296 |
+
)
|
297 |
+
state["tts_enabled"] = new_tts_enabled
|
298 |
+
return state
|
299 |
+
|
300 |
+
|
301 |
+
def set_llm_backend(llm_backend, state):
|
302 |
+
new_llm_backend = "ollama" if llm_backend == "Ollama" else "replicate"
|
303 |
+
logger.info(
|
304 |
+
f"LLM backend was {state['llm_backend']} and changed to {new_llm_backend}"
|
305 |
+
)
|
306 |
+
state["llm_backend"] = new_llm_backend
|
307 |
+
return state
|
308 |
+
|
309 |
+
|
310 |
+
def set_user_preferences(preferences, state):
|
311 |
+
new_preferences = preferences
|
312 |
+
logger.info(f"User preferences changed to: {new_preferences}")
|
313 |
+
state["user_preferences"] = new_preferences
|
314 |
+
return state
|
315 |
+
|
316 |
+
|
317 |
# to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
|
318 |
# in "Insecure origins treated as secure", enable it and relaunch chrome
|
319 |
|
|
|
322 |
# What's the closest restaurant from here?
|
323 |
|
324 |
|
325 |
+
def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
|
326 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
327 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
328 |
state = gr.State(
|
|
|
331 |
"query": "",
|
332 |
"route_points": [],
|
333 |
"model": model,
|
334 |
+
"tts_enabled": tts_enabled,
|
335 |
+
"llm_backend": "Ollama",
|
336 |
+
"user_preferences": "",
|
337 |
}
|
338 |
)
|
339 |
trip_points = gr.State(value=[])
|
|
|
368 |
label="Destination",
|
369 |
interactive=True,
|
370 |
)
|
371 |
+
preferences = gr.Textbox(
|
372 |
+
value="I love italian food\nI like doing sports",
|
373 |
+
label="User preferences",
|
374 |
+
lines=3,
|
375 |
+
interactive=True,
|
376 |
+
)
|
377 |
|
378 |
with gr.Column(scale=2, min_width=600):
|
379 |
map_plot = gr.Plot()
|
|
|
409 |
vehicle_status = gr.JSON(
|
410 |
value=vehicle.model_dump_json(), label="Vehicle status"
|
411 |
)
|
412 |
+
with gr.Accordion("Config"):
|
413 |
+
tts_enabled = gr.Radio(
|
414 |
+
choices=["Yes", "No"],
|
415 |
+
label="Enable TTS",
|
416 |
+
value="No",
|
417 |
+
interactive=True,
|
418 |
+
)
|
419 |
+
llm_backend = gr.Radio(
|
420 |
+
choices=["Ollama", "Replicate"],
|
421 |
+
label="LLM Backend",
|
422 |
+
value="Ollama",
|
423 |
+
interactive=True,
|
424 |
+
)
|
425 |
# Push button
|
426 |
clear_history_btn = gr.Button(value="Clear History")
|
427 |
with gr.Column():
|
|
|
442 |
inputs=[origin, destination],
|
443 |
outputs=[map_plot, vehicle_status, trip_progress],
|
444 |
)
|
445 |
+
preferences.submit(
|
446 |
+
fn=set_user_preferences, inputs=[preferences, state], outputs=[state]
|
447 |
+
)
|
448 |
|
449 |
# Update time based on the time picker
|
450 |
time_picker.select(fn=set_time, inputs=[time_picker], outputs=[vehicle_status])
|
|
|
453 |
input_text.submit(
|
454 |
fn=run_model,
|
455 |
inputs=[input_text, voice_character, state],
|
456 |
+
outputs=[output_text, output_audio, vehicle_status],
|
457 |
)
|
458 |
input_text_debug.submit(
|
459 |
fn=run_model,
|
460 |
+
inputs=[input_text_debug, voice_character, state],
|
461 |
+
outputs=[output_text, output_audio, vehicle_status],
|
462 |
)
|
463 |
|
464 |
# Set the vehicle status based on the trip progress
|
|
|
470 |
|
471 |
# Save and transcribe the audio
|
472 |
input_audio.stop_recording(
|
473 |
+
fn=save_and_transcribe_run_model,
|
474 |
+
inputs=[input_audio, voice_character, state],
|
475 |
+
outputs=[input_text, output_text, output_audio],
|
476 |
)
|
477 |
input_audio_debug.stop_recording(
|
478 |
+
fn=save_and_transcribe_audio,
|
479 |
+
inputs=[input_audio_debug],
|
480 |
+
outputs=[input_text_debug],
|
481 |
)
|
482 |
|
483 |
# Clear the history
|
484 |
clear_history_btn.click(fn=clear_history, inputs=[], outputs=[])
|
485 |
|
486 |
+
# Config
|
487 |
+
tts_enabled.change(
|
488 |
+
fn=set_tts_enabled, inputs=[tts_enabled, state], outputs=[state]
|
489 |
+
)
|
490 |
+
llm_backend.change(
|
491 |
+
fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
|
492 |
+
)
|
493 |
return demo
|
494 |
|
495 |
|
|
|
497 |
gr.close_all()
|
498 |
|
499 |
|
500 |
+
demo = create_demo(False, "llama3", tts_enabled=False)
|
501 |
demo.launch(
|
502 |
debug=True,
|
503 |
server_name="0.0.0.0",
|