sasan commited on
Commit
78e760c
·
1 Parent(s): 0f04201

chore: Update vehicle speed and destination handling functions

Browse files
Files changed (3) hide show
  1. kitt/core/model.py +34 -1
  2. kitt/core/tts.py +103 -0
  3. main.py +30 -15
kitt/core/model.py CHANGED
@@ -84,6 +84,7 @@ Don't make assumptions about tool results if <tool_response> XML tags are not pr
84
  Analyze the data once you get the results and call another function.
85
  At each iteration please continue adding the your analysis to previous summary.
86
  Your final response should directly answer the user query. Don't tell what you are doing, just do it.
 
87
 
88
 
89
  Tools:
@@ -92,13 +93,45 @@ Here are the available tools:
92
  Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
93
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
94
 
95
- When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
96
  Always assume user wants to travel by car.
97
 
98
  Schema:
99
  Use the following pydantic model json schema for each tool call you will make:
100
  {schema}
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  Instructions:
103
  At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
104
  Please keep a running summary with analysis of previous function results and summaries from previous iterations.
 
84
  Analyze the data once you get the results and call another function.
85
  At each iteration please continue adding the your analysis to previous summary.
86
  Your final response should directly answer the user query. Don't tell what you are doing, just do it.
87
+ Keep your responses very concise and to the point. Don't provide any unnecessary information. Don't refer to user preferences as <user_preferences>.
88
 
89
 
90
  Tools:
 
93
  Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
94
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
95
 
96
+ When asked for the weather or points of interest, use the appropriate tool with the current location from <car_status>. If user provides a location, use that location.
97
  Always assume user wants to travel by car.
98
 
99
  Schema:
100
  Use the following pydantic model json schema for each tool call you will make:
101
  {schema}
102
 
103
+ Examples:
104
+
105
+ Example 1:
106
+ User: How is the weather?
107
+ Assistant:
108
+ <tool_call>
109
+ {{"arguments": {{"location": ""}}, "name": "get_weather"}}
110
+ </tool_call>
111
+
112
+ Example 2:
113
+ User: Is there a Spa nearby?
114
+ Assistant:
115
+ <tool_call>
116
+ {{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interest"}}
117
+ </tool_call>
118
+
119
+
120
+ Example 3:
121
+ User: How long will it take to get to the destination?
122
+ Assistant:
123
+ <tool_call>
124
+ {{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
125
+ </tool_call>
126
+
127
+ Example 4:
128
+ User: Set the destination to Paris.
129
+ Assistant:
130
+ <tool_call>
131
+ {{"arguments": {{"destination": "Paris"}}, "name": "set_vehicle_destination"}}
132
+ </tool_call>
133
+
134
+
135
  Instructions:
136
  At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
137
  Please keep a running summary with analysis of previous function results and summaries from previous iterations.
kitt/core/tts.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from replicate import Client
3
+ from loguru import logger
4
+ from kitt.skills.common import config
5
+ import torch
6
+ from parler_tts import ParlerTTSForConditionalGeneration
7
+ from transformers import AutoTokenizer, set_seed
8
+ import soundfile as sf
9
+
10
+ replicate = Client(api_token=config.REPLICATE_API_KEY)
11
+
12
+ Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
13
+
14
+ voices_replicate = [
15
+ Voice(
16
+ "Attenborough",
17
+ neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
18
+ angry=None,
19
+ speed=1.2,
20
+ ),
21
+ Voice(
22
+ "Rick",
23
+ neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/rick-neutral.wav",
24
+ angry=None,
25
+ speed=1.2,
26
+ ),
27
+ Voice(
28
+ "Freeman",
29
+ neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/freeman-neutral.wav",
30
+ angry="https://zebel.ams3.digitaloceanspaces.com/xtts/short/freeman-angry.wav",
31
+ speed=1.1,
32
+ ),
33
+ Voice(
34
+ "Walken",
35
+ neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/walken-neutral.wav",
36
+ angry=None,
37
+ speed=1.1,
38
+ ),
39
+ Voice(
40
+ "Darth Wader",
41
+ neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/darth-neutral.wav",
42
+ angry=None,
43
+ speed=1.15,
44
+ ),
45
+ ]
46
+
47
+ def voice_from_text(voice, voices):
48
+ for v in voices:
49
+ if voice == f"{v.name} - Neutral":
50
+ return v.neutral
51
+ if voice == f"{v.name} - Angry":
52
+ return v.angry
53
+ raise ValueError(f"Voice {voice} not found.")
54
+
55
+
56
+ def speed_from_text(voice, voices):
57
+ for v in voices:
58
+ if voice == f"{v.name} - Neutral":
59
+ return v.speed
60
+ if voice == f"{v.name} - Angry":
61
+ return v.speed
62
+
63
+
64
+ def run_tts_replicate(text: str, voice_character: str):
65
+ voice = voice_from_text(voice_character, voices_replicate)
66
+
67
+ input = {
68
+ "text": text,
69
+ "speaker": voice,
70
+ "cleanup_voice": True
71
+ }
72
+
73
+ output = replicate.run(
74
+ # "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
75
+ "lucataco/xtts-v2:684bc3855b37866c0c65add2ff39c78f3dea3f4ff103a436465326e0f438d55e",
76
+ input=input,
77
+ )
78
+ logger.info(f"sound output: {output}")
79
+ return output
80
+
81
+
82
+ def get_fast_tts():
83
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
84
+
85
+ model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-expresso").to(device)
86
+ tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
87
+ return model, tokenizer, device
88
+
89
+
90
+
91
+ fast_tts = get_fast_tts()
92
+
93
+
94
+ def run_tts_fast(text: str):
95
+ model, tokenizer, device = fast_tts
96
+ description = "Thomas speaks moderately slowly in a sad tone with emphasis and high quality audio."
97
+
98
+ input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
99
+ prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
100
+
101
+ generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
102
+ audio_arr = generation.cpu().numpy().squeeze()
103
+ return model.config.sampling_rate, audio_arr, dict(text=text, voice="Thomas")
main.py CHANGED
@@ -8,6 +8,7 @@ import typer
8
 
9
  from kitt.skills.common import config, vehicle
10
  from kitt.skills.routing import calculate_route
 
11
  import ollama
12
 
13
  from langchain.tools.base import StructuredTool
@@ -33,6 +34,7 @@ from kitt.skills import (
33
  )
34
  from kitt.skills import extract_func_args
35
  from kitt.core import voice_options, tts_gradio
 
36
  # from kitt.core.model import process_query
37
  from kitt.core.model import generate_function_call as process_query
38
  from kitt.core import utils as kitt_utils
@@ -144,7 +146,7 @@ functions = [
144
  get_weather,
145
  find_route,
146
  search_points_of_interest,
147
- search_along_route
148
  ]
149
  openai_tools = [convert_to_openai_tool(tool) for tool in functions]
150
 
@@ -203,8 +205,8 @@ def run_nexusraven_model(query, voice_character, state):
203
 
204
  def run_llama3_model(query, voice_character, state):
205
 
206
- assert len (functions) > 0, "No functions to call"
207
- assert len (openai_tools) > 0, "No openai tools to call"
208
 
209
  output_text = process_query(
210
  query,
@@ -217,7 +219,9 @@ def run_llama3_model(query, voice_character, state):
217
  gr.Info(f"Output text: {output_text}, generating voice output...")
218
  voice_out = None
219
  if state["tts_enabled"]:
220
- voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
 
 
221
  return (
222
  output_text,
223
  voice_out,
@@ -340,10 +344,13 @@ def set_user_preferences(preferences, state):
340
 
341
  def set_enable_history(enable_history, state):
342
  new_enable_history = enable_history == "Yes"
343
- logger.info(f"Enable history was {state['enable_history']} and changed to {new_enable_history}")
 
 
344
  state["enable_history"] = new_enable_history
345
  return state
346
 
 
347
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
348
  # in "Insecure origins treated as secure", enable it and relaunch chrome
349
 
@@ -354,9 +361,12 @@ def set_enable_history(enable_history, state):
354
 
355
  ORIGIN = "Mondorf-les-Bains, Luxembourg"
356
  DESTINATION = "Rue Alphonse Weicker, Luxembourg"
 
 
 
357
 
358
 
359
- def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
360
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
361
  with gr.Blocks(theme=gr.themes.Default()) as demo:
362
  state = gr.State(
@@ -365,10 +375,10 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
365
  "query": "",
366
  "route_points": [],
367
  "model": model,
368
- "tts_enabled": tts_enabled,
369
- "llm_backend": "ollama",
370
  "user_preferences": USER_PREFERENCES,
371
- "enable_history": False,
372
  }
373
  )
374
  trip_points = gr.State(value=[])
@@ -388,6 +398,11 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
388
  value=voice_options[0],
389
  show_label=True,
390
  )
 
 
 
 
 
391
  origin = gr.Textbox(
392
  value=ORIGIN,
393
  label="Origin",
@@ -441,21 +456,21 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
441
  )
442
  with gr.Accordion("Config"):
443
  tts_enabled = gr.Radio(
444
- choices=["Yes", "No"],
445
  label="Enable TTS",
446
- value="No",
447
  interactive=True,
448
  )
449
  llm_backend = gr.Radio(
450
  choices=["Ollama", "Replicate"],
451
  label="LLM Backend",
452
- value="Ollama",
453
  interactive=True,
454
  )
455
  enable_history = gr.Radio(
456
  ["Yes", "No"],
457
  label="Maintain the conversation history?",
458
- value="No",
459
  interactive=True,
460
  )
461
  # Push button
@@ -529,7 +544,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
529
  enable_history.change(
530
  fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
531
  )
532
-
533
  return demo
534
 
535
 
@@ -537,7 +552,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
537
  gr.close_all()
538
 
539
 
540
- demo = create_demo(False, "llama3", tts_enabled=False)
541
  demo.launch(
542
  debug=True,
543
  server_name="0.0.0.0",
 
8
 
9
  from kitt.skills.common import config, vehicle
10
  from kitt.skills.routing import calculate_route
11
+ from kitt.core.tts import run_tts_replicate, run_tts_fast
12
  import ollama
13
 
14
  from langchain.tools.base import StructuredTool
 
34
  )
35
  from kitt.skills import extract_func_args
36
  from kitt.core import voice_options, tts_gradio
37
+
38
  # from kitt.core.model import process_query
39
  from kitt.core.model import generate_function_call as process_query
40
  from kitt.core import utils as kitt_utils
 
146
  get_weather,
147
  find_route,
148
  search_points_of_interest,
149
+ search_along_route,
150
  ]
151
  openai_tools = [convert_to_openai_tool(tool) for tool in functions]
152
 
 
205
 
206
  def run_llama3_model(query, voice_character, state):
207
 
208
+ assert len(functions) > 0, "No functions to call"
209
+ assert len(openai_tools) > 0, "No openai tools to call"
210
 
211
  output_text = process_query(
212
  query,
 
219
  gr.Info(f"Output text: {output_text}, generating voice output...")
220
  voice_out = None
221
  if state["tts_enabled"]:
222
+ # voice_out = run_tts_replicate(output_text, voice_character)
223
+ voice_out = run_tts_fast(output_text)[0]
224
+ # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
225
  return (
226
  output_text,
227
  voice_out,
 
344
 
345
  def set_enable_history(enable_history, state):
346
  new_enable_history = enable_history == "Yes"
347
+ logger.info(
348
+ f"Enable history was {state['enable_history']} and changed to {new_enable_history}"
349
+ )
350
  state["enable_history"] = new_enable_history
351
  return state
352
 
353
+
354
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
355
  # in "Insecure origins treated as secure", enable it and relaunch chrome
356
 
 
361
 
362
  ORIGIN = "Mondorf-les-Bains, Luxembourg"
363
  DESTINATION = "Rue Alphonse Weicker, Luxembourg"
364
+ DEFAULT_LLM_BACKEND = "ollama"
365
+ ENABLE_HISTORY = True
366
+ ENABLE_TTS = True
367
 
368
 
369
+ def create_demo(tts_server: bool = False, model="llama3"):
370
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
371
  with gr.Blocks(theme=gr.themes.Default()) as demo:
372
  state = gr.State(
 
375
  "query": "",
376
  "route_points": [],
377
  "model": model,
378
+ "tts_enabled": ENABLE_TTS,
379
+ "llm_backend": DEFAULT_LLM_BACKEND,
380
  "user_preferences": USER_PREFERENCES,
381
+ "enable_history": ENABLE_HISTORY,
382
  }
383
  )
384
  trip_points = gr.State(value=[])
 
398
  value=voice_options[0],
399
  show_label=True,
400
  )
401
+ # voice_character = gr.Textbox(
402
+ # label="Choose a voice",
403
+ # value="freeman",
404
+ # show_label=True,
405
+ # )
406
  origin = gr.Textbox(
407
  value=ORIGIN,
408
  label="Origin",
 
456
  )
457
  with gr.Accordion("Config"):
458
  tts_enabled = gr.Radio(
459
+ ["Yes", "No"],
460
  label="Enable TTS",
461
+ value="Yes" if ENABLE_TTS else "No",
462
  interactive=True,
463
  )
464
  llm_backend = gr.Radio(
465
  choices=["Ollama", "Replicate"],
466
  label="LLM Backend",
467
+ value=DEFAULT_LLM_BACKEND.title(),
468
  interactive=True,
469
  )
470
  enable_history = gr.Radio(
471
  ["Yes", "No"],
472
  label="Maintain the conversation history?",
473
+ value="Yes" if ENABLE_HISTORY else "No",
474
  interactive=True,
475
  )
476
  # Push button
 
544
  enable_history.change(
545
  fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
546
  )
547
+
548
  return demo
549
 
550
 
 
552
  gr.close_all()
553
 
554
 
555
+ demo = create_demo(False, "llama3")
556
  demo.launch(
557
  debug=True,
558
  server_name="0.0.0.0",