sasan commited on
Commit
42c9935
·
1 Parent(s): 65dafa6

remove nexus and refactor stt

Browse files
Files changed (2) hide show
  1. kitt/core/stt.py +39 -0
  2. main.py +9 -195
kitt/core/stt.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import time
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+ from loguru import logger
9
+ from transformers import pipeline
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ transcriber = pipeline(
13
+ "automatic-speech-recognition", model="openai/whisper-base.en", device=device
14
+ )
15
+
16
+
17
+ def save_audio_as_wav(data, sample_rate, file_path):
18
+ # make a tensor from the numpy array
19
+ data = torch.tensor(data).reshape(1, -1)
20
+ torchaudio.save(
21
+ file_path, data, sample_rate=sample_rate, bits_per_sample=16, encoding="PCM_S"
22
+ )
23
+
24
+
25
+ def save_and_transcribe_audio(audio):
26
+ sample_rate, data = audio
27
+ try:
28
+ # add timestamp to file name
29
+ filename = f"recordings/audio{time.time()}.wav"
30
+ save_audio_as_wav(data, sample_rate, filename)
31
+ data = data.astype(np.float32)
32
+ data /= np.max(np.abs(data))
33
+ text = transcriber({"sampling_rate": sample_rate, "raw": data})["text"]
34
+ gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
35
+
36
+ except Exception as e:
37
+ logger.error(f"Error: {e}")
38
+ raise Exception("Error transcribing audio.")
39
+ return text
main.py CHANGED
@@ -1,24 +1,14 @@
1
- import time
2
-
3
  import gradio as gr
4
- import numpy as np
5
- import ollama
6
- import torch
7
- import torchaudio
8
- import typer
9
  from langchain.memory import ChatMessageHistory
10
  from langchain.tools import tool
11
- from langchain.tools.base import StructuredTool
12
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
  from loguru import logger
14
- from transformers import pipeline
15
 
16
  from kitt.core import tts_gradio
17
  from kitt.core import utils as kitt_utils
18
  from kitt.core import voice_options
19
-
20
- # from kitt.core.model import process_query
21
  from kitt.core.model import generate_function_call as process_query
 
22
  from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_fast, run_tts_replicate
23
  from kitt.skills import (
24
  code_interpreter,
@@ -34,7 +24,6 @@ from kitt.skills import (
34
  set_vehicle_destination,
35
  set_vehicle_speed,
36
  )
37
- from kitt.skills import vehicle_status as vehicle_status_fn
38
  from kitt.skills.common import config, vehicle
39
  from kitt.skills.routing import calculate_route, find_address
40
 
@@ -65,54 +54,6 @@ global_context = {
65
  speaker_embedding_cache = {}
66
  history = ChatMessageHistory()
67
 
68
- MODEL_FUNC = "nexusraven"
69
- MODEL_GENERAL = "llama3:instruct"
70
-
71
- RAVEN_PROMPT_FUNC = """You are a helpful AI assistant in a car (vehicle), that follows instructions extremely well. \
72
- Answer questions concisely and do not mention what you base your reply on."
73
-
74
- {raven_tools}
75
-
76
- {history}
77
-
78
- User Query: Question: {input}<human_end>
79
- """
80
-
81
-
82
- HERMES_PROMPT_FUNC = """
83
- <|im_start|>system
84
- You are a helpful AI assistant in a car (vehicle), that follows instructions extremely well. \
85
- Answer questions concisely and do not mention what you base your reply on.<|im_end|>
86
- <|im_start|>user
87
- {{ .Prompt }}<|im_end|>
88
- <|im_start|>assistant
89
- """
90
-
91
-
92
- def get_prompt(template, input, history, tools):
93
- # "vehicle_status": vehicle_status_fn()[0]
94
- kwargs = {"history": history, "input": input}
95
- prompt = "<human>:\n"
96
- for tool in tools:
97
- func_signature, func_docstring = tool.description.split(" - ", 1)
98
- prompt += f'Function:\n<func_start>def {func_signature}<func_end>\n<docstring_start>\n"""\n{func_docstring}\n"""\n<docstring_end>\n'
99
- kwargs["raven_tools"] = prompt
100
-
101
- if history:
102
- kwargs["history"] = f"Previous conversation history:{history}\n"
103
-
104
- return template.format(**kwargs).replace("{{", "{").replace("}}", "}")
105
-
106
-
107
- def use_tool(func_name, kwargs, tools):
108
- for tool in tools:
109
- if tool.name == func_name:
110
- return tool.invoke(input=kwargs)
111
- return None
112
-
113
-
114
- # llm = Ollama(model="nexusraven", stop=["\nReflection:", "\nThought:"], keep_alive=60*10)
115
-
116
 
117
  # Generate options for hours (00-23)
118
  hour_options = [f"{i:02d}:00:00" for i in range(24)]
@@ -136,20 +77,6 @@ def set_time(time_picker):
136
  return vehicle
137
 
138
 
139
- tools = [
140
- # StructuredTool.from_function(get_weather),
141
- # StructuredTool.from_function(find_route),
142
- # StructuredTool.from_function(vehicle_status_fn),
143
- # StructuredTool.from_function(set_vehicle_speed),
144
- # StructuredTool.from_function(set_vehicle_destination),
145
- # StructuredTool.from_function(search_points_of_interest),
146
- # StructuredTool.from_function(search_along_route),
147
- # StructuredTool.from_function(date_time_info),
148
- # StructuredTool.from_function(get_weather_current_location),
149
- # StructuredTool.from_function(code_interpreter),
150
- # StructuredTool.from_function(do_anything_else),
151
- ]
152
-
153
  functions = [
154
  # set_vehicle_speed,
155
  set_vehicle_destination,
@@ -161,59 +88,11 @@ functions = [
161
  openai_tools = [convert_to_openai_tool(tool) for tool in functions]
162
 
163
 
164
- def run_generic_model(query):
165
- print(f"Running the generic model with query: {query}")
166
- data = {
167
- "prompt": f"Answer the question below in a short and concise manner.\n{query}",
168
- "model": MODEL_GENERAL,
169
- "options": {
170
- # "temperature": 0.1,
171
- # "stop":["\nReflection:", "\nThought:"]
172
- },
173
- }
174
- out = ollama.generate(**data)
175
- return out["response"]
176
-
177
-
178
  def clear_history():
179
  logger.info("Clearing the conversation history...")
180
  history.clear()
181
 
182
 
183
- def run_nexusraven_model(query, voice_character, state):
184
- global_context["prompt"] = get_prompt(RAVEN_PROMPT_FUNC, query, "", tools)
185
- print("Prompt: ", global_context["prompt"])
186
- data = {
187
- "prompt": global_context["prompt"],
188
- # "streaming": False,
189
- "model": "nexusraven",
190
- # "model": "smangrul/llama-3-8b-instruct-function-calling",
191
- "raw": True,
192
- "options": {"temperature": 0.5, "stop": ["\nReflection:", "\nThought:"]},
193
- }
194
- out = ollama.generate(**data)
195
- llm_response = out["response"]
196
- if "Call: " in llm_response:
197
- print(f"llm_response: {llm_response}")
198
- llm_response = llm_response.replace("<bot_end>", " ")
199
- func_name, kwargs = extract_func_args(llm_response)
200
- print(f"Function: {func_name}, Args: {kwargs}")
201
- if func_name == "do_anything_else":
202
- output_text = run_generic_model(query)
203
- else:
204
- output_text = use_tool(func_name, kwargs, tools)
205
- else:
206
- output_text = out["response"]
207
-
208
- if type(output_text) == tuple:
209
- output_text = output_text[0]
210
- gr.Info(f"Output text: {output_text}\nGenerating voice output...")
211
- return (
212
- output_text,
213
- tts_gradio(output_text, voice_character, speaker_embedding_cache)[0],
214
- )
215
-
216
-
217
  def run_llama3_model(query, voice_character, state):
218
 
219
  assert len(functions) > 0, "No functions to call"
@@ -249,18 +128,13 @@ def run_llama3_model(query, voice_character, state):
249
 
250
 
251
  def run_model(query, voice_character, state):
252
- model = state.get("model", "nexusraven")
253
  query = query.strip().replace("'", "")
254
  logger.info(
255
  f"Running model: {model} with query: {query}, voice_character: {voice_character} and llm_backend: {state['llm_backend']}, tts_enabled: {state['tts_enabled']}"
256
  )
257
  global_context["query"] = query
258
- if model == "nexusraven":
259
- text, voice = run_nexusraven_model(query, voice_character, state)
260
- elif model == "llama3":
261
- text, voice = run_llama3_model(query, voice_character, state)
262
- else:
263
- text, voice = "Error running model", None
264
 
265
  if not state["enable_history"]:
266
  history.clear()
@@ -308,44 +182,6 @@ def update_vehicle_status(trip_progress, origin, destination, state):
308
  return vehicle, plot, state
309
 
310
 
311
- device = "cuda" if torch.cuda.is_available() else "cpu"
312
- transcriber = pipeline(
313
- "automatic-speech-recognition", model="openai/whisper-base.en", device=device
314
- )
315
-
316
-
317
- def save_audio_as_wav(data, sample_rate, file_path):
318
- # make a tensor from the numpy array
319
- data = torch.tensor(data).reshape(1, -1)
320
- torchaudio.save(
321
- file_path, data, sample_rate=sample_rate, bits_per_sample=16, encoding="PCM_S"
322
- )
323
-
324
-
325
- def save_and_transcribe_audio(audio):
326
- try:
327
- # capture the audio and save it to a file as wav or mp3
328
- # file_name = save("audioinput.wav")
329
- sr, y = audio
330
- # y = y.astype(np.float32)
331
- # y /= np.max(np.abs(y))
332
-
333
- # add timestamp to file name
334
- filename = f"recordings/audio{time.time()}.wav"
335
- save_audio_as_wav(y, sr, filename)
336
-
337
- sr, y = audio
338
- y = y.astype(np.float32)
339
- y /= np.max(np.abs(y))
340
- text = transcriber({"sampling_rate": sr, "raw": y})["text"]
341
- gr.Info(f"Transcribed text is: {text}\nProcessing the input...")
342
-
343
- except Exception as e:
344
- logger.error(f"Error: {e}")
345
- raise Exception("Error transcribing audio.")
346
- return text
347
-
348
-
349
  def save_and_transcribe_run_model(audio, voice_character, state):
350
  text = save_and_transcribe_audio(audio)
351
  out_text, out_voice, vehicle_status, state, update_proxy = run_model(
@@ -494,7 +330,12 @@ def create_demo(tts_server: bool = False, model="llama3"):
494
  0, 100, step=5, label="Trip progress", interactive=True
495
  )
496
 
497
- # map_if = gr.Interface(fn=plot_map, inputs=year_input, outputs=map_plot)
 
 
 
 
 
498
 
499
  with gr.Row():
500
  with gr.Column():
@@ -647,30 +488,3 @@ demo.launch(
647
  ssl_verify=False,
648
  share=False,
649
  )
650
-
651
- app = typer.Typer()
652
-
653
-
654
- @app.command()
655
- def run(tts_server: bool = False):
656
- global demo
657
- demo = create_demo(tts_server)
658
- demo.launch(
659
- debug=True, server_name="0.0.0.0", server_port=7860, ssl_verify=True, share=True
660
- )
661
-
662
-
663
- @app.command()
664
- def dev(tts_server: bool = False, model: str = "llama3"):
665
- demo = create_demo(tts_server, model)
666
- demo.launch(
667
- debug=True,
668
- server_name="0.0.0.0",
669
- server_port=7860,
670
- ssl_verify=False,
671
- share=False,
672
- )
673
-
674
-
675
- if __name__ == "__main__":
676
- app()
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
  from langchain.memory import ChatMessageHistory
3
  from langchain.tools import tool
 
4
  from langchain_core.utils.function_calling import convert_to_openai_tool
5
  from loguru import logger
 
6
 
7
  from kitt.core import tts_gradio
8
  from kitt.core import utils as kitt_utils
9
  from kitt.core import voice_options
 
 
10
  from kitt.core.model import generate_function_call as process_query
11
+ from kitt.core.stt import save_and_transcribe_audio
12
  from kitt.core.tts import prep_for_tts, run_melo_tts, run_tts_fast, run_tts_replicate
13
  from kitt.skills import (
14
  code_interpreter,
 
24
  set_vehicle_destination,
25
  set_vehicle_speed,
26
  )
 
27
  from kitt.skills.common import config, vehicle
28
  from kitt.skills.routing import calculate_route, find_address
29
 
 
54
  speaker_embedding_cache = {}
55
  history = ChatMessageHistory()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Generate options for hours (00-23)
59
  hour_options = [f"{i:02d}:00:00" for i in range(24)]
 
77
  return vehicle
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  functions = [
81
  # set_vehicle_speed,
82
  set_vehicle_destination,
 
88
  openai_tools = [convert_to_openai_tool(tool) for tool in functions]
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def clear_history():
92
  logger.info("Clearing the conversation history...")
93
  history.clear()
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def run_llama3_model(query, voice_character, state):
97
 
98
  assert len(functions) > 0, "No functions to call"
 
128
 
129
 
130
  def run_model(query, voice_character, state):
131
+ model = state.get("model", "llama3")
132
  query = query.strip().replace("'", "")
133
  logger.info(
134
  f"Running model: {model} with query: {query}, voice_character: {voice_character} and llm_backend: {state['llm_backend']}, tts_enabled: {state['tts_enabled']}"
135
  )
136
  global_context["query"] = query
137
+ text, voice = run_llama3_model(query, voice_character, state)
 
 
 
 
 
138
 
139
  if not state["enable_history"]:
140
  history.clear()
 
182
  return vehicle, plot, state
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def save_and_transcribe_run_model(audio, voice_character, state):
186
  text = save_and_transcribe_audio(audio)
187
  out_text, out_voice, vehicle_status, state, update_proxy = run_model(
 
330
  0, 100, step=5, label="Trip progress", interactive=True
331
  )
332
 
333
+ # with gr.Column(scale=1, min_width=300):
334
+ # gr.Image("linkedin-1.png", label="Linkedin - Sasan Jafarnejad")
335
+ # gr.Image(
336
+ # "team-ubix.png",
337
+ # label="Research Team - UBIX - University of Luxembourg",
338
+ # )
339
 
340
  with gr.Row():
341
  with gr.Column():
 
488
  ssl_verify=False,
489
  share=False,
490
  )