sasan commited on
Commit
cdb2b77
·
1 Parent(s): e3db752

Small updates

Browse files
Files changed (2) hide show
  1. kitt/core/__init__.py +2 -2
  2. main.py +20 -14
kitt/core/__init__.py CHANGED
@@ -101,7 +101,7 @@ def speed_from_text(voice):
101
  return v.speed
102
 
103
 
104
- def tts(
105
  self,
106
  text: str = "",
107
  language_name: str = "",
@@ -198,7 +198,7 @@ def tts_gradio(text, voice, cache):
198
  (gpt_cond_latent, speaker_embedding) = compute_speaker_embedding(
199
  voice_path, tts_pipeline.synthesizer.tts_config, tts_pipeline, cache
200
  )
201
- out = tts(
202
  tts_pipeline.synthesizer,
203
  text,
204
  language_name="en",
 
101
  return v.speed
102
 
103
 
104
+ def tts_xtts(
105
  self,
106
  text: str = "",
107
  language_name: str = "",
 
198
  (gpt_cond_latent, speaker_embedding) = compute_speaker_embedding(
199
  voice_path, tts_pipeline.synthesizer.tts_config, tts_pipeline, cache
200
  )
201
+ out = tts_xtts(
202
  tts_pipeline.synthesizer,
203
  text,
204
  language_name="en",
main.py CHANGED
@@ -40,7 +40,7 @@ from kitt.skills.routing import calculate_route, find_address
40
 
41
  ORIGIN = "Mondorf-les-Bains, Luxembourg"
42
  DESTINATION = "Rue Alphonse Weicker, Luxembourg"
43
- DEFAULT_LLM_BACKEND = "ollama"
44
  ENABLE_HISTORY = True
45
  ENABLE_TTS = True
46
  TTS_BACKEND = "local"
@@ -133,11 +133,11 @@ def search_along_route(query=""):
133
 
134
  def set_time(time_picker):
135
  vehicle.time = time_picker
136
- return vehicle.model_dump_json()
137
 
138
 
139
  def get_vehicle_status(state):
140
- return state.value["vehicle"].model_dump_json()
141
 
142
 
143
  tools = [
@@ -232,11 +232,16 @@ def run_llama3_model(query, voice_character, state):
232
  )
233
  gr.Info(f"Output text: {output_text}\nGenerating voice output...")
234
  voice_out = None
235
- if state["tts_enabled"]:
236
- # voice_out = run_tts_replicate(output_text, voice_character)
 
 
 
 
 
 
237
  # voice_out = run_tts_fast(output_text)[0]
238
- voice_out = run_melo_tts(output_text, voice_character)
239
- # voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
240
  return (
241
  output_text,
242
  voice_out,
@@ -264,7 +269,7 @@ def run_model(query, voice_character, state):
264
  return (
265
  text,
266
  voice,
267
- vehicle.model_dump_json(),
268
  state,
269
  dict(update_proxy=global_context["update_proxy"]),
270
  )
@@ -299,7 +304,8 @@ def update_vehicle_status(trip_progress, origin, destination, state):
299
  plot = kitt_utils.plot_route(
300
  global_context["route_points"], vehicle=vehicle.location_coordinates
301
  )
302
- return vehicle.model_dump_json(), plot, state
 
303
 
304
 
305
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -335,8 +341,8 @@ def save_and_transcribe_audio(audio):
335
  gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
336
 
337
  except Exception as e:
338
- print(f"Error: {e}")
339
- return "Error transcribing audio."
340
  return text
341
 
342
 
@@ -447,6 +453,9 @@ def create_demo(tts_server: bool = False, model="llama3"):
447
 
448
  with gr.Row():
449
  with gr.Column(scale=1, min_width=300):
 
 
 
450
  time_picker = gr.Dropdown(
451
  choices=hour_options,
452
  label="What time is it? (HH:MM)",
@@ -516,9 +525,6 @@ def create_demo(tts_server: bool = False, model="llama3"):
516
  value=dict(update_proxy=0),
517
  label="Global context",
518
  )
519
- vehicle_status = gr.JSON(
520
- value=vehicle.model_dump_json(), label="Vehicle status"
521
- )
522
  with gr.Accordion("Config"):
523
  tts_enabled = gr.Radio(
524
  ["Yes", "No"],
 
40
 
41
  ORIGIN = "Mondorf-les-Bains, Luxembourg"
42
  DESTINATION = "Rue Alphonse Weicker, Luxembourg"
43
+ DEFAULT_LLM_BACKEND = "replicate"
44
  ENABLE_HISTORY = True
45
  ENABLE_TTS = True
46
  TTS_BACKEND = "local"
 
133
 
134
  def set_time(time_picker):
135
  vehicle.time = time_picker
136
+ return vehicle.model_dump_json(indent=2)
137
 
138
 
139
  def get_vehicle_status(state):
140
+ return state.value["vehicle"].model_dump_json(indent=2)
141
 
142
 
143
  tools = [
 
232
  )
233
  gr.Info(f"Output text: {output_text}\nGenerating voice output...")
234
  voice_out = None
235
+ if global_context["tts_enabled"]:
236
+ if "Fast" in voice_character:
237
+ voice_out = run_melo_tts(output_text, voice_character)
238
+ elif global_context["tts_backend"] == "replicate":
239
+ voice_out = run_tts_replicate(output_text, voice_character)
240
+ else:
241
+ voice_out = tts_gradio(output_text, voice_character, speaker_embedding_cache)[0]
242
+ #
243
  # voice_out = run_tts_fast(output_text)[0]
244
+ #
 
245
  return (
246
  output_text,
247
  voice_out,
 
269
  return (
270
  text,
271
  voice,
272
+ vehicle,
273
  state,
274
  dict(update_proxy=global_context["update_proxy"]),
275
  )
 
304
  plot = kitt_utils.plot_route(
305
  global_context["route_points"], vehicle=vehicle.location_coordinates
306
  )
307
+ return vehicle, plot, state
308
+ return vehicle.model_dump_json(indent=2), plot, state
309
 
310
 
311
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
341
  gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
342
 
343
  except Exception as e:
344
+ logger.error(f"Error: {e}")
345
+ raise Exception("Error transcribing audio.")
346
  return text
347
 
348
 
 
453
 
454
  with gr.Row():
455
  with gr.Column(scale=1, min_width=300):
456
+ vehicle_status = gr.JSON(
457
+ value=vehicle.model_dump_json(indent=2), label="Vehicle status"
458
+ )
459
  time_picker = gr.Dropdown(
460
  choices=hour_options,
461
  label="What time is it? (HH:MM)",
 
525
  value=dict(update_proxy=0),
526
  label="Global context",
527
  )
 
 
 
528
  with gr.Accordion("Config"):
529
  tts_enabled = gr.Radio(
530
  ["Yes", "No"],