Small updates
Browse files- kitt/core/__init__.py +2 -2
- main.py +20 -14
kitt/core/__init__.py
CHANGED
@@ -101,7 +101,7 @@ def speed_from_text(voice):
|
|
101 |
return v.speed
|
102 |
|
103 |
|
104 |
-
def
|
105 |
self,
|
106 |
text: str = "",
|
107 |
language_name: str = "",
|
@@ -198,7 +198,7 @@ def tts_gradio(text, voice, cache):
|
|
198 |
(gpt_cond_latent, speaker_embedding) = compute_speaker_embedding(
|
199 |
voice_path, tts_pipeline.synthesizer.tts_config, tts_pipeline, cache
|
200 |
)
|
201 |
-
out =
|
202 |
tts_pipeline.synthesizer,
|
203 |
text,
|
204 |
language_name="en",
|
|
|
101 |
return v.speed
|
102 |
|
103 |
|
104 |
+
def tts_xtts(
|
105 |
self,
|
106 |
text: str = "",
|
107 |
language_name: str = "",
|
|
|
198 |
(gpt_cond_latent, speaker_embedding) = compute_speaker_embedding(
|
199 |
voice_path, tts_pipeline.synthesizer.tts_config, tts_pipeline, cache
|
200 |
)
|
201 |
+
out = tts_xtts(
|
202 |
tts_pipeline.synthesizer,
|
203 |
text,
|
204 |
language_name="en",
|
main.py
CHANGED
@@ -40,7 +40,7 @@ from kitt.skills.routing import calculate_route, find_address
|
|
40 |
|
41 |
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
42 |
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
43 |
-
DEFAULT_LLM_BACKEND = "
|
44 |
ENABLE_HISTORY = True
|
45 |
ENABLE_TTS = True
|
46 |
TTS_BACKEND = "local"
|
@@ -133,11 +133,11 @@ def search_along_route(query=""):
|
|
133 |
|
134 |
def set_time(time_picker):
|
135 |
vehicle.time = time_picker
|
136 |
-
return vehicle.model_dump_json()
|
137 |
|
138 |
|
139 |
def get_vehicle_status(state):
|
140 |
-
return state.value["vehicle"].model_dump_json()
|
141 |
|
142 |
|
143 |
tools = [
|
@@ -232,11 +232,16 @@ def run_llama3_model(query, voice_character, state):
|
|
232 |
)
|
233 |
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
234 |
voice_out = None
|
235 |
-
if
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
# voice_out = run_tts_fast(output_text)[0]
|
238 |
-
|
239 |
-
# voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
240 |
return (
|
241 |
output_text,
|
242 |
voice_out,
|
@@ -264,7 +269,7 @@ def run_model(query, voice_character, state):
|
|
264 |
return (
|
265 |
text,
|
266 |
voice,
|
267 |
-
vehicle
|
268 |
state,
|
269 |
dict(update_proxy=global_context["update_proxy"]),
|
270 |
)
|
@@ -299,7 +304,8 @@ def update_vehicle_status(trip_progress, origin, destination, state):
|
|
299 |
plot = kitt_utils.plot_route(
|
300 |
global_context["route_points"], vehicle=vehicle.location_coordinates
|
301 |
)
|
302 |
-
return vehicle
|
|
|
303 |
|
304 |
|
305 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -335,8 +341,8 @@ def save_and_transcribe_audio(audio):
|
|
335 |
gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
|
336 |
|
337 |
except Exception as e:
|
338 |
-
|
339 |
-
|
340 |
return text
|
341 |
|
342 |
|
@@ -447,6 +453,9 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
447 |
|
448 |
with gr.Row():
|
449 |
with gr.Column(scale=1, min_width=300):
|
|
|
|
|
|
|
450 |
time_picker = gr.Dropdown(
|
451 |
choices=hour_options,
|
452 |
label="What time is it? (HH:MM)",
|
@@ -516,9 +525,6 @@ def create_demo(tts_server: bool = False, model="llama3"):
|
|
516 |
value=dict(update_proxy=0),
|
517 |
label="Global context",
|
518 |
)
|
519 |
-
vehicle_status = gr.JSON(
|
520 |
-
value=vehicle.model_dump_json(), label="Vehicle status"
|
521 |
-
)
|
522 |
with gr.Accordion("Config"):
|
523 |
tts_enabled = gr.Radio(
|
524 |
["Yes", "No"],
|
|
|
40 |
|
41 |
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
42 |
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
43 |
+
DEFAULT_LLM_BACKEND = "replicate"
|
44 |
ENABLE_HISTORY = True
|
45 |
ENABLE_TTS = True
|
46 |
TTS_BACKEND = "local"
|
|
|
133 |
|
134 |
def set_time(time_picker):
|
135 |
vehicle.time = time_picker
|
136 |
+
return vehicle.model_dump_json(indent=2)
|
137 |
|
138 |
|
139 |
def get_vehicle_status(state):
|
140 |
+
return state.value["vehicle"].model_dump_json(indent=2)
|
141 |
|
142 |
|
143 |
tools = [
|
|
|
232 |
)
|
233 |
gr.Info(f"Output text: {output_text}\nGenerating voice output...")
|
234 |
voice_out = None
|
235 |
+
if global_context["tts_enabled"]:
|
236 |
+
if "Fast" in voice_character:
|
237 |
+
voice_out = run_melo_tts(output_text, voice_character)
|
238 |
+
elif global_context["tts_backend"] == "replicate":
|
239 |
+
voice_out = run_tts_replicate(output_text, voice_character)
|
240 |
+
else:
|
241 |
+
voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
|
242 |
+
#
|
243 |
# voice_out = run_tts_fast(output_text)[0]
|
244 |
+
#
|
|
|
245 |
return (
|
246 |
output_text,
|
247 |
voice_out,
|
|
|
269 |
return (
|
270 |
text,
|
271 |
voice,
|
272 |
+
vehicle,
|
273 |
state,
|
274 |
dict(update_proxy=global_context["update_proxy"]),
|
275 |
)
|
|
|
304 |
plot = kitt_utils.plot_route(
|
305 |
global_context["route_points"], vehicle=vehicle.location_coordinates
|
306 |
)
|
307 |
+
return vehicle, plot, state
|
308 |
+
return vehicle.model_dump_json(indent=2), plot, state
|
309 |
|
310 |
|
311 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
341 |
gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
|
342 |
|
343 |
except Exception as e:
|
344 |
+
logger.error(f"Error: {e}")
|
345 |
+
raise Exception("Error transcribing audio.")
|
346 |
return text
|
347 |
|
348 |
|
|
|
453 |
|
454 |
with gr.Row():
|
455 |
with gr.Column(scale=1, min_width=300):
|
456 |
+
vehicle_status = gr.JSON(
|
457 |
+
value=vehicle.model_dump_json(indent=2), label="Vehicle status"
|
458 |
+
)
|
459 |
time_picker = gr.Dropdown(
|
460 |
choices=hour_options,
|
461 |
label="What time is it? (HH:MM)",
|
|
|
525 |
value=dict(update_proxy=0),
|
526 |
label="Global context",
|
527 |
)
|
|
|
|
|
|
|
528 |
with gr.Accordion("Config"):
|
529 |
tts_enabled = gr.Radio(
|
530 |
["Yes", "No"],
|