Spaces:
Paused
Paused
makaveli10
commited on
Commit
•
9dcd6a2
1
Parent(s):
f2683ae
optimizations
Browse files- llm_service.py +35 -7
- whisper_live/trt_server.py +16 -1
- whisper_live/trt_transcriber.py +1 -1
llm_service.py
CHANGED
@@ -137,19 +137,26 @@ class MistralTensorRTLLM:
|
|
137 |
output_ids,
|
138 |
input_lengths,
|
139 |
sequence_lengths,
|
|
|
140 |
):
|
141 |
batch_size, num_beams, _ = output_ids.size()
|
142 |
for batch_idx in range(batch_size):
|
143 |
-
|
144 |
-
|
|
|
|
|
145 |
input_text = self.tokenizer.decode(inputs)
|
146 |
output = []
|
147 |
for beam in range(num_beams):
|
|
|
|
|
|
|
148 |
output_begin = input_lengths[batch_idx]
|
149 |
output_end = sequence_lengths[batch_idx][beam]
|
150 |
outputs = output_ids[batch_idx][beam][
|
151 |
output_begin:output_end].tolist()
|
152 |
output_text = self.tokenizer.decode(outputs)
|
|
|
153 |
output.append(output_text)
|
154 |
return output
|
155 |
|
@@ -179,15 +186,27 @@ class MistralTensorRTLLM:
|
|
179 |
tokenizer_path,
|
180 |
)
|
181 |
|
182 |
-
print("
|
183 |
while True:
|
184 |
-
|
185 |
-
#
|
186 |
transcription_output = transcription_queue.get()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
prompt = transcription_output['prompt'].strip()
|
188 |
input_text=[self.format_prompt_qa(prompt)]
|
189 |
|
190 |
-
print("Whisper:
|
191 |
batch_input_ids = self.parse_input(
|
192 |
input_text=input_text,
|
193 |
add_special_tokens=True,
|
@@ -225,8 +244,16 @@ class MistralTensorRTLLM:
|
|
225 |
output = self.decode_tokens(
|
226 |
output_ids,
|
227 |
input_lengths,
|
228 |
-
sequence_lengths
|
|
|
229 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
else:
|
231 |
output_ids = outputs['output_ids']
|
232 |
sequence_lengths = outputs['sequence_lengths']
|
@@ -239,6 +266,7 @@ class MistralTensorRTLLM:
|
|
239 |
output_ids,
|
240 |
input_lengths,
|
241 |
sequence_lengths,
|
|
|
242 |
)
|
243 |
llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
|
244 |
audio_queue.put(output)
|
|
|
137 |
output_ids,
|
138 |
input_lengths,
|
139 |
sequence_lengths,
|
140 |
+
transcription_queue
|
141 |
):
|
142 |
batch_size, num_beams, _ = output_ids.size()
|
143 |
for batch_idx in range(batch_size):
|
144 |
+
if transcription_queue.qsize() != 0:
|
145 |
+
return None
|
146 |
+
|
147 |
+
inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist()
|
148 |
input_text = self.tokenizer.decode(inputs)
|
149 |
output = []
|
150 |
for beam in range(num_beams):
|
151 |
+
if transcription_queue.qsize() != 0:
|
152 |
+
return None
|
153 |
+
|
154 |
output_begin = input_lengths[batch_idx]
|
155 |
output_end = sequence_lengths[batch_idx][beam]
|
156 |
outputs = output_ids[batch_idx][beam][
|
157 |
output_begin:output_end].tolist()
|
158 |
output_text = self.tokenizer.decode(outputs)
|
159 |
+
print("[LLM] output:", output_text)
|
160 |
output.append(output_text)
|
161 |
return output
|
162 |
|
|
|
186 |
tokenizer_path,
|
187 |
)
|
188 |
|
189 |
+
print("[LLM] loaded: True")
|
190 |
while True:
|
191 |
+
|
192 |
+
# Get the last transcription output from the queue
|
193 |
transcription_output = transcription_queue.get()
|
194 |
+
if transcription_queue.qsize() != 0:
|
195 |
+
print("[LLM] transcription queue size:", transcription_queue.qsize())
|
196 |
+
continue
|
197 |
+
# while True:
|
198 |
+
# try:
|
199 |
+
# transcription_output = transcription_queue.get_nowait()
|
200 |
+
# except Exception as e:
|
201 |
+
# print("[Queue] exception", e)
|
202 |
+
# break
|
203 |
+
|
204 |
+
# transcription_output = transcription_queue.get()
|
205 |
+
|
206 |
prompt = transcription_output['prompt'].strip()
|
207 |
input_text=[self.format_prompt_qa(prompt)]
|
208 |
|
209 |
+
print("[Whisper] prompt:", prompt)
|
210 |
batch_input_ids = self.parse_input(
|
211 |
input_text=input_text,
|
212 |
add_special_tokens=True,
|
|
|
244 |
output = self.decode_tokens(
|
245 |
output_ids,
|
246 |
input_lengths,
|
247 |
+
sequence_lengths,
|
248 |
+
transcription_queue
|
249 |
)
|
250 |
+
|
251 |
+
if output is None:
|
252 |
+
break
|
253 |
+
# Interrupted by transcription queue
|
254 |
+
if output is None:
|
255 |
+
print("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!", transcription_queue.qsize())
|
256 |
+
continue
|
257 |
else:
|
258 |
output_ids = outputs['output_ids']
|
259 |
sequence_lengths = outputs['sequence_lengths']
|
|
|
266 |
output_ids,
|
267 |
input_lengths,
|
268 |
sequence_lengths,
|
269 |
+
transcription_queue
|
270 |
)
|
271 |
llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
|
272 |
audio_queue.put(output)
|
whisper_live/trt_server.py
CHANGED
@@ -263,6 +263,9 @@ class ServeClient:
|
|
263 |
self.task = task
|
264 |
self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
|
265 |
|
|
|
|
|
|
|
266 |
self.timestamp_offset = 0.0
|
267 |
self.frames_np = None
|
268 |
self.frames_offset = 0.0
|
@@ -396,10 +399,22 @@ class ServeClient:
|
|
396 |
# self.append_segment(last_segment)
|
397 |
self.timestamp_offset += duration
|
398 |
self.prompt = ' '.join(segment['text'] for segment in segments)
|
399 |
-
|
|
|
|
|
|
|
400 |
# self.set_eos(False)
|
401 |
logging.info(f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
|
402 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
except Exception as e:
|
405 |
logging.error(f"[ERROR]: {e}")
|
|
|
263 |
self.task = task
|
264 |
self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
|
265 |
|
266 |
+
|
267 |
+
self.last_prompt = None
|
268 |
+
|
269 |
self.timestamp_offset = 0.0
|
270 |
self.frames_np = None
|
271 |
self.frames_offset = 0.0
|
|
|
399 |
# self.append_segment(last_segment)
|
400 |
self.timestamp_offset += duration
|
401 |
self.prompt = ' '.join(segment['text'] for segment in segments)
|
402 |
+
if self.last_prompt != self.prompt:
|
403 |
+
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
404 |
+
|
405 |
+
self.last_prompt = None
|
406 |
# self.set_eos(False)
|
407 |
logging.info(f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
|
408 |
)
|
409 |
+
else:
|
410 |
+
self.prompt = ' '.join(segment['text'] for segment in segments)
|
411 |
+
|
412 |
+
if self.last_prompt != self.prompt:
|
413 |
+
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
414 |
+
|
415 |
+
self.last_prompt = self.prompt
|
416 |
+
|
417 |
+
|
418 |
|
419 |
except Exception as e:
|
420 |
logging.error(f"[ERROR]: {e}")
|
whisper_live/trt_transcriber.py
CHANGED
@@ -199,7 +199,7 @@ class WhisperTRTLLM(object):
|
|
199 |
self.device = device
|
200 |
self.tokenizer = get_tokenizer(
|
201 |
False,
|
202 |
-
num_languages=self.encoder.num_languages,
|
203 |
language="en",
|
204 |
task="transcribe",
|
205 |
)
|
|
|
199 |
self.device = device
|
200 |
self.tokenizer = get_tokenizer(
|
201 |
False,
|
202 |
+
# num_languages=self.encoder.num_languages,
|
203 |
language="en",
|
204 |
task="transcribe",
|
205 |
)
|