makaveli10 commited on
Commit
9dcd6a2
1 Parent(s): f2683ae

optimizations

Browse files
llm_service.py CHANGED
@@ -137,19 +137,26 @@ class MistralTensorRTLLM:
137
  output_ids,
138
  input_lengths,
139
  sequence_lengths,
 
140
  ):
141
  batch_size, num_beams, _ = output_ids.size()
142
  for batch_idx in range(batch_size):
143
- inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist(
144
- )
 
 
145
  input_text = self.tokenizer.decode(inputs)
146
  output = []
147
  for beam in range(num_beams):
 
 
 
148
  output_begin = input_lengths[batch_idx]
149
  output_end = sequence_lengths[batch_idx][beam]
150
  outputs = output_ids[batch_idx][beam][
151
  output_begin:output_end].tolist()
152
  output_text = self.tokenizer.decode(outputs)
 
153
  output.append(output_text)
154
  return output
155
 
@@ -179,15 +186,27 @@ class MistralTensorRTLLM:
179
  tokenizer_path,
180
  )
181
 
182
- print("Loaded LLM...")
183
  while True:
184
-
185
- # while transcription
186
  transcription_output = transcription_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
187
  prompt = transcription_output['prompt'].strip()
188
  input_text=[self.format_prompt_qa(prompt)]
189
 
190
- print("Whisper: ", prompt)
191
  batch_input_ids = self.parse_input(
192
  input_text=input_text,
193
  add_special_tokens=True,
@@ -225,8 +244,16 @@ class MistralTensorRTLLM:
225
  output = self.decode_tokens(
226
  output_ids,
227
  input_lengths,
228
- sequence_lengths
 
229
  )
 
 
 
 
 
 
 
230
  else:
231
  output_ids = outputs['output_ids']
232
  sequence_lengths = outputs['sequence_lengths']
@@ -239,6 +266,7 @@ class MistralTensorRTLLM:
239
  output_ids,
240
  input_lengths,
241
  sequence_lengths,
 
242
  )
243
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
244
  audio_queue.put(output)
 
137
  output_ids,
138
  input_lengths,
139
  sequence_lengths,
140
+ transcription_queue
141
  ):
142
  batch_size, num_beams, _ = output_ids.size()
143
  for batch_idx in range(batch_size):
144
+ if transcription_queue.qsize() != 0:
145
+ return None
146
+
147
+ inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist()
148
  input_text = self.tokenizer.decode(inputs)
149
  output = []
150
  for beam in range(num_beams):
151
+ if transcription_queue.qsize() != 0:
152
+ return None
153
+
154
  output_begin = input_lengths[batch_idx]
155
  output_end = sequence_lengths[batch_idx][beam]
156
  outputs = output_ids[batch_idx][beam][
157
  output_begin:output_end].tolist()
158
  output_text = self.tokenizer.decode(outputs)
159
+ print("[LLM] output:", output_text)
160
  output.append(output_text)
161
  return output
162
 
 
186
  tokenizer_path,
187
  )
188
 
189
+ print("[LLM] loaded: True")
190
  while True:
191
+
192
+ # Get the last transcription output from the queue
193
  transcription_output = transcription_queue.get()
194
+ if transcription_queue.qsize() != 0:
195
+ print("[LLM] transcription queue size:", transcription_queue.qsize())
196
+ continue
197
+ # while True:
198
+ # try:
199
+ # transcription_output = transcription_queue.get_nowait()
200
+ # except Exception as e:
201
+ # print("[Queue] exception", e)
202
+ # break
203
+
204
+ # transcription_output = transcription_queue.get()
205
+
206
  prompt = transcription_output['prompt'].strip()
207
  input_text=[self.format_prompt_qa(prompt)]
208
 
209
+ print("[Whisper] prompt:", prompt)
210
  batch_input_ids = self.parse_input(
211
  input_text=input_text,
212
  add_special_tokens=True,
 
244
  output = self.decode_tokens(
245
  output_ids,
246
  input_lengths,
247
+ sequence_lengths,
248
+ transcription_queue
249
  )
250
+
251
+ if output is None:
252
+ break
253
+ # Interrupted by transcription queue
254
+ if output is None:
255
+ print("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!", transcription_queue.qsize())
256
+ continue
257
  else:
258
  output_ids = outputs['output_ids']
259
  sequence_lengths = outputs['sequence_lengths']
 
266
  output_ids,
267
  input_lengths,
268
  sequence_lengths,
269
+ transcription_queue
270
  )
271
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
272
  audio_queue.put(output)
whisper_live/trt_server.py CHANGED
@@ -263,6 +263,9 @@ class ServeClient:
263
  self.task = task
264
  self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
265
 
 
 
 
266
  self.timestamp_offset = 0.0
267
  self.frames_np = None
268
  self.frames_offset = 0.0
@@ -396,10 +399,22 @@ class ServeClient:
396
  # self.append_segment(last_segment)
397
  self.timestamp_offset += duration
398
  self.prompt = ' '.join(segment['text'] for segment in segments)
399
- self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
 
 
 
400
  # self.set_eos(False)
401
  logging.info(f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
402
  )
 
 
 
 
 
 
 
 
 
403
 
404
  except Exception as e:
405
  logging.error(f"[ERROR]: {e}")
 
263
  self.task = task
264
  self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
265
 
266
+
267
+ self.last_prompt = None
268
+
269
  self.timestamp_offset = 0.0
270
  self.frames_np = None
271
  self.frames_offset = 0.0
 
399
  # self.append_segment(last_segment)
400
  self.timestamp_offset += duration
401
  self.prompt = ' '.join(segment['text'] for segment in segments)
402
+ if self.last_prompt != self.prompt:
403
+ self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
404
+
405
+ self.last_prompt = None
406
  # self.set_eos(False)
407
  logging.info(f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
408
  )
409
+ else:
410
+ self.prompt = ' '.join(segment['text'] for segment in segments)
411
+
412
+ if self.last_prompt != self.prompt:
413
+ self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
414
+
415
+ self.last_prompt = self.prompt
416
+
417
+
418
 
419
  except Exception as e:
420
  logging.error(f"[ERROR]: {e}")
whisper_live/trt_transcriber.py CHANGED
@@ -199,7 +199,7 @@ class WhisperTRTLLM(object):
199
  self.device = device
200
  self.tokenizer = get_tokenizer(
201
  False,
202
- num_languages=self.encoder.num_languages,
203
  language="en",
204
  task="transcribe",
205
  )
 
199
  self.device = device
200
  self.tokenizer = get_tokenizer(
201
  False,
202
+ # num_languages=self.encoder.num_languages,
203
  language="en",
204
  task="transcribe",
205
  )