makaveli10 commited on
Commit
f1e930a
1 Parent(s): e5e84e9

add whisper output to queue for llm process

Browse files
Files changed (3) hide show
  1. llm_service.py +5 -2
  2. main.py +51 -4
  3. whisper_live/trt_server.py +25 -131
llm_service.py CHANGED
@@ -155,6 +155,8 @@ class MistralTensorRTLLM:
155
 
156
  def run(
157
  self,
 
 
158
  transcription_queue=None,
159
  llm_queue=None,
160
  input_text=None,
@@ -166,9 +168,10 @@ class MistralTensorRTLLM:
166
  debug=False,
167
  ):
168
  self.initialize_model(
169
- "/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
170
- "teknium/OpenHermes-2.5-Mistral-7B",
171
  )
 
172
  print("Loaded LLM...")
173
  while True:
174
 
 
155
 
156
  def run(
157
  self,
158
+ model_path,
159
+ tokenizer_path,
160
  transcription_queue=None,
161
  llm_queue=None,
162
  input_text=None,
 
168
  debug=False,
169
  ):
170
  self.initialize_model(
171
+ model_path,
172
+ tokenizer_path,
173
  )
174
+
175
  print("Loaded LLM...")
176
  while True:
177
 
main.py CHANGED
@@ -1,6 +1,5 @@
1
- from whisper_live.trt_server import TranscriptionServer
2
- from llm_service import MistralTensorRTLLM
3
  import multiprocessing
 
4
  import threading
5
  import ssl
6
  import time
@@ -9,8 +8,39 @@ import functools
9
 
10
  from multiprocessing import Process, Manager, Value, Queue
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
14
  multiprocessing.set_start_method('spawn')
15
 
16
  lock = multiprocessing.Lock()
@@ -23,12 +53,29 @@ if __name__ == "__main__":
23
 
24
 
25
  whisper_server = TranscriptionServer()
26
- whisper_process = multiprocessing.Process(target=whisper_server.run, args=("0.0.0.0", 6006, transcription_queue, llm_queue))
 
 
 
 
 
 
 
 
 
27
  whisper_process.start()
28
 
29
  llm_provider = MistralTensorRTLLM()
30
  # llm_provider = MistralTensorRTLLMProvider()
31
- llm_process = multiprocessing.Process(target=llm_provider.run, args=(transcription_queue, llm_queue))
 
 
 
 
 
 
 
 
32
  llm_process.start()
33
 
34
  llm_process.join()
 
 
 
1
  import multiprocessing
2
+ import argparse
3
  import threading
4
  import ssl
5
  import time
 
8
 
9
  from multiprocessing import Process, Manager, Value, Queue
10
 
11
+ from whisper_live.trt_server import TranscriptionServer
12
+ from llm_service import MistralTensorRTLLM
13
+
14
+
15
+ def parse_arguments():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--whisper_tensorrt_path',
18
+ type=str,
19
+ default=None,
20
+ help='Whisper TensorRT model path')
21
+ parser.add_argument('--mistral_tensorrt_path',
22
+ type=str,
23
+ default=None,
24
+ help='Mistral TensorRT model path')
25
+ parser.add_argument('--mistral_tokenizer_path',
26
+ type=str,
27
+ default="teknium/OpenHermes-2.5-Mistral-7B",
28
+ help='Mistral TensorRT model path')
29
+ return parser.parse_args()
30
+
31
 
32
  if __name__ == "__main__":
33
+ args = parse_arguments()
34
+ if not args.whisper_tensorrt_path:
35
+ raise ValueError("Please provide whisper_tensorrt_path to run the pipeline.")
36
+ import sys
37
+ sys.exit(0)
38
+
39
+ if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
40
+ raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
41
+ import sys
42
+ sys.exit(0)
43
+
44
  multiprocessing.set_start_method('spawn')
45
 
46
  lock = multiprocessing.Lock()
 
53
 
54
 
55
  whisper_server = TranscriptionServer()
56
+ whisper_process = multiprocessing.Process(
57
+ target=whisper_server.run,
58
+ args=(
59
+ "0.0.0.0",
60
+ 6006,
61
+ transcription_queue,
62
+ llm_queue,
63
+ args.whisper_tensorrt_path
64
+ )
65
+ )
66
  whisper_process.start()
67
 
68
  llm_provider = MistralTensorRTLLM()
69
  # llm_provider = MistralTensorRTLLMProvider()
70
+ llm_process = multiprocessing.Process(
71
+ target=llm_provider.run,
72
+ args=(
73
+ args.mistral_tensorrt_path,
74
+ args.mistral_tokenizer_path,
75
+ transcription_queue,
76
+ llm_queue,
77
+ )
78
+ )
79
  llm_process.start()
80
 
81
  llm_process.join()
whisper_live/trt_server.py CHANGED
@@ -73,7 +73,7 @@ class TranscriptionServer:
73
 
74
  return wait_time / 60
75
 
76
- def recv_audio(self, websocket, transcription_queue=None, llm_queue=None):
77
  """
78
  Receive audio chunks from a client in an infinite loop.
79
 
@@ -121,7 +121,8 @@ class TranscriptionServer:
121
  task=options["task"],
122
  client_uid=options["uid"],
123
  transcription_queue=transcription_queue,
124
- llm_queue=llm_queue
 
125
  )
126
 
127
  self.clients[websocket] = client
@@ -132,16 +133,16 @@ class TranscriptionServer:
132
  try:
133
  frame_data = websocket.recv()
134
  frame_np = np.frombuffer(frame_data, dtype=np.float32)
135
- # print(frame_np.shape)
136
  # VAD
137
  try:
138
  speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
139
  if speech_prob < self.vad_threshold:
140
  no_voice_activity_chunks += 1
141
- # print("No speech", no_voice_activity_chunks, self.clients[websocket].eos)
142
  if no_voice_activity_chunks > 2:
143
  if not self.clients[websocket].eos:
144
  self.clients[websocket].set_eos(True)
 
145
  continue
146
  no_voice_activity_chunks = 0
147
  self.clients[websocket].set_eos(False)
@@ -172,7 +173,7 @@ class TranscriptionServer:
172
  del websocket
173
  break
174
 
175
- def run(self, host, port=9090, transcription_queue=None, llm_queue=None):
176
  """
177
  Run the transcription server.
178
 
@@ -181,7 +182,12 @@ class TranscriptionServer:
181
  port (int): The port number to bind the server.
182
  """
183
  with serve(
184
- functools.partial(self.recv_audio, transcription_queue=transcription_queue, llm_queue=llm_queue),
 
 
 
 
 
185
  host,
186
  port
187
  ) as server:
@@ -231,6 +237,7 @@ class ServeClient:
231
  client_uid=None,
232
  transcription_queue=None,
233
  llm_queue=None,
 
234
  ):
235
  """
236
  Initialize a ServeClient instance.
@@ -254,9 +261,7 @@ class ServeClient:
254
  self.frames = b""
255
  self.language = language if multilingual else "en"
256
  self.task = task
257
- device = "cuda" if torch.cuda.is_available() else "cpu"
258
- self.transcriber = WhisperTRTLLM(
259
- "whisper_small_en", False, "assets", device="cuda")
260
 
261
  self.timestamp_offset = 0.0
262
  self.frames_np = None
@@ -295,8 +300,6 @@ class ServeClient:
295
 
296
  def set_eos(self, eos):
297
  self.lock.acquire()
298
- # if self.eos != eos:
299
- # logging.info(f"[WhisperLive:] setting eos: {eos}")
300
  self.eos = eos
301
  self.lock.release()
302
 
@@ -345,13 +348,10 @@ class ServeClient:
345
  """
346
  while True:
347
  try:
348
- start = time.time()
349
  if self.llm_queue is not None:
350
  llm_output = self.llm_queue.get_nowait()
351
  if llm_output:
352
  self.websocket.send(json.dumps(llm_output))
353
- end = time.time()
354
- # print(f"Time to check LLM output {end - start}")
355
  except queue.Empty:
356
  pass
357
 
@@ -360,8 +360,7 @@ class ServeClient:
360
  break
361
 
362
  if self.frames_np is None:
363
- # print("frames is None..")
364
- time.sleep(0.05)
365
  continue
366
 
367
  # clip audio if the current chunk exceeds 30 seconds, this basically implies that
@@ -373,24 +372,19 @@ class ServeClient:
373
  samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
374
  input_bytes = self.frames_np[int(samples_take):].copy()
375
  duration = input_bytes.shape[0] / self.RATE
376
- if duration<1.0:
377
  continue
378
 
379
  try:
380
  input_sample = input_bytes.copy()
381
- # save_wav(input_sample)
382
- # whisper transcribe with prompt
383
  mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
384
  last_segment = self.transcriber.transcribe(mel)
385
-
386
  if len(last_segment):
387
- if len(self.transcript) < self.send_last_n_segments:
388
- segments = self.transcript
389
- else:
390
- segments = self.transcript[-self.send_last_n_segments:]
391
  segments.append({"text": last_segment})
392
  try:
393
- # print(f"Sending... {segments}")
394
  self.websocket.send(
395
  json.dumps({
396
  "uid": self.client_uid,
@@ -399,115 +393,22 @@ class ServeClient:
399
  })
400
  )
401
  if self.eos:
402
- self.append_segment(last_segment)
403
  self.timestamp_offset += duration
404
  self.prompt = ' '.join(segment['text'] for segment in segments)
405
  self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
406
- self.transcript = []
407
  self.set_eos(False)
408
 
 
 
 
 
 
409
  except Exception as e:
410
  logging.error(f"[ERROR]: {e}")
411
 
412
  except Exception as e:
413
  logging.error(f"[ERROR]: {e}")
414
- time.sleep(0.01)
415
-
416
- def append_segment(self, result):
417
- # print("adding to trasncript: ", result)
418
- if not len(self.transcript):
419
- self.transcript.append({"text": result + " "})
420
- else:
421
- if self.transcript[-1]["text"].strip()[-1] == ".":
422
- if result[0] >= "a" and result[0] <= "z":
423
- self.transcript[-1]["text"] = replace_last_occurrence(
424
- self.transcript[-1]["text"], ".", ","
425
- )
426
- elif self.transcript[-1]["text"].strip()[-1] == "?":
427
- if result[0] >= "a" and result[0] <= "z":
428
- self.transcript[-1]["text"] = replace_last_occurrence(
429
- self.transcript[-1]["text"], "?", ","
430
- )
431
-
432
- self.transcript.append({"text": result + " "})
433
-
434
-
435
- def update_segments(self, segments, duration):
436
- """
437
- Processes the segments from whisper. Appends all the segments to the list
438
- except for the last segment assuming that it is incomplete.
439
-
440
- Updates the ongoing transcript with transcribed segments, including their start and end times.
441
- Complete segments are appended to the transcript in chronological order. Incomplete segments
442
- (assumed to be the last one) are processed to identify repeated content. If the same incomplete
443
- segment is seen multiple times, it updates the offset and appends the segment to the transcript.
444
- A threshold is used to detect repeated content and ensure it is only included once in the transcript.
445
- The timestamp offset is updated based on the duration of processed segments. The method returns the
446
- last processed segment, allowing it to be sent to the client for real-time updates.
447
-
448
- Args:
449
- segments(dict) : dictionary of segments as returned by whisper
450
- duration(float): duration of the current chunk
451
-
452
- Returns:
453
- dict or None: The last processed segment with its start time, end time, and transcribed text.
454
- Returns None if there are no valid segments to process.
455
- """
456
- offset = None
457
- self.current_out = ''
458
- last_segment = None
459
- # process complete segments
460
- if len(segments) > 1:
461
- for i, s in enumerate(segments[:-1]):
462
- text_ = s.text
463
- self.text.append(text_)
464
- start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
465
- self.transcript.append(
466
- {
467
- 'start': start,
468
- 'end': end,
469
- 'text': text_
470
- }
471
- )
472
-
473
- offset = min(duration, s.end)
474
-
475
- self.current_out += segments[-1].text
476
- last_segment = {
477
- 'start': self.timestamp_offset + segments[-1].start,
478
- 'end': self.timestamp_offset + min(duration, segments[-1].end),
479
- 'text': self.current_out
480
- }
481
-
482
- # if same incomplete segment is seen multiple times then update the offset
483
- # and append the segment to the list
484
- if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
485
- self.same_output_threshold += 1
486
- else:
487
- self.same_output_threshold = 0
488
-
489
- if self.same_output_threshold > 5:
490
- if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower():
491
- self.text.append(self.current_out)
492
- self.transcript.append(
493
- {
494
- 'start': self.timestamp_offset,
495
- 'end': self.timestamp_offset + duration,
496
- 'text': self.current_out
497
- }
498
- )
499
- self.current_out = ''
500
- offset = duration
501
- self.same_output_threshold = 0
502
- last_segment = None
503
- else:
504
- self.prev_out = self.current_out
505
-
506
- # update offset
507
- if offset is not None:
508
- self.timestamp_offset += offset
509
-
510
- return last_segment
511
 
512
  def disconnect(self):
513
  """
@@ -538,10 +439,3 @@ class ServeClient:
538
  logging.info("Cleaning up.")
539
  self.exit = True
540
  self.transcriber.destroy()
541
-
542
- def replace_last_occurrence(input_str, old_char, new_char):
543
- parts = input_str.rsplit(old_char, 1)
544
- if len(parts) == 2:
545
- return parts[0] + new_char + parts[1]
546
- else:
547
- return input_str
 
73
 
74
  return wait_time / 60
75
 
76
+ def recv_audio(self, websocket, transcription_queue=None, llm_queue=None, whisper_tensorrt_path=None):
77
  """
78
  Receive audio chunks from a client in an infinite loop.
79
 
 
121
  task=options["task"],
122
  client_uid=options["uid"],
123
  transcription_queue=transcription_queue,
124
+ llm_queue=llm_queue,
125
+ model_path=whisper_tensorrt_path
126
  )
127
 
128
  self.clients[websocket] = client
 
133
  try:
134
  frame_data = websocket.recv()
135
  frame_np = np.frombuffer(frame_data, dtype=np.float32)
136
+
137
  # VAD
138
  try:
139
  speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
140
  if speech_prob < self.vad_threshold:
141
  no_voice_activity_chunks += 1
 
142
  if no_voice_activity_chunks > 2:
143
  if not self.clients[websocket].eos:
144
  self.clients[websocket].set_eos(True)
145
+ time.sleep(0.25) # EOS stop receiving frames for a 250ms(to send output to LLM.)
146
  continue
147
  no_voice_activity_chunks = 0
148
  self.clients[websocket].set_eos(False)
 
173
  del websocket
174
  break
175
 
176
+ def run(self, host, port=9090, transcription_queue=None, llm_queue=None, whisper_tensorrt_path=None):
177
  """
178
  Run the transcription server.
179
 
 
182
  port (int): The port number to bind the server.
183
  """
184
  with serve(
185
+ functools.partial(
186
+ self.recv_audio,
187
+ transcription_queue=transcription_queue,
188
+ llm_queue=llm_queue,
189
+ whisper_tensorrt_path=whisper_tensorrt_path
190
+ ),
191
  host,
192
  port
193
  ) as server:
 
237
  client_uid=None,
238
  transcription_queue=None,
239
  llm_queue=None,
240
+ model_path=None
241
  ):
242
  """
243
  Initialize a ServeClient instance.
 
261
  self.frames = b""
262
  self.language = language if multilingual else "en"
263
  self.task = task
264
+ self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
 
 
265
 
266
  self.timestamp_offset = 0.0
267
  self.frames_np = None
 
300
 
301
  def set_eos(self, eos):
302
  self.lock.acquire()
 
 
303
  self.eos = eos
304
  self.lock.release()
305
 
 
348
  """
349
  while True:
350
  try:
 
351
  if self.llm_queue is not None:
352
  llm_output = self.llm_queue.get_nowait()
353
  if llm_output:
354
  self.websocket.send(json.dumps(llm_output))
 
 
355
  except queue.Empty:
356
  pass
357
 
 
360
  break
361
 
362
  if self.frames_np is None:
363
+ time.sleep(0.01) # wait for any audio to arrive
 
364
  continue
365
 
366
  # clip audio if the current chunk exceeds 30 seconds, this basically implies that
 
372
  samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
373
  input_bytes = self.frames_np[int(samples_take):].copy()
374
  duration = input_bytes.shape[0] / self.RATE
375
+ if duration<0.4:
376
  continue
377
 
378
  try:
379
  input_sample = input_bytes.copy()
380
+
 
381
  mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
382
  last_segment = self.transcriber.transcribe(mel)
383
+ segments = []
384
  if len(last_segment):
 
 
 
 
385
  segments.append({"text": last_segment})
386
  try:
387
+ print(f"Sending... {segments}")
388
  self.websocket.send(
389
  json.dumps({
390
  "uid": self.client_uid,
 
393
  })
394
  )
395
  if self.eos:
396
+ # self.append_segment(last_segment)
397
  self.timestamp_offset += duration
398
  self.prompt = ' '.join(segment['text'] for segment in segments)
399
  self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
 
400
  self.set_eos(False)
401
 
402
+ logging.info(
403
+ f"[INFO:] \
404
+ Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
405
+ )
406
+
407
  except Exception as e:
408
  logging.error(f"[ERROR]: {e}")
409
 
410
  except Exception as e:
411
  logging.error(f"[ERROR]: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  def disconnect(self):
414
  """
 
439
  logging.info("Cleaning up.")
440
  self.exit = True
441
  self.transcriber.destroy()