chore: Update vehicle speed and destination handling functions
Browse files- kitt/core/model.py +34 -1
- kitt/core/tts.py +103 -0
- main.py +30 -15
kitt/core/model.py
CHANGED
@@ -84,6 +84,7 @@ Don't make assumptions about tool results if <tool_response> XML tags are not pr
|
|
84 |
Analyze the data once you get the results and call another function.
|
85 |
At each iteration please continue adding the your analysis to previous summary.
|
86 |
Your final response should directly answer the user query. Don't tell what you are doing, just do it.
|
|
|
87 |
|
88 |
|
89 |
Tools:
|
@@ -92,13 +93,45 @@ Here are the available tools:
|
|
92 |
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
93 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
94 |
|
95 |
-
When asked for the weather or points of interest, use the appropriate tool with the current location
|
96 |
Always assume user wants to travel by car.
|
97 |
|
98 |
Schema:
|
99 |
Use the following pydantic model json schema for each tool call you will make:
|
100 |
{schema}
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
Instructions:
|
103 |
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
104 |
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
|
|
84 |
Analyze the data once you get the results and call another function.
|
85 |
At each iteration please continue adding the your analysis to previous summary.
|
86 |
Your final response should directly answer the user query. Don't tell what you are doing, just do it.
|
87 |
+
Keep your responses very concise and to the point. Don't provide any unnecessary information. Don't refer to user preferences as <user_preferences>.
|
88 |
|
89 |
|
90 |
Tools:
|
|
|
93 |
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
94 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
95 |
|
96 |
+
When asked for the weather or points of interest, use the appropriate tool with the current location from <car_status>. If user provides a location, use that location.
|
97 |
Always assume user wants to travel by car.
|
98 |
|
99 |
Schema:
|
100 |
Use the following pydantic model json schema for each tool call you will make:
|
101 |
{schema}
|
102 |
|
103 |
+
Examples:
|
104 |
+
|
105 |
+
Example 1:
|
106 |
+
User: How is the weather?
|
107 |
+
Assistant:
|
108 |
+
<tool_call>
|
109 |
+
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
|
110 |
+
</tool_call>
|
111 |
+
|
112 |
+
Example 2:
|
113 |
+
User: Is there a Spa nearby?
|
114 |
+
Assistant:
|
115 |
+
<tool_call>
|
116 |
+
{{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interest"}}
|
117 |
+
</tool_call>
|
118 |
+
|
119 |
+
|
120 |
+
Example 3:
|
121 |
+
User: How long will it take to get to the destination?
|
122 |
+
Assistant:
|
123 |
+
<tool_call>
|
124 |
+
{{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
|
125 |
+
</tool_call>
|
126 |
+
|
127 |
+
Example 4:
|
128 |
+
User: Set the destination to Paris.
|
129 |
+
Assistant:
|
130 |
+
<tool_call>
|
131 |
+
{{"arguments": {{"destination": "Paris"}}, "name": "set_vehicle_destination"}}
|
132 |
+
</tool_call>
|
133 |
+
|
134 |
+
|
135 |
Instructions:
|
136 |
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
137 |
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
kitt/core/tts.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from replicate import Client
|
3 |
+
from loguru import logger
|
4 |
+
from kitt.skills.common import config
|
5 |
+
import torch
|
6 |
+
from parler_tts import ParlerTTSForConditionalGeneration
|
7 |
+
from transformers import AutoTokenizer, set_seed
|
8 |
+
import soundfile as sf
|
9 |
+
|
10 |
+
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
11 |
+
|
12 |
+
Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
|
13 |
+
|
14 |
+
voices_replicate = [
|
15 |
+
Voice(
|
16 |
+
"Attenborough",
|
17 |
+
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
|
18 |
+
angry=None,
|
19 |
+
speed=1.2,
|
20 |
+
),
|
21 |
+
Voice(
|
22 |
+
"Rick",
|
23 |
+
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/rick-neutral.wav",
|
24 |
+
angry=None,
|
25 |
+
speed=1.2,
|
26 |
+
),
|
27 |
+
Voice(
|
28 |
+
"Freeman",
|
29 |
+
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/freeman-neutral.wav",
|
30 |
+
angry="https://zebel.ams3.digitaloceanspaces.com/xtts/short/freeman-angry.wav",
|
31 |
+
speed=1.1,
|
32 |
+
),
|
33 |
+
Voice(
|
34 |
+
"Walken",
|
35 |
+
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/walken-neutral.wav",
|
36 |
+
angry=None,
|
37 |
+
speed=1.1,
|
38 |
+
),
|
39 |
+
Voice(
|
40 |
+
"Darth Wader",
|
41 |
+
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/darth-neutral.wav",
|
42 |
+
angry=None,
|
43 |
+
speed=1.15,
|
44 |
+
),
|
45 |
+
]
|
46 |
+
|
47 |
+
def voice_from_text(voice, voices):
|
48 |
+
for v in voices:
|
49 |
+
if voice == f"{v.name} - Neutral":
|
50 |
+
return v.neutral
|
51 |
+
if voice == f"{v.name} - Angry":
|
52 |
+
return v.angry
|
53 |
+
raise ValueError(f"Voice {voice} not found.")
|
54 |
+
|
55 |
+
|
56 |
+
def speed_from_text(voice, voices):
|
57 |
+
for v in voices:
|
58 |
+
if voice == f"{v.name} - Neutral":
|
59 |
+
return v.speed
|
60 |
+
if voice == f"{v.name} - Angry":
|
61 |
+
return v.speed
|
62 |
+
|
63 |
+
|
64 |
+
def run_tts_replicate(text: str, voice_character: str):
|
65 |
+
voice = voice_from_text(voice_character, voices_replicate)
|
66 |
+
|
67 |
+
input = {
|
68 |
+
"text": text,
|
69 |
+
"speaker": voice,
|
70 |
+
"cleanup_voice": True
|
71 |
+
}
|
72 |
+
|
73 |
+
output = replicate.run(
|
74 |
+
# "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
|
75 |
+
"lucataco/xtts-v2:684bc3855b37866c0c65add2ff39c78f3dea3f4ff103a436465326e0f438d55e",
|
76 |
+
input=input,
|
77 |
+
)
|
78 |
+
logger.info(f"sound output: {output}")
|
79 |
+
return output
|
80 |
+
|
81 |
+
|
82 |
+
def get_fast_tts():
|
83 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
84 |
+
|
85 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-expresso").to(device)
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
|
87 |
+
return model, tokenizer, device
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
fast_tts = get_fast_tts()
|
92 |
+
|
93 |
+
|
94 |
+
def run_tts_fast(text: str):
|
95 |
+
model, tokenizer, device = fast_tts
|
96 |
+
description = "Thomas speaks moderately slowly in a sad tone with emphasis and high quality audio."
|
97 |
+
|
98 |
+
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
|
99 |
+
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
|
100 |
+
|
101 |
+
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
102 |
+
audio_arr = generation.cpu().numpy().squeeze()
|
103 |
+
return model.config.sampling_rate, audio_arr, dict(text=text, voice="Thomas")
|
main.py
CHANGED
@@ -8,6 +8,7 @@ import typer
|
|
8 |
|
9 |
from kitt.skills.common import config, vehicle
|
10 |
from kitt.skills.routing import calculate_route
|
|
|
11 |
import ollama
|
12 |
|
13 |
from langchain.tools.base import StructuredTool
|
@@ -33,6 +34,7 @@ from kitt.skills import (
|
|
33 |
)
|
34 |
from kitt.skills import extract_func_args
|
35 |
from kitt.core import voice_options, tts_gradio
|
|
|
36 |
# from kitt.core.model import process_query
|
37 |
from kitt.core.model import generate_function_call as process_query
|
38 |
from kitt.core import utils as kitt_utils
|
@@ -144,7 +146,7 @@ functions = [
|
|
144 |
get_weather,
|
145 |
find_route,
|
146 |
search_points_of_interest,
|
147 |
-
search_along_route
|
148 |
]
|
149 |
openai_tools = [convert_to_openai_tool(tool) for tool in functions]
|
150 |
|
@@ -203,8 +205,8 @@ def run_nexusraven_model(query, voice_character, state):
|
|
203 |
|
204 |
def run_llama3_model(query, voice_character, state):
|
205 |
|
206 |
-
assert len
|
207 |
-
assert len
|
208 |
|
209 |
output_text = process_query(
|
210 |
query,
|
@@ -217,7 +219,9 @@ def run_llama3_model(query, voice_character, state):
|
|
217 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
218 |
voice_out = None
|
219 |
if state["tts_enabled"]:
|
220 |
-
voice_out =
|
|
|
|
|
221 |
return (
|
222 |
output_text,
|
223 |
voice_out,
|
@@ -340,10 +344,13 @@ def set_user_preferences(preferences, state):
|
|
340 |
|
341 |
def set_enable_history(enable_history, state):
|
342 |
new_enable_history = enable_history == "Yes"
|
343 |
-
logger.info(
|
|
|
|
|
344 |
state["enable_history"] = new_enable_history
|
345 |
return state
|
346 |
|
|
|
347 |
# to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
|
348 |
# in "Insecure origins treated as secure", enable it and relaunch chrome
|
349 |
|
@@ -354,9 +361,12 @@ def set_enable_history(enable_history, state):
|
|
354 |
|
355 |
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
356 |
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
|
|
|
|
|
|
357 |
|
358 |
|
359 |
-
def create_demo(tts_server: bool = False, model="llama3"
|
360 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
361 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
362 |
state = gr.State(
|
@@ -365,10 +375,10 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
365 |
"query": "",
|
366 |
"route_points": [],
|
367 |
"model": model,
|
368 |
-
"tts_enabled":
|
369 |
-
"llm_backend":
|
370 |
"user_preferences": USER_PREFERENCES,
|
371 |
-
"enable_history":
|
372 |
}
|
373 |
)
|
374 |
trip_points = gr.State(value=[])
|
@@ -388,6 +398,11 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
388 |
value=voice_options[0],
|
389 |
show_label=True,
|
390 |
)
|
|
|
|
|
|
|
|
|
|
|
391 |
origin = gr.Textbox(
|
392 |
value=ORIGIN,
|
393 |
label="Origin",
|
@@ -441,21 +456,21 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
441 |
)
|
442 |
with gr.Accordion("Config"):
|
443 |
tts_enabled = gr.Radio(
|
444 |
-
|
445 |
label="Enable TTS",
|
446 |
-
value="No",
|
447 |
interactive=True,
|
448 |
)
|
449 |
llm_backend = gr.Radio(
|
450 |
choices=["Ollama", "Replicate"],
|
451 |
label="LLM Backend",
|
452 |
-
value=
|
453 |
interactive=True,
|
454 |
)
|
455 |
enable_history = gr.Radio(
|
456 |
["Yes", "No"],
|
457 |
label="Maintain the conversation history?",
|
458 |
-
value="No",
|
459 |
interactive=True,
|
460 |
)
|
461 |
# Push button
|
@@ -529,7 +544,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
529 |
enable_history.change(
|
530 |
fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
|
531 |
)
|
532 |
-
|
533 |
return demo
|
534 |
|
535 |
|
@@ -537,7 +552,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
537 |
gr.close_all()
|
538 |
|
539 |
|
540 |
-
demo = create_demo(False, "llama3"
|
541 |
demo.launch(
|
542 |
debug=True,
|
543 |
server_name="0.0.0.0",
|
|
|
8 |
|
9 |
from kitt.skills.common import config, vehicle
|
10 |
from kitt.skills.routing import calculate_route
|
11 |
+
from kitt.core.tts import run_tts_replicate, run_tts_fast
|
12 |
import ollama
|
13 |
|
14 |
from langchain.tools.base import StructuredTool
|
|
|
34 |
)
|
35 |
from kitt.skills import extract_func_args
|
36 |
from kitt.core import voice_options, tts_gradio
|
37 |
+
|
38 |
# from kitt.core.model import process_query
|
39 |
from kitt.core.model import generate_function_call as process_query
|
40 |
from kitt.core import utils as kitt_utils
|
|
|
146 |
get_weather,
|
147 |
find_route,
|
148 |
search_points_of_interest,
|
149 |
+
search_along_route,
|
150 |
]
|
151 |
openai_tools = [convert_to_openai_tool(tool) for tool in functions]
|
152 |
|
|
|
205 |
|
206 |
def run_llama3_model(query, voice_character, state):
|
207 |
|
208 |
+
assert len(functions) > 0, "No functions to call"
|
209 |
+
assert len(openai_tools) > 0, "No openai tools to call"
|
210 |
|
211 |
output_text = process_query(
|
212 |
query,
|
|
|
219 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
220 |
voice_out = None
|
221 |
if state["tts_enabled"]:
|
222 |
+
# voice_out = run_tts_replicate(output_text, voice_character)
|
223 |
+
voice_out = run_tts_fast(output_text)[0]
|
224 |
+
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
225 |
return (
|
226 |
output_text,
|
227 |
voice_out,
|
|
|
344 |
|
345 |
def set_enable_history(enable_history, state):
|
346 |
new_enable_history = enable_history == "Yes"
|
347 |
+
logger.info(
|
348 |
+
f"Enable history was {state['enable_history']} and changed to {new_enable_history}"
|
349 |
+
)
|
350 |
state["enable_history"] = new_enable_history
|
351 |
return state
|
352 |
|
353 |
+
|
354 |
# to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
|
355 |
# in "Insecure origins treated as secure", enable it and relaunch chrome
|
356 |
|
|
|
361 |
|
362 |
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
363 |
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
364 |
+
DEFAULT_LLM_BACKEND = "ollama"
|
365 |
+
ENABLE_HISTORY = True
|
366 |
+
ENABLE_TTS = True
|
367 |
|
368 |
|
369 |
+
def create_demo(tts_server: bool = False, model="llama3"):
|
370 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
371 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
372 |
state = gr.State(
|
|
|
375 |
"query": "",
|
376 |
"route_points": [],
|
377 |
"model": model,
|
378 |
+
"tts_enabled": ENABLE_TTS,
|
379 |
+
"llm_backend": DEFAULT_LLM_BACKEND,
|
380 |
"user_preferences": USER_PREFERENCES,
|
381 |
+
"enable_history": ENABLE_HISTORY,
|
382 |
}
|
383 |
)
|
384 |
trip_points = gr.State(value=[])
|
|
|
398 |
value=voice_options[0],
|
399 |
show_label=True,
|
400 |
)
|
401 |
+
# voice_character = gr.Textbox(
|
402 |
+
# label="Choose a voice",
|
403 |
+
# value="freeman",
|
404 |
+
# show_label=True,
|
405 |
+
# )
|
406 |
origin = gr.Textbox(
|
407 |
value=ORIGIN,
|
408 |
label="Origin",
|
|
|
456 |
)
|
457 |
with gr.Accordion("Config"):
|
458 |
tts_enabled = gr.Radio(
|
459 |
+
["Yes", "No"],
|
460 |
label="Enable TTS",
|
461 |
+
value="Yes" if ENABLE_TTS else "No",
|
462 |
interactive=True,
|
463 |
)
|
464 |
llm_backend = gr.Radio(
|
465 |
choices=["Ollama", "Replicate"],
|
466 |
label="LLM Backend",
|
467 |
+
value=DEFAULT_LLM_BACKEND.title(),
|
468 |
interactive=True,
|
469 |
)
|
470 |
enable_history = gr.Radio(
|
471 |
["Yes", "No"],
|
472 |
label="Maintain the conversation history?",
|
473 |
+
value="Yes" if ENABLE_HISTORY else "No",
|
474 |
interactive=True,
|
475 |
)
|
476 |
# Push button
|
|
|
544 |
enable_history.change(
|
545 |
fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
|
546 |
)
|
547 |
+
|
548 |
return demo
|
549 |
|
550 |
|
|
|
552 |
gr.close_all()
|
553 |
|
554 |
|
555 |
+
demo = create_demo(False, "llama3")
|
556 |
demo.launch(
|
557 |
debug=True,
|
558 |
server_name="0.0.0.0",
|