makaveli10 commited on
Commit
6c3262d
1 Parent(s): 06d32c0

trt whisper live

Browse files
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PyAudio
2
+ faster-whisper==0.9.0
3
+ websockets
4
+ onnxruntime==1.16.0
run_faster_whisper_server.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from whisper_live.server import TranscriptionServer
2
+
3
+ if __name__ == "__main__":
4
+ server = TranscriptionServer()
5
+ server.run("0.0.0.0", 6006)
run_trt_server.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from whisper_live.trt_server import TranscriptionServer
2
+
3
+ if __name__ == "__main__":
4
+ server = TranscriptionServer()
5
+ server.run("0.0.0.0", 6006)
whisper_live/__init__.py ADDED
File without changes
whisper_live/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__="0.0.9"
whisper_live/client.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wave
3
+
4
+ import numpy as np
5
+ import scipy
6
+ import ffmpeg
7
+ import pyaudio
8
+ import threading
9
+ import textwrap
10
+ import json
11
+ import websocket
12
+ import uuid
13
+ import time
14
+
15
+
16
+ def resample(file: str, sr: int = 16000):
17
+ """
18
+ # https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22
19
+ Open an audio file and read as mono waveform, resampling as necessary,
20
+ save the resampled audio
21
+
22
+ Args:
23
+ file (str): The audio file to open
24
+ sr (int): The sample rate to resample the audio if necessary
25
+
26
+ Returns:
27
+ resampled_file (str): The resampled audio file
28
+ """
29
+ try:
30
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
31
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
32
+ out, _ = (
33
+ ffmpeg.input(file, threads=0)
34
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
35
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
36
+ )
37
+ except ffmpeg.Error as e:
38
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
39
+ np_buffer = np.frombuffer(out, dtype=np.int16)
40
+
41
+ resampled_file = f"{file.split('.')[0]}_resampled.wav"
42
+ scipy.io.wavfile.write(resampled_file, sr, np_buffer.astype(np.int16))
43
+ return resampled_file
44
+
45
+
46
+ class Client:
47
+ """
48
+ Handles audio recording, streaming, and communication with a server using WebSocket.
49
+ """
50
+ INSTANCES = {}
51
+
52
+ def __init__(
53
+ self, host=None, port=None, is_multilingual=False, lang=None, translate=False
54
+ ):
55
+ """
56
+ Initializes a Client instance for audio recording and streaming to a server.
57
+
58
+ If host and port are not provided, the WebSocket connection will not be established.
59
+ When translate is True, the task will be set to "translate" instead of "transcribe".
60
+ he audio recording starts immediately upon initialization.
61
+
62
+ Args:
63
+ host (str): The hostname or IP address of the server.
64
+ port (int): The port number for the WebSocket server.
65
+ is_multilingual (bool, optional): Specifies if multilingual transcription is enabled. Default is False.
66
+ lang (str, optional): The selected language for transcription when multilingual is disabled. Default is None.
67
+ translate (bool, optional): Specifies if the task is translation. Default is False.
68
+ """
69
+ self.chunk = 1024
70
+ self.format = pyaudio.paInt16
71
+ self.channels = 1
72
+ self.rate = 16000
73
+ self.record_seconds = 60000
74
+ self.recording = False
75
+ self.multilingual = False
76
+ self.language = None
77
+ self.task = "transcribe"
78
+ self.uid = str(uuid.uuid4())
79
+ self.waiting = False
80
+ self.last_response_recieved = None
81
+ self.disconnect_if_no_response_for = 15
82
+ self.multilingual = is_multilingual
83
+ self.language = lang if is_multilingual else "en"
84
+ if translate:
85
+ self.task = "translate"
86
+
87
+ self.timestamp_offset = 0.0
88
+ self.audio_bytes = None
89
+ self.p = pyaudio.PyAudio()
90
+ self.stream = self.p.open(
91
+ format=self.format,
92
+ channels=self.channels,
93
+ rate=self.rate,
94
+ input=True,
95
+ frames_per_buffer=self.chunk,
96
+ )
97
+
98
+ if host is not None and port is not None:
99
+ socket_url = f"ws://{host}:{port}"
100
+ self.client_socket = websocket.WebSocketApp(
101
+ socket_url,
102
+ on_open=lambda ws: self.on_open(ws),
103
+ on_message=lambda ws, message: self.on_message(ws, message),
104
+ on_error=lambda ws, error: self.on_error(ws, error),
105
+ on_close=lambda ws, close_status_code, close_msg: self.on_close(
106
+ ws, close_status_code, close_msg
107
+ ),
108
+ )
109
+ else:
110
+ print("[ERROR]: No host or port specified.")
111
+ return
112
+
113
+ Client.INSTANCES[self.uid] = self
114
+
115
+ # start websocket client in a thread
116
+ self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
117
+ self.ws_thread.setDaemon(True)
118
+ self.ws_thread.start()
119
+
120
+ self.frames = b""
121
+ print("[INFO]: * recording")
122
+
123
+ def on_message(self, ws, message):
124
+ """
125
+ Callback function called when a message is received from the server.
126
+
127
+ It updates various attributes of the client based on the received message, including
128
+ recording status, language detection, and server messages. If a disconnect message
129
+ is received, it sets the recording status to False.
130
+
131
+ Args:
132
+ ws (websocket.WebSocketApp): The WebSocket client instance.
133
+ message (str): The received message from the server.
134
+
135
+ """
136
+ self.last_response_recieved = time.time()
137
+ message = json.loads(message)
138
+
139
+ if self.uid != message.get("uid"):
140
+ print("[ERROR]: invalid client uid")
141
+ return
142
+
143
+ if "status" in message.keys() and message["status"] == "WAIT":
144
+ self.waiting = True
145
+ print(
146
+ f"[INFO]:Server is full. Estimated wait time {round(message['message'])} minutes."
147
+ )
148
+
149
+ if "message" in message.keys() and message["message"] == "DISCONNECT":
150
+ print("[INFO]: Server overtime disconnected.")
151
+ self.recording = False
152
+
153
+ if "message" in message.keys() and message["message"] == "SERVER_READY":
154
+ self.recording = True
155
+ return
156
+
157
+ if "language" in message.keys():
158
+ self.language = message.get("language")
159
+ lang_prob = message.get("language_prob")
160
+ print(
161
+ f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
162
+ )
163
+ return
164
+
165
+ if "segments" not in message.keys():
166
+ return
167
+
168
+ message = message["segments"]
169
+ text = []
170
+ if len(message):
171
+ for seg in message:
172
+ if text and text[-1] == seg["text"]:
173
+ # already got it
174
+ continue
175
+ text.append(seg["text"])
176
+ # keep only last 3
177
+ if len(text) > 3:
178
+ text = text[-3:]
179
+ wrapper = textwrap.TextWrapper(width=60)
180
+ word_list = wrapper.wrap(text="".join(text))
181
+ # Print each line.
182
+ if os.name == "nt":
183
+ os.system("cls")
184
+ else:
185
+ os.system("clear")
186
+ for element in word_list:
187
+ print(element)
188
+
189
+ def on_error(self, ws, error):
190
+ print(error)
191
+
192
+ def on_close(self, ws, close_status_code, close_msg):
193
+ print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
194
+
195
+ def on_open(self, ws):
196
+ """
197
+ Callback function called when the WebSocket connection is successfully opened.
198
+
199
+ Sends an initial configuration message to the server, including client UID, multilingual mode,
200
+ language selection, and task type.
201
+
202
+ Args:
203
+ ws (websocket.WebSocketApp): The WebSocket client instance.
204
+
205
+ """
206
+ print(self.multilingual, self.language, self.task)
207
+
208
+ print("[INFO]: Opened connection")
209
+ ws.send(
210
+ json.dumps(
211
+ {
212
+ "uid": self.uid,
213
+ "multilingual": self.multilingual,
214
+ "language": self.language,
215
+ "task": self.task,
216
+ }
217
+ )
218
+ )
219
+
220
+ @staticmethod
221
+ def bytes_to_float_array(audio_bytes):
222
+ """
223
+ Convert audio data from bytes to a NumPy float array.
224
+
225
+ It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
226
+ have values between -1 and 1.
227
+
228
+ Args:
229
+ audio_bytes (bytes): Audio data in bytes.
230
+
231
+ Returns:
232
+ np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
233
+ """
234
+ raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
235
+ return raw_data.astype(np.float32) / 32768.0
236
+
237
+ def send_packet_to_server(self, message):
238
+ """
239
+ Send an audio packet to the server using WebSocket.
240
+
241
+ Args:
242
+ message (bytes): The audio data packet in bytes to be sent to the server.
243
+
244
+ """
245
+ try:
246
+ self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
247
+ except Exception as e:
248
+ print(e)
249
+
250
+ def play_file(self, filename):
251
+ """
252
+ Play an audio file and send it to the server for processing.
253
+
254
+ Reads an audio file, plays it through the audio output, and simultaneously sends
255
+ the audio data to the server for processing. It uses PyAudio to create an audio
256
+ stream for playback. The audio data is read from the file in chunks, converted to
257
+ floating-point format, and sent to the server using WebSocket communication.
258
+ This method is typically used when you want to process pre-recorded audio and send it
259
+ to the server in real-time.
260
+
261
+ Args:
262
+ filename (str): The path to the audio file to be played and sent to the server.
263
+ """
264
+
265
+ # read audio and create pyaudio stream
266
+ with wave.open(filename, "rb") as wavfile:
267
+ self.stream = self.p.open(
268
+ format=self.p.get_format_from_width(wavfile.getsampwidth()),
269
+ channels=wavfile.getnchannels(),
270
+ rate=wavfile.getframerate(),
271
+ input=True,
272
+ output=True,
273
+ frames_per_buffer=self.chunk,
274
+ )
275
+ try:
276
+ while self.recording:
277
+ data = wavfile.readframes(self.chunk)
278
+ if data == b"":
279
+ break
280
+
281
+ audio_array = self.bytes_to_float_array(data)
282
+ self.send_packet_to_server(audio_array.tobytes())
283
+ self.stream.write(data)
284
+
285
+ wavfile.close()
286
+
287
+ assert self.last_response_recieved
288
+ while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for:
289
+ continue
290
+ self.stream.close()
291
+ self.close_websocket()
292
+
293
+ except KeyboardInterrupt:
294
+ wavfile.close()
295
+ self.stream.stop_stream()
296
+ self.stream.close()
297
+ self.p.terminate()
298
+ self.close_websocket()
299
+ print("[INFO]: Keyboard interrupt.")
300
+
301
+ def close_websocket(self):
302
+ """
303
+ Close the WebSocket connection and join the WebSocket thread.
304
+
305
+ First attempts to close the WebSocket connection using `self.client_socket.close()`. After
306
+ closing the connection, it joins the WebSocket thread to ensure proper termination.
307
+
308
+ """
309
+ try:
310
+ self.client_socket.close()
311
+ except Exception as e:
312
+ print("[ERROR]: Error closing WebSocket:", e)
313
+
314
+ try:
315
+ self.ws_thread.join()
316
+ except Exception as e:
317
+ print("[ERROR:] Error joining WebSocket thread:", e)
318
+
319
+ def get_client_socket(self):
320
+ """
321
+ Get the WebSocket client socket instance.
322
+
323
+ Returns:
324
+ WebSocketApp: The WebSocket client socket instance currently in use by the client.
325
+ """
326
+ return self.client_socket
327
+
328
+ def write_audio_frames_to_file(self, frames, file_name):
329
+ """
330
+ Write audio frames to a WAV file.
331
+
332
+ The WAV file is created or overwritten with the specified name. The audio frames should be
333
+ in the correct format and match the specified channel, sample width, and sample rate.
334
+
335
+ Args:
336
+ frames (bytes): The audio frames to be written to the file.
337
+ file_name (str): The name of the WAV file to which the frames will be written.
338
+
339
+ """
340
+ with wave.open(file_name, "wb") as wavfile:
341
+ wavfile: wave.Wave_write
342
+ wavfile.setnchannels(self.channels)
343
+ wavfile.setsampwidth(2)
344
+ wavfile.setframerate(self.rate)
345
+ wavfile.writeframes(frames)
346
+
347
+ def process_hls_stream(self, hls_url):
348
+ """
349
+ Connect to an HLS source, process the audio stream, and send it for transcription.
350
+
351
+ Args:
352
+ hls_url (str): The URL of the HLS stream source.
353
+ """
354
+ print("[INFO]: Connecting to HLS stream...")
355
+ process = None # Initialize process to None
356
+
357
+ try:
358
+ # Connecting to the HLS stream using ffmpeg-python
359
+ process = (
360
+ ffmpeg
361
+ .input(hls_url, threads=0)
362
+ .output('-', format='s16le', acodec='pcm_s16le', ac=1, ar=self.rate)
363
+ .run_async(pipe_stdout=True, pipe_stderr=True)
364
+ )
365
+
366
+ # Process the stream
367
+ while True:
368
+ in_bytes = process.stdout.read(self.chunk * 2) # 2 bytes per sample
369
+ if not in_bytes:
370
+ break
371
+ audio_array = self.bytes_to_float_array(in_bytes)
372
+ self.send_packet_to_server(audio_array.tobytes())
373
+
374
+ except Exception as e:
375
+ print(f"[ERROR]: Failed to connect to HLS stream: {e}")
376
+ finally:
377
+ if process:
378
+ process.kill()
379
+
380
+ print("[INFO]: HLS stream processing finished.")
381
+
382
+
383
+ def record(self, out_file="output_recording.wav"):
384
+ """
385
+ Record audio data from the input stream and save it to a WAV file.
386
+
387
+ Continuously records audio data from the input stream, sends it to the server via a WebSocket
388
+ connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
389
+ the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
390
+
391
+ Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
392
+ The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
393
+ The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
394
+ the method combines all the saved audio chunks into the specified `out_file`.
395
+
396
+ Args:
397
+ out_file (str, optional): The name of the output WAV file to save the entire recording. Default is "output_recording.wav".
398
+
399
+ """
400
+ n_audio_file = 0
401
+ if not os.path.exists("chunks"):
402
+ os.makedirs("chunks", exist_ok=True)
403
+ try:
404
+ for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
405
+ if not self.recording:
406
+ break
407
+ data = self.stream.read(self.chunk)
408
+ self.frames += data
409
+
410
+ audio_array = Client.bytes_to_float_array(data)
411
+
412
+ self.send_packet_to_server(audio_array.tobytes())
413
+
414
+ # save frames if more than a minute
415
+ if len(self.frames) > 60 * self.rate:
416
+ t = threading.Thread(
417
+ target=self.write_audio_frames_to_file,
418
+ args=(
419
+ self.frames[:],
420
+ f"chunks/{n_audio_file}.wav",
421
+ ),
422
+ )
423
+ t.start()
424
+ n_audio_file += 1
425
+ self.frames = b""
426
+
427
+ except KeyboardInterrupt:
428
+ if len(self.frames):
429
+ self.write_audio_frames_to_file(
430
+ self.frames[:], f"chunks/{n_audio_file}.wav"
431
+ )
432
+ n_audio_file += 1
433
+ self.stream.stop_stream()
434
+ self.stream.close()
435
+ self.p.terminate()
436
+ self.close_websocket()
437
+
438
+ self.write_output_recording(n_audio_file, out_file)
439
+
440
+ def write_output_recording(self, n_audio_file, out_file):
441
+ """
442
+ Combine and save recorded audio chunks into a single WAV file.
443
+
444
+ The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
445
+ file, appends its audio data to the final recording, and then deletes the chunk file. After combining
446
+ and saving, the final recording is stored in the specified `out_file`.
447
+
448
+
449
+ Args:
450
+ n_audio_file (int): The number of audio chunk files to combine.
451
+ out_file (str): The name of the output WAV file to save the final recording.
452
+
453
+ """
454
+ input_files = [
455
+ f"chunks/{i}.wav"
456
+ for i in range(n_audio_file)
457
+ if os.path.exists(f"chunks/{i}.wav")
458
+ ]
459
+ with wave.open(out_file, "wb") as wavfile:
460
+ wavfile: wave.Wave_write
461
+ wavfile.setnchannels(self.channels)
462
+ wavfile.setsampwidth(2)
463
+ wavfile.setframerate(self.rate)
464
+ for in_file in input_files:
465
+ with wave.open(in_file, "rb") as wav_in:
466
+ while True:
467
+ data = wav_in.readframes(self.chunk)
468
+ if data == b"":
469
+ break
470
+ wavfile.writeframes(data)
471
+ # remove this file
472
+ os.remove(in_file)
473
+ wavfile.close()
474
+
475
+
476
+ class TranscriptionClient:
477
+ """
478
+ Client for handling audio transcription tasks via a WebSocket connection.
479
+
480
+ Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
481
+ to send audio data for transcription to a server and receive transcribed text segments.
482
+
483
+ Args:
484
+ host (str): The hostname or IP address of the server.
485
+ port (int): The port number to connect to on the server.
486
+ is_multilingual (bool, optional): Indicates whether the transcription should support multiple languages (default is False).
487
+ lang (str, optional): The primary language for transcription (used if `is_multilingual` is False). Default is None, which defaults to English ('en').
488
+ translate (bool, optional): Indicates whether translation tasks are required (default is False).
489
+
490
+ Attributes:
491
+ client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
492
+
493
+ Example:
494
+ To create a TranscriptionClient and start transcription on microphone audio:
495
+ ```python
496
+ transcription_client = TranscriptionClient(host="localhost", port=9090, is_multilingual=True)
497
+ transcription_client()
498
+ ```
499
+ """
500
+ def __init__(self, host, port, is_multilingual=False, lang=None, translate=False):
501
+ self.client = Client(host, port, is_multilingual, lang, translate)
502
+
503
+ def __call__(self, audio=None, hls_url=None):
504
+ """
505
+ Start the transcription process.
506
+
507
+ Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
508
+ to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
509
+ will be played and streamed to the server; otherwise, it will perform live recording.
510
+
511
+ Args:
512
+ audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
513
+
514
+ """
515
+ print("[INFO]: Waiting for server ready ...")
516
+ while not self.client.recording:
517
+ if self.client.waiting:
518
+ self.client.close_websocket()
519
+ return
520
+ pass
521
+ print("[INFO]: Server Ready!")
522
+ if hls_url is not None:
523
+ self.client.process_hls_stream(hls_url)
524
+ elif audio is not None:
525
+ resampled_file = resample(audio)
526
+ self.client.play_file(resampled_file)
527
+ else:
528
+ self.client.record()
whisper_live/server.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import websockets
2
+ import time
3
+ import threading
4
+ import json
5
+ import textwrap
6
+
7
+ import logging
8
+ logging.basicConfig(level = logging.INFO)
9
+
10
+ from websockets.sync.server import serve
11
+
12
+ import torch
13
+ import numpy as np
14
+ import time
15
+ from whisper_live.transcriber import WhisperModel
16
+
17
+
18
+ class TranscriptionServer:
19
+ """
20
+ Represents a transcription server that handles incoming audio from clients.
21
+
22
+ Attributes:
23
+ RATE (int): The audio sampling rate (constant) set to 16000.
24
+ vad_model (torch.Module): The voice activity detection model.
25
+ vad_threshold (float): The voice activity detection threshold.
26
+ clients (dict): A dictionary to store connected clients.
27
+ websockets (dict): A dictionary to store WebSocket connections.
28
+ clients_start_time (dict): A dictionary to track client start times.
29
+ max_clients (int): Maximum allowed connected clients.
30
+ max_connection_time (int): Maximum allowed connection time in seconds.
31
+ """
32
+
33
+ RATE = 16000
34
+
35
+ def __init__(self):
36
+ # voice activity detection model
37
+
38
+ self.clients = {}
39
+ self.websockets = {}
40
+ self.clients_start_time = {}
41
+ self.max_clients = 4
42
+ self.max_connection_time = 600
43
+
44
+ def get_wait_time(self):
45
+ """
46
+ Calculate and return the estimated wait time for clients.
47
+
48
+ Returns:
49
+ float: The estimated wait time in minutes.
50
+ """
51
+ wait_time = None
52
+
53
+ for k, v in self.clients_start_time.items():
54
+ current_client_time_remaining = self.max_connection_time - (time.time() - v)
55
+
56
+ if wait_time is None or current_client_time_remaining < wait_time:
57
+ wait_time = current_client_time_remaining
58
+
59
+ return wait_time / 60
60
+
61
+ def recv_audio(self, websocket):
62
+ """
63
+ Receive audio chunks from a client in an infinite loop.
64
+
65
+ Continuously receives audio frames from a connected client
66
+ over a WebSocket connection. It processes the audio frames using a
67
+ voice activity detection (VAD) model to determine if they contain speech
68
+ or not. If the audio frame contains speech, it is added to the client's
69
+ audio data for ASR.
70
+ If the maximum number of clients is reached, the method sends a
71
+ "WAIT" status to the client, indicating that they should wait
72
+ until a slot is available.
73
+ If a client's connection exceeds the maximum allowed time, it will
74
+ be disconnected, and the client's resources will be cleaned up.
75
+
76
+ Args:
77
+ websocket (WebSocket): The WebSocket connection for the client.
78
+
79
+ Raises:
80
+ Exception: If there is an error during the audio frame processing.
81
+ """
82
+ logging.info("New client connected")
83
+ options = websocket.recv()
84
+ options = json.loads(options)
85
+
86
+ if len(self.clients) >= self.max_clients:
87
+ logging.warning("Client Queue Full. Asking client to wait ...")
88
+ wait_time = self.get_wait_time()
89
+ response = {
90
+ "uid": options["uid"],
91
+ "status": "WAIT",
92
+ "message": wait_time,
93
+ }
94
+ websocket.send(json.dumps(response))
95
+ websocket.close()
96
+ del websocket
97
+ return
98
+
99
+ client = ServeClient(
100
+ websocket,
101
+ multilingual=options["multilingual"],
102
+ language=options["language"],
103
+ task=options["task"],
104
+ client_uid=options["uid"]
105
+ )
106
+
107
+ self.clients[websocket] = client
108
+ self.clients_start_time[websocket] = time.time()
109
+
110
+ while True:
111
+ try:
112
+ frame_data = websocket.recv()
113
+ frame_np = np.frombuffer(frame_data, dtype=np.float32)
114
+
115
+ self.clients[websocket].add_frames(frame_np)
116
+
117
+ elapsed_time = time.time() - self.clients_start_time[websocket]
118
+ if elapsed_time >= self.max_connection_time:
119
+ self.clients[websocket].disconnect()
120
+ logging.warning(f"{self.clients[websocket]} Client disconnected due to overtime.")
121
+ self.clients[websocket].cleanup()
122
+ self.clients.pop(websocket)
123
+ self.clients_start_time.pop(websocket)
124
+ websocket.close()
125
+ del websocket
126
+ break
127
+
128
+ except Exception as e:
129
+ logging.error(e)
130
+ self.clients[websocket].cleanup()
131
+ self.clients.pop(websocket)
132
+ self.clients_start_time.pop(websocket)
133
+ logging.info("Connection Closed.")
134
+ logging.info(self.clients)
135
+ del websocket
136
+ break
137
+
138
+ def run(self, host, port=9090):
139
+ """
140
+ Run the transcription server.
141
+
142
+ Args:
143
+ host (str): The host address to bind the server.
144
+ port (int): The port number to bind the server.
145
+ """
146
+ with serve(self.recv_audio, host, port) as server:
147
+ server.serve_forever()
148
+
149
+
150
+ class ServeClient:
151
+ """
152
+ Attributes:
153
+ RATE (int): The audio sampling rate (constant) set to 16000.
154
+ SERVER_READY (str): A constant message indicating that the server is ready.
155
+ DISCONNECT (str): A constant message indicating that the client should disconnect.
156
+ client_uid (str): A unique identifier for the client.
157
+ data (bytes): Accumulated audio data.
158
+ frames (bytes): Accumulated audio frames.
159
+ language (str): The language for transcription.
160
+ task (str): The task type, e.g., "transcribe."
161
+ transcriber (WhisperModel): The Whisper model for speech-to-text.
162
+ timestamp_offset (float): The offset in audio timestamps.
163
+ frames_np (numpy.ndarray): NumPy array to store audio frames.
164
+ frames_offset (float): The offset in audio frames.
165
+ text (list): List of transcribed text segments.
166
+ current_out (str): The current incomplete transcription.
167
+ prev_out (str): The previous incomplete transcription.
168
+ t_start (float): Timestamp for the start of transcription.
169
+ exit (bool): A flag to exit the transcription thread.
170
+ same_output_threshold (int): Threshold for consecutive same output segments.
171
+ show_prev_out_thresh (int): Threshold for showing previous output segments.
172
+ add_pause_thresh (int): Threshold for adding a pause (blank) segment.
173
+ transcript (list): List of transcribed segments.
174
+ send_last_n_segments (int): Number of last segments to send to the client.
175
+ wrapper (textwrap.TextWrapper): Text wrapper for formatting text.
176
+ pick_previous_segments (int): Number of previous segments to include in the output.
177
+ websocket: The WebSocket connection for the client.
178
+ """
179
+ RATE = 16000
180
+ SERVER_READY = "SERVER_READY"
181
+ DISCONNECT = "DISCONNECT"
182
+
183
+ def __init__(self, websocket, task="transcribe", device=None, multilingual=False, language=None, client_uid=None):
184
+ """
185
+ Initialize a ServeClient instance.
186
+ The Whisper model is initialized based on the client's language and device availability.
187
+ The transcription thread is started upon initialization. A "SERVER_READY" message is sent
188
+ to the client to indicate that the server is ready.
189
+
190
+ Args:
191
+ websocket (WebSocket): The WebSocket connection for the client.
192
+ task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
193
+ device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
194
+ multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
195
+ language (str, optional): The language for transcription. Defaults to None.
196
+ client_uid (str, optional): A unique identifier for the client. Defaults to None.
197
+
198
+ """
199
+ self.client_uid = client_uid
200
+ self.data = b""
201
+ self.frames = b""
202
+ self.language = language if multilingual else "en"
203
+ self.task = task
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ self.transcriber = WhisperModel(
206
+ "small" if multilingual else "small.en",
207
+ device=device,
208
+ compute_type="int8" if device=="cpu" else "float16",
209
+ local_files_only=False,
210
+ )
211
+
212
+ self.timestamp_offset = 0.0
213
+ self.frames_np = None
214
+ self.frames_offset = 0.0
215
+ self.text = []
216
+ self.current_out = ''
217
+ self.prev_out = ''
218
+ self.t_start=None
219
+ self.exit = False
220
+ self.same_output_threshold = 0
221
+ self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
222
+ self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
223
+ self.transcript = []
224
+ self.send_last_n_segments = 10
225
+
226
+ # text formatting
227
+ self.wrapper = textwrap.TextWrapper(width=50)
228
+ self.pick_previous_segments = 2
229
+
230
+ # threading
231
+ self.websocket = websocket
232
+ self.trans_thread = threading.Thread(target=self.speech_to_text)
233
+ self.trans_thread.start()
234
+ self.websocket.send(
235
+ json.dumps(
236
+ {
237
+ "uid": self.client_uid,
238
+ "message": self.SERVER_READY
239
+ }
240
+ )
241
+ )
242
+
243
+ def fill_output(self, output):
244
+ """
245
+ Format the current incomplete transcription output by combining it with previous complete segments.
246
+ The resulting transcription is wrapped into two lines, each containing a maximum of 50 characters.
247
+
248
+ It ensures that the combined transcription fits within two lines, with a maximum of 50 characters per line.
249
+ Segments are concatenated in the order they exist in the list of previous segments, with the most
250
+ recent complete segment first and older segments prepended as needed to maintain the character limit.
251
+ If a 3-second pause is detected in the previous segments, any text preceding it is discarded to ensure
252
+ the transcription starts with the most recent complete content. The resulting transcription is returned
253
+ as a single string.
254
+
255
+ Args:
256
+ output(str): The current incomplete transcription segment.
257
+
258
+ Returns:
259
+ str: A formatted transcription wrapped in two lines.
260
+ """
261
+ text = ''
262
+ pick_prev = min(len(self.text), self.pick_previous_segments)
263
+ for seg in self.text[-pick_prev:]:
264
+ # discard everything before a 3 second pause
265
+ if seg == '':
266
+ text = ''
267
+ else:
268
+ text += seg
269
+ wrapped = "".join(text + output)
270
+ return wrapped
271
+
272
+ def add_frames(self, frame_np):
273
+ """
274
+ Add audio frames to the ongoing audio stream buffer.
275
+
276
+ This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
277
+ of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
278
+ to prevent excessive memory usage.
279
+
280
+ If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
281
+ of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
282
+ audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
283
+
284
+ Args:
285
+ frame_np (numpy.ndarray): The audio frame data as a NumPy array.
286
+
287
+ """
288
+ if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
289
+ self.frames_offset += 30.0
290
+ self.frames_np = self.frames_np[int(30*self.RATE):]
291
+ if self.frames_np is None:
292
+ self.frames_np = frame_np.copy()
293
+ else:
294
+ self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
295
+
296
+ def speech_to_text(self):
297
+ """
298
+ Process an audio stream in an infinite loop, continuously transcribing the speech.
299
+
300
+ This method continuously receives audio frames, performs real-time transcription, and sends
301
+ transcribed segments to the client via a WebSocket connection.
302
+
303
+ If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
304
+ It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
305
+ are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
306
+ (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
307
+ there is no speech for a specified duration to indicate a pause.
308
+
309
+ Raises:
310
+ Exception: If there is an issue with audio processing or WebSocket communication.
311
+
312
+ """
313
+ while True:
314
+ if self.exit:
315
+ logging.info("Exiting speech to text thread")
316
+ break
317
+
318
+ if self.frames_np is None:
319
+ continue
320
+
321
+ # clip audio if the current chunk exceeds 30 seconds, this basically implies that
322
+ # no valid segment for the last 30 seconds from whisper
323
+ if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
324
+ duration = self.frames_np.shape[0] / self.RATE
325
+ self.timestamp_offset = self.frames_offset + duration - 5
326
+
327
+ samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
328
+ input_bytes = self.frames_np[int(samples_take):].copy()
329
+ duration = input_bytes.shape[0] / self.RATE
330
+ if duration<1.0:
331
+ continue
332
+ try:
333
+ input_sample = input_bytes.copy()
334
+
335
+ # whisper transcribe with prompt
336
+ result, info = self.transcriber.transcribe(
337
+ input_sample,
338
+ initial_prompt=None,
339
+ language=self.language,
340
+ task=self.task,
341
+ vad_filter=True,
342
+ vad_parameters={"threshold": 0.5}
343
+ )
344
+
345
+ if self.language is None:
346
+ if info.language_probability > 0.5:
347
+ self.language = info.language
348
+ logging.info(f"Detected language {self.language} with probability {info.language_probability}")
349
+ self.websocket.send(json.dumps(
350
+ {"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
351
+ else:
352
+ # detect language again
353
+ continue
354
+
355
+ if len(result):
356
+ self.t_start = None
357
+ last_segment = self.update_segments(result, duration)
358
+ if len(self.transcript) < self.send_last_n_segments:
359
+ segments = self.transcript
360
+ else:
361
+ segments = self.transcript[-self.send_last_n_segments:]
362
+ if last_segment is not None:
363
+ segments = segments + [last_segment]
364
+ else:
365
+ # show previous output if there is pause i.e. no output from whisper
366
+ segments = []
367
+ if self.t_start is None: self.t_start = time.time()
368
+ if time.time() - self.t_start < self.show_prev_out_thresh:
369
+ if len(self.transcript) < self.send_last_n_segments:
370
+ segments = self.transcript
371
+ else:
372
+ segments = self.transcript[-self.send_last_n_segments:]
373
+
374
+ # add a blank if there is no speech for 3 seconds
375
+ if len(self.text) and self.text[-1] != '':
376
+ if time.time() - self.t_start > self.add_pause_thresh:
377
+ self.text.append('')
378
+
379
+ try:
380
+ self.websocket.send(
381
+ json.dumps({
382
+ "uid": self.client_uid,
383
+ "segments": segments
384
+ })
385
+ )
386
+ except Exception as e:
387
+ logging.error(f"[ERROR]: {e}")
388
+
389
+ except Exception as e:
390
+ logging.error(f"[ERROR]: {e}")
391
+ time.sleep(0.01)
392
+
393
+ def update_segments(self, segments, duration):
394
+ """
395
+ Processes the segments from whisper. Appends all the segments to the list
396
+ except for the last segment assuming that it is incomplete.
397
+
398
+ Updates the ongoing transcript with transcribed segments, including their start and end times.
399
+ Complete segments are appended to the transcript in chronological order. Incomplete segments
400
+ (assumed to be the last one) are processed to identify repeated content. If the same incomplete
401
+ segment is seen multiple times, it updates the offset and appends the segment to the transcript.
402
+ A threshold is used to detect repeated content and ensure it is only included once in the transcript.
403
+ The timestamp offset is updated based on the duration of processed segments. The method returns the
404
+ last processed segment, allowing it to be sent to the client for real-time updates.
405
+
406
+ Args:
407
+ segments(dict) : dictionary of segments as returned by whisper
408
+ duration(float): duration of the current chunk
409
+
410
+ Returns:
411
+ dict or None: The last processed segment with its start time, end time, and transcribed text.
412
+ Returns None if there are no valid segments to process.
413
+ """
414
+ offset = None
415
+ self.current_out = ''
416
+ last_segment = None
417
+ # process complete segments
418
+ if len(segments) > 1:
419
+ for i, s in enumerate(segments[:-1]):
420
+ text_ = s.text
421
+ self.text.append(text_)
422
+ start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
423
+ self.transcript.append(
424
+ {
425
+ 'start': start,
426
+ 'end': end,
427
+ 'text': text_
428
+ }
429
+ )
430
+
431
+ offset = min(duration, s.end)
432
+
433
+ self.current_out += segments[-1].text
434
+ last_segment = {
435
+ 'start': self.timestamp_offset + segments[-1].start,
436
+ 'end': self.timestamp_offset + min(duration, segments[-1].end),
437
+ 'text': self.current_out
438
+ }
439
+
440
+ # if same incomplete segment is seen multiple times then update the offset
441
+ # and append the segment to the list
442
+ if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
443
+ self.same_output_threshold += 1
444
+ else:
445
+ self.same_output_threshold = 0
446
+
447
+ if self.same_output_threshold > 5:
448
+ if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower():
449
+ self.text.append(self.current_out)
450
+ self.transcript.append(
451
+ {
452
+ 'start': self.timestamp_offset,
453
+ 'end': self.timestamp_offset + duration,
454
+ 'text': self.current_out
455
+ }
456
+ )
457
+ self.current_out = ''
458
+ offset = duration
459
+ self.same_output_threshold = 0
460
+ last_segment = None
461
+ else:
462
+ self.prev_out = self.current_out
463
+
464
+ # update offset
465
+ if offset is not None:
466
+ self.timestamp_offset += offset
467
+
468
+ return last_segment
469
+
470
+ def disconnect(self):
471
+ """
472
+ Notify the client of disconnection and send a disconnect message.
473
+
474
+ This method sends a disconnect message to the client via the WebSocket connection to notify them
475
+ that the transcription service is disconnecting gracefully.
476
+
477
+ """
478
+ self.websocket.send(
479
+ json.dumps(
480
+ {
481
+ "uid": self.client_uid,
482
+ "message": self.DISCONNECT
483
+ }
484
+ )
485
+ )
486
+
487
+ def cleanup(self):
488
+ """
489
+ Perform cleanup tasks before exiting the transcription service.
490
+
491
+ This method performs necessary cleanup tasks, including stopping the transcription thread, marking
492
+ the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
493
+ associated with the transcription process.
494
+
495
+ """
496
+ logging.info("Cleaning up.")
497
+ self.exit = True
498
+ self.transcriber.destroy()
whisper_live/transcriber.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py
2
+
3
+ import itertools
4
+ import logging
5
+ import os
6
+ import zlib
7
+
8
+ from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
9
+
10
+ import ctranslate2
11
+ import numpy as np
12
+ import tokenizers
13
+
14
+ from faster_whisper.audio import decode_audio
15
+ from faster_whisper.feature_extractor import FeatureExtractor
16
+ from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
17
+ from faster_whisper.utils import download_model, format_timestamp, get_logger
18
+ from faster_whisper.vad import (
19
+ SpeechTimestampsMap,
20
+ VadOptions,
21
+ collect_chunks,
22
+ get_speech_timestamps,
23
+ )
24
+
25
+
26
+ class Word(NamedTuple):
27
+ start: float
28
+ end: float
29
+ word: str
30
+ probability: float
31
+
32
+
33
+ class Segment(NamedTuple):
34
+ id: int
35
+ seek: int
36
+ start: float
37
+ end: float
38
+ text: str
39
+ tokens: List[int]
40
+ temperature: float
41
+ avg_logprob: float
42
+ compression_ratio: float
43
+ no_speech_prob: float
44
+ words: Optional[List[Word]]
45
+
46
+
47
+ class TranscriptionOptions(NamedTuple):
48
+ beam_size: int
49
+ best_of: int
50
+ patience: float
51
+ length_penalty: float
52
+ repetition_penalty: float
53
+ no_repeat_ngram_size: int
54
+ log_prob_threshold: Optional[float]
55
+ no_speech_threshold: Optional[float]
56
+ compression_ratio_threshold: Optional[float]
57
+ condition_on_previous_text: bool
58
+ prompt_reset_on_temperature: float
59
+ temperatures: List[float]
60
+ initial_prompt: Optional[Union[str, Iterable[int]]]
61
+ prefix: Optional[str]
62
+ suppress_blank: bool
63
+ suppress_tokens: Optional[List[int]]
64
+ without_timestamps: bool
65
+ max_initial_timestamp: float
66
+ word_timestamps: bool
67
+ prepend_punctuations: str
68
+ append_punctuations: str
69
+
70
+
71
+ class TranscriptionInfo(NamedTuple):
72
+ language: str
73
+ language_probability: float
74
+ duration: float
75
+ duration_after_vad: float
76
+ all_language_probs: Optional[List[Tuple[str, float]]]
77
+ transcription_options: TranscriptionOptions
78
+ vad_options: VadOptions
79
+
80
+
81
+ class WhisperModel:
82
+ def __init__(
83
+ self,
84
+ model_size_or_path: str,
85
+ device: str = "auto",
86
+ device_index: Union[int, List[int]] = 0,
87
+ compute_type: str = "default",
88
+ cpu_threads: int = 0,
89
+ num_workers: int = 1,
90
+ download_root: Optional[str] = None,
91
+ local_files_only: bool = False,
92
+ ):
93
+ """Initializes the Whisper model.
94
+
95
+ Args:
96
+ model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
97
+ small, small.en, medium, medium.en, large-v1, large-v2, or large), a path to a converted
98
+ model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
99
+ When a size or a model ID is configured, the converted model is downloaded
100
+ from the Hugging Face Hub.
101
+ device: Device to use for computation ("cpu", "cuda", "auto").
102
+ device_index: Device ID to use.
103
+ The model can also be loaded on multiple GPUs by passing a list of IDs
104
+ (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
105
+ when transcribe() is called from multiple Python threads (see also num_workers).
106
+ compute_type: Type to use for computation.
107
+ See https://opennmt.net/CTranslate2/quantization.html.
108
+ cpu_threads: Number of threads to use when running on CPU (4 by default).
109
+ A non zero value overrides the OMP_NUM_THREADS environment variable.
110
+ num_workers: When transcribe() is called from multiple Python threads,
111
+ having multiple workers enables true parallelism when running the model
112
+ (concurrent calls to self.model.generate() will run in parallel).
113
+ This can improve the global throughput at the cost of increased memory usage.
114
+ download_root: Directory where the models should be saved. If not set, the models
115
+ are saved in the standard Hugging Face cache directory.
116
+ local_files_only: If True, avoid downloading the file and return the path to the
117
+ local cached file if it exists.
118
+ """
119
+ self.logger = get_logger()
120
+
121
+ if os.path.isdir(model_size_or_path):
122
+ model_path = model_size_or_path
123
+ else:
124
+ model_path = download_model(
125
+ model_size_or_path,
126
+ local_files_only=local_files_only,
127
+ cache_dir=download_root,
128
+ )
129
+
130
+ self.model = ctranslate2.models.Whisper(
131
+ model_path,
132
+ device=device,
133
+ device_index=device_index,
134
+ compute_type=compute_type,
135
+ intra_threads=cpu_threads,
136
+ inter_threads=num_workers,
137
+ )
138
+
139
+ tokenizer_file = os.path.join(model_path, "tokenizer.json")
140
+ if os.path.isfile(tokenizer_file):
141
+ self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
142
+ else:
143
+ self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
144
+ "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
145
+ )
146
+
147
+ self.feature_extractor = FeatureExtractor()
148
+ self.num_samples_per_token = self.feature_extractor.hop_length * 2
149
+ self.frames_per_second = (
150
+ self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
151
+ )
152
+ self.tokens_per_second = (
153
+ self.feature_extractor.sampling_rate // self.num_samples_per_token
154
+ )
155
+ self.input_stride = 2
156
+ self.time_precision = 0.02
157
+ self.max_length = 448
158
+
159
+ @property
160
+ def supported_languages(self) -> List[str]:
161
+ """The languages supported by the model."""
162
+ return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
163
+
164
+ def transcribe(
165
+ self,
166
+ audio: Union[str, BinaryIO, np.ndarray],
167
+ language: Optional[str] = None,
168
+ task: str = "transcribe",
169
+ beam_size: int = 5,
170
+ best_of: int = 5,
171
+ patience: float = 1,
172
+ length_penalty: float = 1,
173
+ repetition_penalty: float = 1,
174
+ no_repeat_ngram_size: int = 0,
175
+ temperature: Union[float, List[float], Tuple[float, ...]] = [
176
+ 0.0,
177
+ 0.2,
178
+ 0.4,
179
+ 0.6,
180
+ 0.8,
181
+ 1.0,
182
+ ],
183
+ compression_ratio_threshold: Optional[float] = 2.4,
184
+ log_prob_threshold: Optional[float] = -1.0,
185
+ no_speech_threshold: Optional[float] = 0.6,
186
+ condition_on_previous_text: bool = True,
187
+ prompt_reset_on_temperature: float = 0.5,
188
+ initial_prompt: Optional[Union[str, Iterable[int]]] = None,
189
+ prefix: Optional[str] = None,
190
+ suppress_blank: bool = True,
191
+ suppress_tokens: Optional[List[int]] = [-1],
192
+ without_timestamps: bool = False,
193
+ max_initial_timestamp: float = 1.0,
194
+ word_timestamps: bool = False,
195
+ prepend_punctuations: str = "\"'“¿([{-",
196
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
197
+ vad_filter: bool = False,
198
+ vad_parameters: Optional[Union[dict, VadOptions]] = None,
199
+ ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
200
+ """Transcribes an input file.
201
+
202
+ Arguments:
203
+ audio: Path to the input file (or a file-like object), or the audio waveform.
204
+ language: The language spoken in the audio. It should be a language code such
205
+ as "en" or "fr". If not set, the language will be detected in the first 30 seconds
206
+ of audio.
207
+ task: Task to execute (transcribe or translate).
208
+ beam_size: Beam size to use for decoding.
209
+ best_of: Number of candidates when sampling with non-zero temperature.
210
+ patience: Beam search patience factor.
211
+ length_penalty: Exponential length penalty constant.
212
+ repetition_penalty: Penalty applied to the score of previously generated tokens
213
+ (set > 1 to penalize).
214
+ no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
215
+ temperature: Temperature for sampling. It can be a tuple of temperatures,
216
+ which will be successively used upon failures according to either
217
+ `compression_ratio_threshold` or `log_prob_threshold`.
218
+ compression_ratio_threshold: If the gzip compression ratio is above this value,
219
+ treat as failed.
220
+ log_prob_threshold: If the average log probability over sampled tokens is
221
+ below this value, treat as failed.
222
+ no_speech_threshold: If the no_speech probability is higher than this value AND
223
+ the average log probability over sampled tokens is below `log_prob_threshold`,
224
+ consider the segment as silent.
225
+ condition_on_previous_text: If True, the previous output of the model is provided
226
+ as a prompt for the next window; disabling may make the text inconsistent across
227
+ windows, but the model becomes less prone to getting stuck in a failure loop,
228
+ such as repetition looping or timestamps going out of sync.
229
+ prompt_reset_on_temperature: Resets prompt if temperature is above this value.
230
+ Arg has effect only if condition_on_previous_text is True.
231
+ initial_prompt: Optional text string or iterable of token ids to provide as a
232
+ prompt for the first window.
233
+ prefix: Optional text to provide as a prefix for the first window.
234
+ suppress_blank: Suppress blank outputs at the beginning of the sampling.
235
+ suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
236
+ of symbols as defined in the model config.json file.
237
+ without_timestamps: Only sample text tokens.
238
+ max_initial_timestamp: The initial timestamp cannot be later than this.
239
+ word_timestamps: Extract word-level timestamps using the cross-attention pattern
240
+ and dynamic time warping, and include the timestamps for each word in each segment.
241
+ prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
242
+ with the next word
243
+ append_punctuations: If word_timestamps is True, merge these punctuation symbols
244
+ with the previous word
245
+ vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
246
+ without speech. This step is using the Silero VAD model
247
+ https://github.com/snakers4/silero-vad.
248
+ vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
249
+ parameters and default values in the class `VadOptions`).
250
+
251
+ Returns:
252
+ A tuple with:
253
+
254
+ - a generator over transcribed segments
255
+ - an instance of TranscriptionInfo
256
+ """
257
+ sampling_rate = self.feature_extractor.sampling_rate
258
+
259
+ if not isinstance(audio, np.ndarray):
260
+ audio = decode_audio(audio, sampling_rate=sampling_rate)
261
+
262
+ duration = audio.shape[0] / sampling_rate
263
+ duration_after_vad = duration
264
+
265
+ self.logger.info(
266
+ "Processing audio with duration %s", format_timestamp(duration)
267
+ )
268
+
269
+ if vad_filter:
270
+ if vad_parameters is None:
271
+ vad_parameters = VadOptions()
272
+ elif isinstance(vad_parameters, dict):
273
+ vad_parameters = VadOptions(**vad_parameters)
274
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
275
+ audio = collect_chunks(audio, speech_chunks)
276
+ duration_after_vad = audio.shape[0] / sampling_rate
277
+
278
+ self.logger.info(
279
+ "VAD filter removed %s of audio",
280
+ format_timestamp(duration - duration_after_vad),
281
+ )
282
+
283
+ if self.logger.isEnabledFor(logging.DEBUG):
284
+ self.logger.debug(
285
+ "VAD filter kept the following audio segments: %s",
286
+ ", ".join(
287
+ "[%s -> %s]"
288
+ % (
289
+ format_timestamp(chunk["start"] / sampling_rate),
290
+ format_timestamp(chunk["end"] / sampling_rate),
291
+ )
292
+ for chunk in speech_chunks
293
+ ),
294
+ )
295
+
296
+ else:
297
+ speech_chunks = None
298
+
299
+ features = self.feature_extractor(audio)
300
+
301
+ encoder_output = None
302
+ all_language_probs = None
303
+
304
+ if language is None:
305
+ if not self.model.is_multilingual:
306
+ language = "en"
307
+ language_probability = 1
308
+ else:
309
+ segment = features[:, : self.feature_extractor.nb_max_frames]
310
+ encoder_output = self.encode(segment)
311
+ # results is a list of tuple[str, float] with language names and
312
+ # probabilities.
313
+ results = self.model.detect_language(encoder_output)[0]
314
+ # Parse language names to strip out markers
315
+ all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
316
+ # Get top language token and probability
317
+ language, language_probability = all_language_probs[0]
318
+
319
+ self.logger.info(
320
+ "Detected language '%s' with probability %.2f",
321
+ language,
322
+ language_probability,
323
+ )
324
+ else:
325
+ if not self.model.is_multilingual and language != "en":
326
+ self.logger.warning(
327
+ "The current model is English-only but the language parameter is set to '%s'; "
328
+ "using 'en' instead." % language
329
+ )
330
+ language = "en"
331
+
332
+ language_probability = 1
333
+
334
+ tokenizer = Tokenizer(
335
+ self.hf_tokenizer,
336
+ self.model.is_multilingual,
337
+ task=task,
338
+ language=language,
339
+ )
340
+
341
+ options = TranscriptionOptions(
342
+ beam_size=beam_size,
343
+ best_of=best_of,
344
+ patience=patience,
345
+ length_penalty=length_penalty,
346
+ repetition_penalty=repetition_penalty,
347
+ no_repeat_ngram_size=no_repeat_ngram_size,
348
+ log_prob_threshold=log_prob_threshold,
349
+ no_speech_threshold=no_speech_threshold,
350
+ compression_ratio_threshold=compression_ratio_threshold,
351
+ condition_on_previous_text=condition_on_previous_text,
352
+ prompt_reset_on_temperature=prompt_reset_on_temperature,
353
+ temperatures=(
354
+ temperature if isinstance(temperature, (list, tuple)) else [temperature]
355
+ ),
356
+ initial_prompt=initial_prompt,
357
+ prefix=prefix,
358
+ suppress_blank=suppress_blank,
359
+ suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
360
+ without_timestamps=without_timestamps,
361
+ max_initial_timestamp=max_initial_timestamp,
362
+ word_timestamps=word_timestamps,
363
+ prepend_punctuations=prepend_punctuations,
364
+ append_punctuations=append_punctuations,
365
+ )
366
+
367
+ segments = self.generate_segments(features, tokenizer, options, encoder_output)
368
+
369
+ if speech_chunks:
370
+ segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
371
+
372
+ info = TranscriptionInfo(
373
+ language=language,
374
+ language_probability=language_probability,
375
+ duration=duration,
376
+ duration_after_vad=duration_after_vad,
377
+ transcription_options=options,
378
+ vad_options=vad_parameters,
379
+ all_language_probs=all_language_probs,
380
+ )
381
+
382
+ return segments, info
383
+
384
+ def generate_segments(
385
+ self,
386
+ features: np.ndarray,
387
+ tokenizer: Tokenizer,
388
+ options: TranscriptionOptions,
389
+ encoder_output: Optional[ctranslate2.StorageView] = None,
390
+ ) -> Iterable[Segment]:
391
+ content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
392
+ idx = 0
393
+ seek = 0
394
+ all_tokens = []
395
+ prompt_reset_since = 0
396
+
397
+ if options.initial_prompt is not None:
398
+ if isinstance(options.initial_prompt, str):
399
+ initial_prompt = " " + options.initial_prompt.strip()
400
+ initial_prompt_tokens = tokenizer.encode(initial_prompt)
401
+ all_tokens.extend(initial_prompt_tokens)
402
+ else:
403
+ all_tokens.extend(options.initial_prompt)
404
+
405
+ last_speech_timestamp = 0.0
406
+ all_segments = []
407
+ while seek < content_frames:
408
+ time_offset = seek * self.feature_extractor.time_per_frame
409
+ segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
410
+ segment_size = min(
411
+ self.feature_extractor.nb_max_frames, content_frames - seek
412
+ )
413
+ segment_duration = segment_size * self.feature_extractor.time_per_frame
414
+
415
+ if self.logger.isEnabledFor(logging.DEBUG):
416
+ self.logger.debug(
417
+ "Processing segment at %s", format_timestamp(time_offset)
418
+ )
419
+
420
+ previous_tokens = all_tokens[prompt_reset_since:]
421
+ prompt = self.get_prompt(
422
+ tokenizer,
423
+ previous_tokens,
424
+ without_timestamps=options.without_timestamps,
425
+ prefix=options.prefix if seek == 0 else None,
426
+ )
427
+
428
+ if seek > 0 or encoder_output is None:
429
+ encoder_output = self.encode(segment)
430
+
431
+ (
432
+ result,
433
+ avg_logprob,
434
+ temperature,
435
+ compression_ratio,
436
+ ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
437
+
438
+ if options.no_speech_threshold is not None:
439
+ # no voice activity check
440
+ should_skip = result.no_speech_prob > options.no_speech_threshold
441
+
442
+ if (
443
+ options.log_prob_threshold is not None
444
+ and avg_logprob > options.log_prob_threshold
445
+ ):
446
+ # don't skip if the logprob is high enough, despite the no_speech_prob
447
+ should_skip = False
448
+
449
+ if should_skip:
450
+ self.logger.debug(
451
+ "No speech threshold is met (%f > %f)",
452
+ result.no_speech_prob,
453
+ options.no_speech_threshold,
454
+ )
455
+
456
+ # fast-forward to the next segment boundary
457
+ seek += segment_size
458
+ continue
459
+
460
+ tokens = result.sequences_ids[0]
461
+
462
+ previous_seek = seek
463
+ current_segments = []
464
+
465
+ single_timestamp_ending = (
466
+ len(tokens) >= 2
467
+ and tokens[-2] < tokenizer.timestamp_begin
468
+ and tokens[-1] >= tokenizer.timestamp_begin
469
+ )
470
+
471
+ consecutive_timestamps = [
472
+ i
473
+ for i in range(len(tokens))
474
+ if i > 0
475
+ and tokens[i] >= tokenizer.timestamp_begin
476
+ and tokens[i - 1] >= tokenizer.timestamp_begin
477
+ ]
478
+
479
+ if len(consecutive_timestamps) > 0:
480
+ slices = list(consecutive_timestamps)
481
+ if single_timestamp_ending:
482
+ slices.append(len(tokens))
483
+
484
+ last_slice = 0
485
+ for current_slice in slices:
486
+ sliced_tokens = tokens[last_slice:current_slice]
487
+ start_timestamp_position = (
488
+ sliced_tokens[0] - tokenizer.timestamp_begin
489
+ )
490
+ end_timestamp_position = (
491
+ sliced_tokens[-1] - tokenizer.timestamp_begin
492
+ )
493
+ start_time = (
494
+ time_offset + start_timestamp_position * self.time_precision
495
+ )
496
+ end_time = (
497
+ time_offset + end_timestamp_position * self.time_precision
498
+ )
499
+
500
+ current_segments.append(
501
+ dict(
502
+ seek=seek,
503
+ start=start_time,
504
+ end=end_time,
505
+ tokens=sliced_tokens,
506
+ )
507
+ )
508
+ last_slice = current_slice
509
+
510
+ if single_timestamp_ending:
511
+ # single timestamp at the end means no speech after the last timestamp.
512
+ seek += segment_size
513
+ else:
514
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
515
+ last_timestamp_position = (
516
+ tokens[last_slice - 1] - tokenizer.timestamp_begin
517
+ )
518
+ seek += last_timestamp_position * self.input_stride
519
+
520
+ else:
521
+ duration = segment_duration
522
+ timestamps = [
523
+ token for token in tokens if token >= tokenizer.timestamp_begin
524
+ ]
525
+ if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
526
+ last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
527
+ duration = last_timestamp_position * self.time_precision
528
+
529
+ current_segments.append(
530
+ dict(
531
+ seek=seek,
532
+ start=time_offset,
533
+ end=time_offset + duration,
534
+ tokens=tokens,
535
+ )
536
+ )
537
+
538
+ seek += segment_size
539
+
540
+ if options.word_timestamps:
541
+ self.add_word_timestamps(
542
+ current_segments,
543
+ tokenizer,
544
+ encoder_output,
545
+ segment_size,
546
+ options.prepend_punctuations,
547
+ options.append_punctuations,
548
+ last_speech_timestamp=last_speech_timestamp,
549
+ )
550
+
551
+ word_end_timestamps = [
552
+ w["end"] for s in current_segments for w in s["words"]
553
+ ]
554
+ if len(word_end_timestamps) > 0:
555
+ last_speech_timestamp = word_end_timestamps[-1]
556
+ if not single_timestamp_ending and len(word_end_timestamps) > 0:
557
+ seek_shift = round(
558
+ (word_end_timestamps[-1] - time_offset) * self.frames_per_second
559
+ )
560
+
561
+ if seek_shift > 0:
562
+ seek = previous_seek + seek_shift
563
+
564
+ for segment in current_segments:
565
+ tokens = segment["tokens"]
566
+ text = tokenizer.decode(tokens)
567
+
568
+ if segment["start"] == segment["end"] or not text.strip():
569
+ continue
570
+
571
+ all_tokens.extend(tokens)
572
+ idx += 1
573
+
574
+ all_segments.append(Segment(
575
+ id=idx,
576
+ seek=seek,
577
+ start=segment["start"],
578
+ end=segment["end"],
579
+ text=text,
580
+ tokens=tokens,
581
+ temperature=temperature,
582
+ avg_logprob=avg_logprob,
583
+ compression_ratio=compression_ratio,
584
+ no_speech_prob=result.no_speech_prob,
585
+ words=(
586
+ [Word(**word) for word in segment["words"]]
587
+ if options.word_timestamps
588
+ else None
589
+ ),
590
+ ))
591
+
592
+ if (
593
+ not options.condition_on_previous_text
594
+ or temperature > options.prompt_reset_on_temperature
595
+ ):
596
+ if options.condition_on_previous_text:
597
+ self.logger.debug(
598
+ "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
599
+ temperature,
600
+ options.prompt_reset_on_temperature,
601
+ )
602
+
603
+ prompt_reset_since = len(all_tokens)
604
+ return all_segments
605
+
606
+ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
607
+ # When the model is running on multiple GPUs, the encoder output should be moved
608
+ # to the CPU since we don't know which GPU will handle the next job.
609
+ to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
610
+
611
+ features = np.expand_dims(features, 0)
612
+ features = get_ctranslate2_storage(features)
613
+
614
+ return self.model.encode(features, to_cpu=to_cpu)
615
+
616
+ def generate_with_fallback(
617
+ self,
618
+ encoder_output: ctranslate2.StorageView,
619
+ prompt: List[int],
620
+ tokenizer: Tokenizer,
621
+ options: TranscriptionOptions,
622
+ ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
623
+ decode_result = None
624
+ all_results = []
625
+ below_cr_threshold_results = []
626
+
627
+ max_initial_timestamp_index = int(
628
+ round(options.max_initial_timestamp / self.time_precision)
629
+ )
630
+
631
+ for temperature in options.temperatures:
632
+ if temperature > 0:
633
+ kwargs = {
634
+ "beam_size": 1,
635
+ "num_hypotheses": options.best_of,
636
+ "sampling_topk": 0,
637
+ "sampling_temperature": temperature,
638
+ }
639
+ else:
640
+ kwargs = {
641
+ "beam_size": options.beam_size,
642
+ "patience": options.patience,
643
+ }
644
+
645
+ result = self.model.generate(
646
+ encoder_output,
647
+ [prompt],
648
+ length_penalty=options.length_penalty,
649
+ repetition_penalty=options.repetition_penalty,
650
+ no_repeat_ngram_size=options.no_repeat_ngram_size,
651
+ max_length=self.max_length,
652
+ return_scores=True,
653
+ return_no_speech_prob=True,
654
+ suppress_blank=options.suppress_blank,
655
+ suppress_tokens=options.suppress_tokens,
656
+ max_initial_timestamp_index=max_initial_timestamp_index,
657
+ **kwargs,
658
+ )[0]
659
+
660
+ tokens = result.sequences_ids[0]
661
+
662
+ # Recover the average log prob from the returned score.
663
+ seq_len = len(tokens)
664
+ cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
665
+ avg_logprob = cum_logprob / (seq_len + 1)
666
+
667
+ text = tokenizer.decode(tokens).strip()
668
+ compression_ratio = get_compression_ratio(text)
669
+
670
+ decode_result = (
671
+ result,
672
+ avg_logprob,
673
+ temperature,
674
+ compression_ratio,
675
+ )
676
+ all_results.append(decode_result)
677
+
678
+ needs_fallback = False
679
+
680
+ if options.compression_ratio_threshold is not None:
681
+ if compression_ratio > options.compression_ratio_threshold:
682
+ needs_fallback = True # too repetitive
683
+
684
+ self.logger.debug(
685
+ "Compression ratio threshold is not met with temperature %.1f (%f > %f)",
686
+ temperature,
687
+ compression_ratio,
688
+ options.compression_ratio_threshold,
689
+ )
690
+ else:
691
+ below_cr_threshold_results.append(decode_result)
692
+
693
+ if (
694
+ options.log_prob_threshold is not None
695
+ and avg_logprob < options.log_prob_threshold
696
+ ):
697
+ needs_fallback = True # average log probability is too low
698
+
699
+ self.logger.debug(
700
+ "Log probability threshold is not met with temperature %.1f (%f < %f)",
701
+ temperature,
702
+ avg_logprob,
703
+ options.log_prob_threshold,
704
+ )
705
+
706
+ if (
707
+ options.no_speech_threshold is not None
708
+ and result.no_speech_prob > options.no_speech_threshold
709
+ ):
710
+ needs_fallback = False # silence
711
+
712
+ if not needs_fallback:
713
+ break
714
+ else:
715
+ # all failed, select the result with the highest average log probability
716
+ decode_result = max(
717
+ below_cr_threshold_results or all_results, key=lambda x: x[1]
718
+ )
719
+
720
+ return decode_result
721
+
722
+ def get_prompt(
723
+ self,
724
+ tokenizer: Tokenizer,
725
+ previous_tokens: List[int],
726
+ without_timestamps: bool = False,
727
+ prefix: Optional[str] = None,
728
+ ) -> List[int]:
729
+ prompt = []
730
+
731
+ if previous_tokens:
732
+ prompt.append(tokenizer.sot_prev)
733
+ prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
734
+
735
+ prompt.extend(tokenizer.sot_sequence)
736
+
737
+ if without_timestamps:
738
+ prompt.append(tokenizer.no_timestamps)
739
+
740
+ if prefix:
741
+ prefix_tokens = tokenizer.encode(" " + prefix.strip())
742
+ if len(prefix_tokens) >= self.max_length // 2:
743
+ prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
744
+ if not without_timestamps:
745
+ prompt.append(tokenizer.timestamp_begin)
746
+ prompt.extend(prefix_tokens)
747
+
748
+ return prompt
749
+
750
+ def add_word_timestamps(
751
+ self,
752
+ segments: List[dict],
753
+ tokenizer: Tokenizer,
754
+ encoder_output: ctranslate2.StorageView,
755
+ num_frames: int,
756
+ prepend_punctuations: str,
757
+ append_punctuations: str,
758
+ last_speech_timestamp: float,
759
+ ) -> None:
760
+ if len(segments) == 0:
761
+ return
762
+
763
+ text_tokens_per_segment = [
764
+ [token for token in segment["tokens"] if token < tokenizer.eot]
765
+ for segment in segments
766
+ ]
767
+
768
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
769
+ alignment = self.find_alignment(
770
+ tokenizer, text_tokens, encoder_output, num_frames
771
+ )
772
+ word_durations = np.array([word["end"] - word["start"] for word in alignment])
773
+ word_durations = word_durations[word_durations.nonzero()]
774
+ median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
775
+ max_duration = median_duration * 2
776
+
777
+ # hack: truncate long words at sentence boundaries.
778
+ # a better segmentation algorithm based on VAD should be able to replace this.
779
+ if len(word_durations) > 0:
780
+ sentence_end_marks = ".。!!??"
781
+ # ensure words at sentence boundaries
782
+ # are not longer than twice the median word duration.
783
+ for i in range(1, len(alignment)):
784
+ if alignment[i]["end"] - alignment[i]["start"] > max_duration:
785
+ if alignment[i]["word"] in sentence_end_marks:
786
+ alignment[i]["end"] = alignment[i]["start"] + max_duration
787
+ elif alignment[i - 1]["word"] in sentence_end_marks:
788
+ alignment[i]["start"] = alignment[i]["end"] - max_duration
789
+
790
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
791
+
792
+ time_offset = (
793
+ segments[0]["seek"]
794
+ * self.feature_extractor.hop_length
795
+ / self.feature_extractor.sampling_rate
796
+ )
797
+
798
+ word_index = 0
799
+
800
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
801
+ saved_tokens = 0
802
+ words = []
803
+
804
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
805
+ timing = alignment[word_index]
806
+
807
+ if timing["word"]:
808
+ words.append(
809
+ dict(
810
+ word=timing["word"],
811
+ start=round(time_offset + timing["start"], 2),
812
+ end=round(time_offset + timing["end"], 2),
813
+ probability=timing["probability"],
814
+ )
815
+ )
816
+
817
+ saved_tokens += len(timing["tokens"])
818
+ word_index += 1
819
+
820
+ # hack: truncate long words at segment boundaries.
821
+ # a better segmentation algorithm based on VAD should be able to replace this.
822
+ if len(words) > 0:
823
+ # ensure the first and second word after a pause is not longer than
824
+ # twice the median word duration.
825
+ if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
826
+ words[0]["end"] - words[0]["start"] > max_duration
827
+ or (
828
+ len(words) > 1
829
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
830
+ )
831
+ ):
832
+ if (
833
+ len(words) > 1
834
+ and words[1]["end"] - words[1]["start"] > max_duration
835
+ ):
836
+ boundary = max(
837
+ words[1]["end"] / 2, words[1]["end"] - max_duration
838
+ )
839
+ words[0]["end"] = words[1]["start"] = boundary
840
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
841
+
842
+ # prefer the segment-level start timestamp if the first word is too long.
843
+ if (
844
+ segment["start"] < words[0]["end"]
845
+ and segment["start"] - 0.5 > words[0]["start"]
846
+ ):
847
+ words[0]["start"] = max(
848
+ 0, min(words[0]["end"] - median_duration, segment["start"])
849
+ )
850
+ else:
851
+ segment["start"] = words[0]["start"]
852
+
853
+ # prefer the segment-level end timestamp if the last word is too long.
854
+ if (
855
+ segment["end"] > words[-1]["start"]
856
+ and segment["end"] + 0.5 < words[-1]["end"]
857
+ ):
858
+ words[-1]["end"] = max(
859
+ words[-1]["start"] + median_duration, segment["end"]
860
+ )
861
+ else:
862
+ segment["end"] = words[-1]["end"]
863
+
864
+ last_speech_timestamp = segment["end"]
865
+
866
+ segment["words"] = words
867
+
868
+ def find_alignment(
869
+ self,
870
+ tokenizer: Tokenizer,
871
+ text_tokens: List[int],
872
+ encoder_output: ctranslate2.StorageView,
873
+ num_frames: int,
874
+ median_filter_width: int = 7,
875
+ ) -> List[dict]:
876
+ if len(text_tokens) == 0:
877
+ return []
878
+
879
+ result = self.model.align(
880
+ encoder_output,
881
+ tokenizer.sot_sequence,
882
+ [text_tokens],
883
+ num_frames,
884
+ median_filter_width=median_filter_width,
885
+ )[0]
886
+
887
+ text_token_probs = result.text_token_probs
888
+
889
+ alignments = result.alignments
890
+ text_indices = np.array([pair[0] for pair in alignments])
891
+ time_indices = np.array([pair[1] for pair in alignments])
892
+
893
+ words, word_tokens = tokenizer.split_to_word_tokens(
894
+ text_tokens + [tokenizer.eot]
895
+ )
896
+ word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
897
+ if len(word_boundaries) <= 1:
898
+ return []
899
+
900
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
901
+ jump_times = time_indices[jumps] / self.tokens_per_second
902
+ start_times = jump_times[word_boundaries[:-1]]
903
+ end_times = jump_times[word_boundaries[1:]]
904
+ word_probabilities = [
905
+ np.mean(text_token_probs[i:j])
906
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
907
+ ]
908
+
909
+ return [
910
+ dict(
911
+ word=word, tokens=tokens, start=start, end=end, probability=probability
912
+ )
913
+ for word, tokens, start, end, probability in zip(
914
+ words, word_tokens, start_times, end_times, word_probabilities
915
+ )
916
+ ]
917
+
918
+ def destroy(self):
919
+ del self.model
920
+
921
+
922
+ def restore_speech_timestamps(
923
+ segments: Iterable[Segment],
924
+ speech_chunks: List[dict],
925
+ sampling_rate: int,
926
+ ) -> Iterable[Segment]:
927
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
928
+
929
+ for segment in segments:
930
+ if segment.words:
931
+ words = []
932
+ for word in segment.words:
933
+ # Ensure the word start and end times are resolved to the same chunk.
934
+ middle = (word.start + word.end) / 2
935
+ chunk_index = ts_map.get_chunk_index(middle)
936
+ word = word._replace(
937
+ start=ts_map.get_original_time(word.start, chunk_index),
938
+ end=ts_map.get_original_time(word.end, chunk_index),
939
+ )
940
+ words.append(word)
941
+
942
+ segment = segment._replace(
943
+ start=words[0].start,
944
+ end=words[-1].end,
945
+ words=words,
946
+ )
947
+
948
+ else:
949
+ segment = segment._replace(
950
+ start=ts_map.get_original_time(segment.start),
951
+ end=ts_map.get_original_time(segment.end),
952
+ )
953
+
954
+ return segments
955
+
956
+
957
+ def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
958
+ segment = np.ascontiguousarray(segment)
959
+ segment = ctranslate2.StorageView.from_array(segment)
960
+ return segment
961
+
962
+
963
+ def get_compression_ratio(text: str) -> float:
964
+ text_bytes = text.encode("utf-8")
965
+ return len(text_bytes) / len(zlib.compress(text_bytes))
966
+
967
+
968
+ def get_suppressed_tokens(
969
+ tokenizer: Tokenizer,
970
+ suppress_tokens: Optional[List[int]],
971
+ ) -> Optional[List[int]]:
972
+ if not suppress_tokens or -1 in suppress_tokens:
973
+ return suppress_tokens
974
+
975
+ suppress_tokens = list(suppress_tokens)
976
+
977
+ # Ensure the following special tokens are suppressed when the user does
978
+ # not use the default set (-1).
979
+ suppress_tokens.extend(
980
+ [
981
+ tokenizer.transcribe,
982
+ tokenizer.translate,
983
+ tokenizer.sot,
984
+ tokenizer.sot_prev,
985
+ tokenizer.sot_lm,
986
+ ]
987
+ )
988
+
989
+ return sorted(set(suppress_tokens))
990
+
991
+
992
+ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
993
+ # merge prepended punctuations
994
+ i = len(alignment) - 2
995
+ j = len(alignment) - 1
996
+ while i >= 0:
997
+ previous = alignment[i]
998
+ following = alignment[j]
999
+ if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
1000
+ # prepend it to the following word
1001
+ following["word"] = previous["word"] + following["word"]
1002
+ following["tokens"] = previous["tokens"] + following["tokens"]
1003
+ previous["word"] = ""
1004
+ previous["tokens"] = []
1005
+ else:
1006
+ j = i
1007
+ i -= 1
1008
+
1009
+ # merge appended punctuations
1010
+ i = 0
1011
+ j = 1
1012
+ while j < len(alignment):
1013
+ previous = alignment[i]
1014
+ following = alignment[j]
1015
+ if not previous["word"].endswith(" ") and following["word"] in appended:
1016
+ # append it to the previous word
1017
+ previous["word"] = previous["word"] + following["word"]
1018
+ previous["tokens"] = previous["tokens"] + following["tokens"]
1019
+ following["word"] = ""
1020
+ following["tokens"] = []
1021
+ else:
1022
+ i = j
1023
+ j += 1
whisper_live/trt_server.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import websockets
2
+ import time
3
+ import threading
4
+ import json
5
+ import textwrap
6
+
7
+ import logging
8
+ logging.basicConfig(level = logging.INFO)
9
+
10
+ from websockets.sync.server import serve
11
+
12
+ import torch
13
+ import numpy as np
14
+ import time
15
+ from whisper_live.vad import VoiceActivityDetection
16
+ from whisper_live.trt_transcriber import WhisperTRTLLM
17
+
18
+
19
+ from scipy.io.wavfile import write
20
+ import functools
21
+
22
+ save_counter = 0
23
+ def save_wav(normalized_float32):
24
+ global save_counter
25
+ scaled_int16 = (normalized_float32 * 32768).astype(np.int16)
26
+ write(f"outputs/output{save_counter}.wav", 16000, scaled_int16)
27
+ save_counter += 1
28
+
29
+
30
+
31
+ class TranscriptionServer:
32
+ """
33
+ Represents a transcription server that handles incoming audio from clients.
34
+
35
+ Attributes:
36
+ RATE (int): The audio sampling rate (constant) set to 16000.
37
+ vad_model (torch.Module): The voice activity detection model.
38
+ vad_threshold (float): The voice activity detection threshold.
39
+ clients (dict): A dictionary to store connected clients.
40
+ websockets (dict): A dictionary to store WebSocket connections.
41
+ clients_start_time (dict): A dictionary to track client start times.
42
+ max_clients (int): Maximum allowed connected clients.
43
+ max_connection_time (int): Maximum allowed connection time in seconds.
44
+ """
45
+
46
+ RATE = 16000
47
+
48
+ def __init__(self):
49
+ # voice activity detection model
50
+ self.vad_model = VoiceActivityDetection()
51
+ self.vad_threshold = 0.5
52
+ self.clients = {}
53
+ self.websockets = {}
54
+ self.clients_start_time = {}
55
+ self.max_clients = 4
56
+ self.max_connection_time = 600
57
+
58
+ def get_wait_time(self):
59
+ """
60
+ Calculate and return the estimated wait time for clients.
61
+
62
+ Returns:
63
+ float: The estimated wait time in minutes.
64
+ """
65
+ wait_time = None
66
+
67
+ for k, v in self.clients_start_time.items():
68
+ current_client_time_remaining = self.max_connection_time - (time.time() - v)
69
+
70
+ if wait_time is None or current_client_time_remaining < wait_time:
71
+ wait_time = current_client_time_remaining
72
+
73
+ return wait_time / 60
74
+
75
+ def recv_audio(self, websocket):
76
+ """
77
+ Receive audio chunks from a client in an infinite loop.
78
+
79
+ Continuously receives audio frames from a connected client
80
+ over a WebSocket connection. It processes the audio frames using a
81
+ voice activity detection (VAD) model to determine if they contain speech
82
+ or not. If the audio frame contains speech, it is added to the client's
83
+ audio data for ASR.
84
+ If the maximum number of clients is reached, the method sends a
85
+ "WAIT" status to the client, indicating that they should wait
86
+ until a slot is available.
87
+ If a client's connection exceeds the maximum allowed time, it will
88
+ be disconnected, and the client's resources will be cleaned up.
89
+
90
+ Args:
91
+ websocket (WebSocket): The WebSocket connection for the client.
92
+
93
+ Raises:
94
+ Exception: If there is an error during the audio frame processing.
95
+ """
96
+ logging.info("New client connected")
97
+ options = websocket.recv()
98
+ options = json.loads(options)
99
+
100
+ if len(self.clients) >= self.max_clients:
101
+ logging.warning("Client Queue Full. Asking client to wait ...")
102
+ wait_time = self.get_wait_time()
103
+ response = {
104
+ "uid": options["uid"],
105
+ "status": "WAIT",
106
+ "message": wait_time,
107
+ }
108
+ websocket.send(json.dumps(response))
109
+ websocket.close()
110
+ del websocket
111
+ return
112
+
113
+ client = ServeClient(
114
+ websocket,
115
+ multilingual=options["multilingual"],
116
+ language=options["language"],
117
+ task=options["task"],
118
+ client_uid=options["uid"]
119
+ )
120
+
121
+ self.clients[websocket] = client
122
+ self.clients_start_time[websocket] = time.time()
123
+ no_voice_activity_chunks = 0
124
+ while True:
125
+ try:
126
+ frame_data = websocket.recv()
127
+ frame_np = np.frombuffer(frame_data, dtype=np.float32)
128
+ # VAD
129
+ try:
130
+ speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
131
+ if speech_prob < self.vad_threshold:
132
+ no_voice_activity_chunks += 1
133
+ print("No speech", no_voice_activity_chunks)
134
+ if no_voice_activity_chunks > 2:
135
+ if not self.clients[websocket].eos:
136
+ self.clients[websocket].set_eos(True)
137
+ continue
138
+ no_voice_activity_chunks = 0
139
+ self.clients[websocket].set_eos(False)
140
+
141
+ except Exception as e:
142
+ logging.error(e)
143
+ return
144
+ self.clients[websocket].add_frames(frame_np)
145
+
146
+ elapsed_time = time.time() - self.clients_start_time[websocket]
147
+ if elapsed_time >= self.max_connection_time:
148
+ self.clients[websocket].disconnect()
149
+ logging.warning(f"{self.clients[websocket]} Client disconnected due to overtime.")
150
+ self.clients[websocket].cleanup()
151
+ self.clients.pop(websocket)
152
+ self.clients_start_time.pop(websocket)
153
+ websocket.close()
154
+ del websocket
155
+ break
156
+
157
+ except Exception as e:
158
+ logging.error(e)
159
+ self.clients[websocket].cleanup()
160
+ self.clients.pop(websocket)
161
+ self.clients_start_time.pop(websocket)
162
+ logging.info("Connection Closed.")
163
+ logging.info(self.clients)
164
+ del websocket
165
+ break
166
+
167
+ def run(self, host, port=9090):
168
+ """
169
+ Run the transcription server.
170
+
171
+ Args:
172
+ host (str): The host address to bind the server.
173
+ port (int): The port number to bind the server.
174
+ """
175
+ with serve(self.recv_audio, host, port) as server:
176
+ server.serve_forever()
177
+
178
+
179
+ class ServeClient:
180
+ """
181
+ Attributes:
182
+ RATE (int): The audio sampling rate (constant) set to 16000.
183
+ SERVER_READY (str): A constant message indicating that the server is ready.
184
+ DISCONNECT (str): A constant message indicating that the client should disconnect.
185
+ client_uid (str): A unique identifier for the client.
186
+ data (bytes): Accumulated audio data.
187
+ frames (bytes): Accumulated audio frames.
188
+ language (str): The language for transcription.
189
+ task (str): The task type, e.g., "transcribe."
190
+ transcriber (WhisperModel): The Whisper model for speech-to-text.
191
+ timestamp_offset (float): The offset in audio timestamps.
192
+ frames_np (numpy.ndarray): NumPy array to store audio frames.
193
+ frames_offset (float): The offset in audio frames.
194
+ text (list): List of transcribed text segments.
195
+ current_out (str): The current incomplete transcription.
196
+ prev_out (str): The previous incomplete transcription.
197
+ t_start (float): Timestamp for the start of transcription.
198
+ exit (bool): A flag to exit the transcription thread.
199
+ same_output_threshold (int): Threshold for consecutive same output segments.
200
+ show_prev_out_thresh (int): Threshold for showing previous output segments.
201
+ add_pause_thresh (int): Threshold for adding a pause (blank) segment.
202
+ transcript (list): List of transcribed segments.
203
+ send_last_n_segments (int): Number of last segments to send to the client.
204
+ wrapper (textwrap.TextWrapper): Text wrapper for formatting text.
205
+ pick_previous_segments (int): Number of previous segments to include in the output.
206
+ websocket: The WebSocket connection for the client.
207
+ """
208
+ RATE = 16000
209
+ SERVER_READY = "SERVER_READY"
210
+ DISCONNECT = "DISCONNECT"
211
+
212
+ def __init__(self, websocket, task="transcribe", device=None, multilingual=False, language=None, client_uid=None):
213
+ """
214
+ Initialize a ServeClient instance.
215
+ The Whisper model is initialized based on the client's language and device availability.
216
+ The transcription thread is started upon initialization. A "SERVER_READY" message is sent
217
+ to the client to indicate that the server is ready.
218
+
219
+ Args:
220
+ websocket (WebSocket): The WebSocket connection for the client.
221
+ task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
222
+ device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
223
+ multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
224
+ language (str, optional): The language for transcription. Defaults to None.
225
+ client_uid (str, optional): A unique identifier for the client. Defaults to None.
226
+
227
+ """
228
+ self.client_uid = client_uid
229
+ self.data = b""
230
+ self.frames = b""
231
+ self.language = language if multilingual else "en"
232
+ self.task = task
233
+ device = "cuda" if torch.cuda.is_available() else "cpu"
234
+ self.transcriber = WhisperTRTLLM(
235
+ "whisper_small_en", False, "assets", device="cuda")
236
+
237
+ self.timestamp_offset = 0.0
238
+ self.frames_np = None
239
+ self.frames_offset = 0.0
240
+ self.text = []
241
+ self.current_out = ''
242
+ self.prev_out = ''
243
+ self.t_start=None
244
+ self.exit = False
245
+ self.same_output_threshold = 0
246
+ self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
247
+ self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
248
+ self.transcript = []
249
+ self.send_last_n_segments = 10
250
+
251
+ # text formatting
252
+ self.wrapper = textwrap.TextWrapper(width=50)
253
+ self.pick_previous_segments = 2
254
+
255
+ # threading
256
+ self.websocket = websocket
257
+ self.lock = threading.Lock()
258
+ self.eos = False
259
+ self.trans_thread = threading.Thread(target=self.speech_to_text)
260
+ self.trans_thread.start()
261
+ self.websocket.send(
262
+ json.dumps(
263
+ {
264
+ "uid": self.client_uid,
265
+ "message": self.SERVER_READY
266
+ }
267
+ )
268
+ )
269
+
270
+ def set_eos(self, eos):
271
+ self.lock.acquire()
272
+ # if self.eos != eos:
273
+ # logging.info(f"[WhisperLive:] setting eos: {eos}")
274
+ self.eos = eos
275
+ self.lock.release()
276
+
277
+ def add_frames(self, frame_np):
278
+ """
279
+ Add audio frames to the ongoing audio stream buffer.
280
+
281
+ This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
282
+ of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
283
+ to prevent excessive memory usage.
284
+
285
+ If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
286
+ of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
287
+ audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
288
+
289
+ Args:
290
+ frame_np (numpy.ndarray): The audio frame data as a NumPy array.
291
+
292
+ """
293
+ self.lock.acquire()
294
+ if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
295
+ self.frames_offset += 30.0
296
+ self.frames_np = self.frames_np[int(30*self.RATE):]
297
+ if self.frames_np is None:
298
+ self.frames_np = frame_np.copy()
299
+ else:
300
+ self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
301
+ self.lock.release()
302
+
303
+ def speech_to_text(self):
304
+ """
305
+ Process an audio stream in an infinite loop, continuously transcribing the speech.
306
+
307
+ This method continuously receives audio frames, performs real-time transcription, and sends
308
+ transcribed segments to the client via a WebSocket connection.
309
+
310
+ If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
311
+ It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
312
+ are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
313
+ (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
314
+ there is no speech for a specified duration to indicate a pause.
315
+
316
+ Raises:
317
+ Exception: If there is an issue with audio processing or WebSocket communication.
318
+
319
+ """
320
+ while True:
321
+ if self.exit:
322
+ logging.info("Exiting speech to text thread")
323
+ break
324
+
325
+ if self.frames_np is None:
326
+ continue
327
+
328
+ # clip audio if the current chunk exceeds 30 seconds, this basically implies that
329
+ # no valid segment for the last 30 seconds from whisper
330
+ if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
331
+ duration = self.frames_np.shape[0] / self.RATE
332
+ self.timestamp_offset = self.frames_offset + duration - 5
333
+
334
+ samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
335
+ input_bytes = self.frames_np[int(samples_take):].copy()
336
+ duration = input_bytes.shape[0] / self.RATE
337
+ if duration<1.0 or not self.eos:
338
+ continue
339
+
340
+ try:
341
+ input_sample = input_bytes.copy()
342
+ save_wav(input_sample)
343
+ # whisper transcribe with prompt
344
+ mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
345
+ print(mel.shape, duration)
346
+ result = self.transcriber.transcribe(mel)
347
+ self.append_segment(result)
348
+ self.set_eos(False)
349
+ self.timestamp_offset += duration
350
+ if len(result):
351
+ segments = self.transcript[-self.send_last_n_segments:]
352
+ try:
353
+ self.websocket.send(
354
+ json.dumps({
355
+ "uid": self.client_uid,
356
+ "segments": segments
357
+ })
358
+ )
359
+ except Exception as e:
360
+ logging.error(f"[ERROR]: {e}")
361
+
362
+ except Exception as e:
363
+ logging.error(f"[ERROR]: {e}")
364
+ time.sleep(0.01)
365
+
366
+ def append_segment(self, result):
367
+ if not len(self.transcript):
368
+ self.transcript.append({"text": result + " "})
369
+ else:
370
+ if self.transcript[-1]["text"].strip()[-1] == ".":
371
+ if result[0] >= "a" and result[0] <= "z":
372
+ self.transcript[-1]["text"] = replace_last_occurrence(
373
+ self.transcript[-1]["text"], ".", ","
374
+ )
375
+ elif self.transcript[-1]["text"].strip()[-1] == "?":
376
+ if result[0] >= "a" and result[0] <= "z":
377
+ self.transcript[-1]["text"] = replace_last_occurrence(
378
+ self.transcript[-1]["text"], "?", ","
379
+ )
380
+
381
+ self.transcript.append({"text": result + " "})
382
+
383
+
384
+ def update_segments(self, segments, duration):
385
+ """
386
+ Processes the segments from whisper. Appends all the segments to the list
387
+ except for the last segment assuming that it is incomplete.
388
+
389
+ Updates the ongoing transcript with transcribed segments, including their start and end times.
390
+ Complete segments are appended to the transcript in chronological order. Incomplete segments
391
+ (assumed to be the last one) are processed to identify repeated content. If the same incomplete
392
+ segment is seen multiple times, it updates the offset and appends the segment to the transcript.
393
+ A threshold is used to detect repeated content and ensure it is only included once in the transcript.
394
+ The timestamp offset is updated based on the duration of processed segments. The method returns the
395
+ last processed segment, allowing it to be sent to the client for real-time updates.
396
+
397
+ Args:
398
+ segments(dict) : dictionary of segments as returned by whisper
399
+ duration(float): duration of the current chunk
400
+
401
+ Returns:
402
+ dict or None: The last processed segment with its start time, end time, and transcribed text.
403
+ Returns None if there are no valid segments to process.
404
+ """
405
+ offset = None
406
+ self.current_out = ''
407
+ last_segment = None
408
+ # process complete segments
409
+ if len(segments) > 1:
410
+ for i, s in enumerate(segments[:-1]):
411
+ text_ = s.text
412
+ self.text.append(text_)
413
+ start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
414
+ self.transcript.append(
415
+ {
416
+ 'start': start,
417
+ 'end': end,
418
+ 'text': text_
419
+ }
420
+ )
421
+
422
+ offset = min(duration, s.end)
423
+
424
+ self.current_out += segments[-1].text
425
+ last_segment = {
426
+ 'start': self.timestamp_offset + segments[-1].start,
427
+ 'end': self.timestamp_offset + min(duration, segments[-1].end),
428
+ 'text': self.current_out
429
+ }
430
+
431
+ # if same incomplete segment is seen multiple times then update the offset
432
+ # and append the segment to the list
433
+ if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
434
+ self.same_output_threshold += 1
435
+ else:
436
+ self.same_output_threshold = 0
437
+
438
+ if self.same_output_threshold > 5:
439
+ if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower():
440
+ self.text.append(self.current_out)
441
+ self.transcript.append(
442
+ {
443
+ 'start': self.timestamp_offset,
444
+ 'end': self.timestamp_offset + duration,
445
+ 'text': self.current_out
446
+ }
447
+ )
448
+ self.current_out = ''
449
+ offset = duration
450
+ self.same_output_threshold = 0
451
+ last_segment = None
452
+ else:
453
+ self.prev_out = self.current_out
454
+
455
+ # update offset
456
+ if offset is not None:
457
+ self.timestamp_offset += offset
458
+
459
+ return last_segment
460
+
461
+ def disconnect(self):
462
+ """
463
+ Notify the client of disconnection and send a disconnect message.
464
+
465
+ This method sends a disconnect message to the client via the WebSocket connection to notify them
466
+ that the transcription service is disconnecting gracefully.
467
+
468
+ """
469
+ self.websocket.send(
470
+ json.dumps(
471
+ {
472
+ "uid": self.client_uid,
473
+ "message": self.DISCONNECT
474
+ }
475
+ )
476
+ )
477
+
478
+ def cleanup(self):
479
+ """
480
+ Perform cleanup tasks before exiting the transcription service.
481
+
482
+ This method performs necessary cleanup tasks, including stopping the transcription thread, marking
483
+ the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
484
+ associated with the transcription process.
485
+
486
+ """
487
+ logging.info("Cleaning up.")
488
+ self.exit = True
489
+ self.transcriber.destroy()
490
+
491
+ def replace_last_occurrence(input_str, old_char, new_char):
492
+ parts = input_str.rsplit(old_char, 1)
493
+ if len(parts) == 2:
494
+ return parts[0] + new_char + parts[1]
495
+ else:
496
+ return input_str
whisper_live/trt_transcriber.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import re
4
+ import time
5
+ from collections import OrderedDict
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
8
+
9
+ import torch
10
+ import numpy as np
11
+ from whisper.tokenizer import get_tokenizer
12
+ from whisper_live.whisper_utils import (mel_filters, store_transcripts,
13
+ write_error_stats, load_audio_wav_format,
14
+ pad_or_trim)
15
+
16
+ import tensorrt_llm
17
+ import tensorrt_llm.logger as logger
18
+ from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
19
+ trt_dtype_to_torch)
20
+ from tensorrt_llm.runtime import ModelConfig, SamplingConfig
21
+ from tensorrt_llm.runtime.session import Session, TensorInfo
22
+
23
+
24
+ SAMPLE_RATE = 16000
25
+ N_FFT = 400
26
+ HOP_LENGTH = 160
27
+ CHUNK_LENGTH = 30
28
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
29
+
30
+
31
+ class WhisperEncoding:
32
+
33
+ def __init__(self, engine_dir):
34
+ self.session = self.get_session(engine_dir)
35
+
36
+ def get_session(self, engine_dir):
37
+ config_path = engine_dir / 'encoder_config.json'
38
+ with open(config_path, 'r') as f:
39
+ config = json.load(f)
40
+
41
+ use_gpt_attention_plugin = config['plugin_config'][
42
+ 'gpt_attention_plugin']
43
+ dtype = config['builder_config']['precision']
44
+ n_mels = config['builder_config']['n_mels']
45
+ num_languages = config['builder_config']['num_languages']
46
+
47
+ self.dtype = dtype
48
+ self.n_mels = n_mels
49
+ self.num_languages = num_languages
50
+
51
+ serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine'
52
+
53
+ with open(serialize_path, 'rb') as f:
54
+ session = Session.from_serialized_engine(f.read())
55
+
56
+ return session
57
+
58
+ def get_audio_features(self, mel):
59
+ inputs = OrderedDict()
60
+ output_list = []
61
+
62
+ inputs.update({'x': mel})
63
+ output_list.append(
64
+ TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape))
65
+
66
+ output_info = (self.session).infer_shapes(output_list)
67
+
68
+ logger.debug(f'output info {output_info}')
69
+ outputs = {
70
+ t.name: torch.empty(tuple(t.shape),
71
+ dtype=trt_dtype_to_torch(t.dtype),
72
+ device='cuda')
73
+ for t in output_info
74
+ }
75
+ stream = torch.cuda.current_stream()
76
+ ok = self.session.run(inputs=inputs,
77
+ outputs=outputs,
78
+ stream=stream.cuda_stream)
79
+ assert ok, 'Engine execution failed'
80
+ stream.synchronize()
81
+ audio_features = outputs['output']
82
+ return audio_features
83
+
84
+
85
+ class WhisperDecoding:
86
+
87
+ def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
88
+
89
+ self.decoder_config = self.get_config(engine_dir)
90
+ self.decoder_generation_session = self.get_session(
91
+ engine_dir, runtime_mapping, debug_mode)
92
+
93
+ def get_config(self, engine_dir):
94
+ config_path = engine_dir / 'decoder_config.json'
95
+ with open(config_path, 'r') as f:
96
+ config = json.load(f)
97
+ decoder_config = OrderedDict()
98
+ decoder_config.update(config['plugin_config'])
99
+ decoder_config.update(config['builder_config'])
100
+ return decoder_config
101
+
102
+ def get_session(self, engine_dir, runtime_mapping, debug_mode=False):
103
+ dtype = self.decoder_config['precision']
104
+ serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine'
105
+ with open(serialize_path, "rb") as f:
106
+ decoder_engine_buffer = f.read()
107
+
108
+ decoder_model_config = ModelConfig(
109
+ num_heads=self.decoder_config['num_heads'],
110
+ num_kv_heads=self.decoder_config['num_heads'],
111
+ hidden_size=self.decoder_config['hidden_size'],
112
+ vocab_size=self.decoder_config['vocab_size'],
113
+ num_layers=self.decoder_config['num_layers'],
114
+ gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'],
115
+ remove_input_padding=self.decoder_config['remove_input_padding'],
116
+ cross_attention=self.decoder_config['cross_attention'],
117
+ has_position_embedding=self.
118
+ decoder_config['has_position_embedding'],
119
+ has_token_type_embedding=self.
120
+ decoder_config['has_token_type_embedding'],
121
+ )
122
+ decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
123
+ decoder_model_config,
124
+ decoder_engine_buffer,
125
+ runtime_mapping,
126
+ debug_mode=debug_mode)
127
+
128
+ return decoder_generation_session
129
+
130
+ def generate(self,
131
+ decoder_input_ids,
132
+ encoder_outputs,
133
+ eot_id,
134
+ max_new_tokens=40,
135
+ num_beams=1):
136
+ encoder_input_lengths = torch.tensor(
137
+ [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])],
138
+ dtype=torch.int32,
139
+ device='cuda')
140
+
141
+ decoder_input_lengths = torch.tensor([
142
+ decoder_input_ids.shape[-1]
143
+ for _ in range(decoder_input_ids.shape[0])
144
+ ],
145
+ dtype=torch.int32,
146
+ device='cuda')
147
+ decoder_max_input_length = torch.max(decoder_input_lengths).item()
148
+
149
+ # generation config
150
+ sampling_config = SamplingConfig(end_id=eot_id,
151
+ pad_id=eot_id,
152
+ num_beams=num_beams)
153
+ self.decoder_generation_session.setup(
154
+ decoder_input_lengths.size(0),
155
+ decoder_max_input_length,
156
+ max_new_tokens,
157
+ beam_width=num_beams,
158
+ encoder_max_input_length=encoder_outputs.shape[1])
159
+
160
+ torch.cuda.synchronize()
161
+
162
+ decoder_input_ids = decoder_input_ids.type(torch.int32).cuda()
163
+ output_ids = self.decoder_generation_session.decode(
164
+ decoder_input_ids,
165
+ decoder_input_lengths,
166
+ sampling_config,
167
+ encoder_output=encoder_outputs,
168
+ encoder_input_lengths=encoder_input_lengths,
169
+ )
170
+ torch.cuda.synchronize()
171
+
172
+ # get the list of int from output_ids tensor
173
+ output_ids = output_ids.cpu().numpy().tolist()
174
+ return output_ids
175
+
176
+
177
+ class WhisperTRTLLM(object):
178
+
179
+ def __init__(
180
+ self,
181
+ engine_dir,
182
+ debug_mode=False,
183
+ assets_dir=None,
184
+ device=None
185
+ ):
186
+ world_size = 1
187
+ runtime_rank = tensorrt_llm.mpi_rank()
188
+ runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
189
+ torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
190
+ engine_dir = Path(engine_dir)
191
+
192
+ self.encoder = WhisperEncoding(engine_dir)
193
+ self.decoder = WhisperDecoding(engine_dir,
194
+ runtime_mapping,
195
+ debug_mode=False)
196
+ self.n_mels = self.encoder.n_mels
197
+ # self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages,
198
+ # tokenizer_dir=assets_dir)
199
+ self.device = device
200
+ self.tokenizer = get_tokenizer(
201
+ False,
202
+ num_languages=self.encoder.num_languages,
203
+ language="en",
204
+ task="transcribe",
205
+ )
206
+ self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir)
207
+
208
+ def log_mel_spectrogram(
209
+ self,
210
+ audio: Union[str, np.ndarray, torch.Tensor],
211
+ padding: int = 0,
212
+ return_duration = True
213
+ ):
214
+ """
215
+ Compute the log-Mel spectrogram of
216
+
217
+ Parameters
218
+ ----------
219
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
220
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
221
+
222
+ n_mels: int
223
+ The number of Mel-frequency filters, only 80 and 128 are supported
224
+
225
+ padding: int
226
+ Number of zero samples to pad to the right
227
+
228
+ device: Optional[Union[str, torch.device]]
229
+ If given, the audio tensor is moved to this device before STFT
230
+
231
+ Returns
232
+ -------
233
+ torch.Tensor, shape = (80 or 128, n_frames)
234
+ A Tensor that contains the Mel spectrogram
235
+ """
236
+ if not torch.is_tensor(audio):
237
+ if isinstance(audio, str):
238
+ if audio.endswith('.wav'):
239
+ audio, _ = load_audio_wav_format(audio)
240
+ else:
241
+ audio = load_audio(audio)
242
+ assert isinstance(audio,
243
+ np.ndarray), f"Unsupported audio type: {type(audio)}"
244
+ duration = audio.shape[-1] / SAMPLE_RATE
245
+ audio = pad_or_trim(audio, N_SAMPLES)
246
+ audio = audio.astype(np.float32)
247
+ audio = torch.from_numpy(audio)
248
+
249
+ if self.device is not None:
250
+ audio = audio.to(self.device)
251
+ if padding > 0:
252
+ audio = F.pad(audio, (0, padding))
253
+ window = torch.hann_window(N_FFT).to(audio.device)
254
+ stft = torch.stft(audio,
255
+ N_FFT,
256
+ HOP_LENGTH,
257
+ window=window,
258
+ return_complex=True)
259
+ magnitudes = stft[..., :-1].abs()**2
260
+
261
+
262
+ mel_spec = self.filters @ magnitudes
263
+
264
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
265
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
266
+ log_spec = (log_spec + 4.0) / 4.0
267
+ if return_duration:
268
+ return log_spec, duration
269
+ else:
270
+ return log_spec
271
+
272
+
273
+ def process_batch(
274
+ self,
275
+ mel,
276
+ text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
277
+ num_beams=1):
278
+ prompt_id = self.tokenizer.encode(
279
+ text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys()))
280
+
281
+ prompt_id = torch.tensor(prompt_id)
282
+ batch_size = mel.shape[0]
283
+ decoder_input_ids = prompt_id.repeat(batch_size, 1)
284
+
285
+ encoder_output = self.encoder.get_audio_features(mel)
286
+ output_ids = self.decoder.generate(decoder_input_ids,
287
+ encoder_output,
288
+ self.tokenizer.eot,
289
+ max_new_tokens=96,
290
+ num_beams=num_beams)
291
+ texts = []
292
+ for i in range(len(output_ids)):
293
+ text = self.tokenizer.decode(output_ids[i][0]).strip()
294
+ texts.append(text)
295
+ return texts
296
+
297
+ def transcribe(
298
+ self,
299
+ mel,
300
+ text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
301
+ dtype='float16',
302
+ batch_size=1,
303
+ num_beams=1,
304
+ ):
305
+ mel = mel.type(str_dtype_to_torch(dtype))
306
+ mel = mel.unsqueeze(0)
307
+ predictions = self.process_batch(mel, text_prefix, num_beams)
308
+ prediction = predictions[0]
309
+
310
+ # remove all special tokens in the prediction
311
+ prediction = re.sub(r'<\|.*?\|>', '', prediction)
312
+ return prediction.strip()
313
+
314
+
315
+ def decode_wav_file(
316
+ model,
317
+ mel,
318
+ text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
319
+ dtype='float16',
320
+ batch_size=1,
321
+ num_beams=1,
322
+ normalizer=None,
323
+ mel_filters_dir=None):
324
+
325
+ mel = mel.type(str_dtype_to_torch(dtype))
326
+ mel = mel.unsqueeze(0)
327
+ # repeat the mel spectrogram to match the batch size
328
+ mel = mel.repeat(batch_size, 1, 1)
329
+ predictions = model.process_batch(mel, text_prefix, num_beams)
330
+ prediction = predictions[0]
331
+
332
+ # remove all special tokens in the prediction
333
+ prediction = re.sub(r'<\|.*?\|>', '', prediction)
334
+ if normalizer:
335
+ prediction = normalizer(prediction)
336
+
337
+ return prediction.strip()
338
+
339
+
340
+ if __name__=="__main__":
341
+ tensorrt_llm.logger.set_level("error")
342
+ model = WhisperTRTLLM("../whisper_small_en", False, "../assets", device="cuda")
343
+ mel, total_duration = model.log_mel_spectrogram(
344
+ "/root/Code/outputs/output3.wav",
345
+ )
346
+ results = model.transcribe(mel)
347
+ print(results, total_duration)
whisper_live/vad.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original: https://github.com/snakers4/silero-vad/blob/master/utils_vad.py
2
+
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import numpy as np
7
+ import onnxruntime
8
+
9
+
10
+ class VoiceActivityDetection():
11
+
12
+ def __init__(self, force_onnx_cpu=True):
13
+ path = self.download()
14
+ opts = onnxruntime.SessionOptions()
15
+ opts.log_severity_level = 3
16
+
17
+ opts.inter_op_num_threads = 1
18
+ opts.intra_op_num_threads = 1
19
+
20
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
21
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
22
+ else:
23
+ self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
24
+
25
+
26
+ self.reset_states()
27
+ self.sample_rates = [8000, 16000]
28
+
29
+ def _validate_input(self, x, sr: int):
30
+ if x.dim() == 1:
31
+ x = x.unsqueeze(0)
32
+ if x.dim() > 2:
33
+ raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
34
+
35
+ if sr != 16000 and (sr % 16000 == 0):
36
+ step = sr // 16000
37
+ x = x[:,::step]
38
+ sr = 16000
39
+
40
+ if sr not in self.sample_rates:
41
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
42
+
43
+ if sr / x.shape[1] > 31.25:
44
+ raise ValueError("Input audio chunk is too short")
45
+
46
+ return x, sr
47
+
48
+ def reset_states(self, batch_size=1):
49
+ self._h = np.zeros((2, batch_size, 64)).astype('float32')
50
+ self._c = np.zeros((2, batch_size, 64)).astype('float32')
51
+ self._last_sr = 0
52
+ self._last_batch_size = 0
53
+
54
+ def __call__(self, x, sr: int):
55
+
56
+ x, sr = self._validate_input(x, sr)
57
+ batch_size = x.shape[0]
58
+
59
+ if not self._last_batch_size:
60
+ self.reset_states(batch_size)
61
+ if (self._last_sr) and (self._last_sr != sr):
62
+ self.reset_states(batch_size)
63
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
64
+ self.reset_states(batch_size)
65
+
66
+ if sr in [8000, 16000]:
67
+ ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
68
+ ort_outs = self.session.run(None, ort_inputs)
69
+ out, self._h, self._c = ort_outs
70
+ else:
71
+ raise ValueError()
72
+
73
+ self._last_sr = sr
74
+ self._last_batch_size = batch_size
75
+
76
+ out = torch.tensor(out)
77
+ return out
78
+
79
+ def audio_forward(self, x, sr: int, num_samples: int = 512):
80
+ outs = []
81
+ x, sr = self._validate_input(x, sr)
82
+
83
+ if x.shape[1] % num_samples:
84
+ pad_num = num_samples - (x.shape[1] % num_samples)
85
+ x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
86
+
87
+ self.reset_states(x.shape[0])
88
+ for i in range(0, x.shape[1], num_samples):
89
+ wavs_batch = x[:, i:i+num_samples]
90
+ out_chunk = self.__call__(wavs_batch, sr)
91
+ outs.append(out_chunk)
92
+
93
+ stacked = torch.cat(outs, dim=1)
94
+ return stacked.cpu()
95
+
96
+ @staticmethod
97
+ def download(model_url="https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx"):
98
+ target_dir = os.path.expanduser("~/.cache/whisper-live/")
99
+
100
+ # Ensure the target directory exists
101
+ os.makedirs(target_dir, exist_ok=True)
102
+
103
+ # Define the target file path
104
+ model_filename = os.path.join(target_dir, "silero_vad.onnx")
105
+
106
+ # Check if the model file already exists
107
+ if not os.path.exists(model_filename):
108
+ # If it doesn't exist, download the model using wget
109
+ print("Downloading VAD ONNX model...")
110
+ try:
111
+ subprocess.run(["wget", "-O", model_filename, model_url], check=True)
112
+ except subprocess.CalledProcessError:
113
+ print("Failed to download the model using wget.")
114
+ return model_filename
whisper_live/whisper_utils.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ import os
17
+ from collections import defaultdict
18
+ from functools import lru_cache
19
+ from pathlib import Path
20
+ from subprocess import CalledProcessError, run
21
+ from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
22
+
23
+ import kaldialign
24
+ import numpy as np
25
+ import soundfile
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ Pathlike = Union[str, Path]
30
+
31
+ SAMPLE_RATE = 16000
32
+ N_FFT = 400
33
+ HOP_LENGTH = 160
34
+ CHUNK_LENGTH = 30
35
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
36
+
37
+
38
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
39
+ """
40
+ Open an audio file and read as mono waveform, resampling as necessary
41
+
42
+ Parameters
43
+ ----------
44
+ file: str
45
+ The audio file to open
46
+
47
+ sr: int
48
+ The sample rate to resample the audio if necessary
49
+
50
+ Returns
51
+ -------
52
+ A NumPy array containing the audio waveform, in float32 dtype.
53
+ """
54
+
55
+ # This launches a subprocess to decode audio while down-mixing
56
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
57
+ # fmt: off
58
+ cmd = [
59
+ "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac",
60
+ "1", "-acodec", "pcm_s16le", "-ar",
61
+ str(sr), "-"
62
+ ]
63
+ # fmt: on
64
+ try:
65
+ out = run(cmd, capture_output=True, check=True).stdout
66
+ except CalledProcessError as e:
67
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
68
+
69
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
70
+
71
+
72
+ def load_audio_wav_format(wav_path):
73
+ # make sure audio in .wav format
74
+ assert wav_path.endswith(
75
+ '.wav'), f"Only support .wav format, but got {wav_path}"
76
+ waveform, sample_rate = soundfile.read(wav_path)
77
+ assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
78
+ return waveform, sample_rate
79
+
80
+
81
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
82
+ """
83
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
84
+ """
85
+ if torch.is_tensor(array):
86
+ if array.shape[axis] > length:
87
+ array = array.index_select(dim=axis,
88
+ index=torch.arange(length,
89
+ device=array.device))
90
+
91
+ if array.shape[axis] < length:
92
+ pad_widths = [(0, 0)] * array.ndim
93
+ pad_widths[axis] = (0, length - array.shape[axis])
94
+ array = F.pad(array,
95
+ [pad for sizes in pad_widths[::-1] for pad in sizes])
96
+ else:
97
+ if array.shape[axis] > length:
98
+ array = array.take(indices=range(length), axis=axis)
99
+
100
+ if array.shape[axis] < length:
101
+ pad_widths = [(0, 0)] * array.ndim
102
+ pad_widths[axis] = (0, length - array.shape[axis])
103
+ array = np.pad(array, pad_widths)
104
+
105
+ return array
106
+
107
+
108
+ @lru_cache(maxsize=None)
109
+ def mel_filters(device,
110
+ n_mels: int,
111
+ mel_filters_dir: str = None) -> torch.Tensor:
112
+ """
113
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
114
+ Allows decoupling librosa dependency; saved using:
115
+
116
+ np.savez_compressed(
117
+ "mel_filters.npz",
118
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
119
+ )
120
+ """
121
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
122
+ if mel_filters_dir is None:
123
+ mel_filters_path = os.path.join(os.path.dirname(__file__), "assets",
124
+ "mel_filters.npz")
125
+ else:
126
+ mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
127
+ with np.load(mel_filters_path) as f:
128
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
129
+
130
+
131
+ def log_mel_spectrogram(
132
+ audio: Union[str, np.ndarray, torch.Tensor],
133
+ n_mels: int,
134
+ padding: int = 0,
135
+ device: Optional[Union[str, torch.device]] = None,
136
+ return_duration: bool = False,
137
+ mel_filters_dir: str = None,
138
+ ):
139
+ """
140
+ Compute the log-Mel spectrogram of
141
+
142
+ Parameters
143
+ ----------
144
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
145
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
146
+
147
+ n_mels: int
148
+ The number of Mel-frequency filters, only 80 and 128 are supported
149
+
150
+ padding: int
151
+ Number of zero samples to pad to the right
152
+
153
+ device: Optional[Union[str, torch.device]]
154
+ If given, the audio tensor is moved to this device before STFT
155
+
156
+ Returns
157
+ -------
158
+ torch.Tensor, shape = (80 or 128, n_frames)
159
+ A Tensor that contains the Mel spectrogram
160
+ """
161
+ if not torch.is_tensor(audio):
162
+ if isinstance(audio, str):
163
+ if audio.endswith('.wav'):
164
+ audio, _ = load_audio_wav_format(audio)
165
+ else:
166
+ audio = load_audio(audio)
167
+ assert isinstance(audio,
168
+ np.ndarray), f"Unsupported audio type: {type(audio)}"
169
+ duration = audio.shape[-1] / SAMPLE_RATE
170
+ audio = pad_or_trim(audio, N_SAMPLES)
171
+ audio = audio.astype(np.float32)
172
+ audio = torch.from_numpy(audio)
173
+
174
+ if device is not None:
175
+ audio = audio.to(device)
176
+ if padding > 0:
177
+ audio = F.pad(audio, (0, padding))
178
+ window = torch.hann_window(N_FFT).to(audio.device)
179
+ stft = torch.stft(audio,
180
+ N_FFT,
181
+ HOP_LENGTH,
182
+ window=window,
183
+ return_complex=True)
184
+ magnitudes = stft[..., :-1].abs()**2
185
+
186
+ filters = mel_filters(audio.device, n_mels, mel_filters_dir)
187
+ mel_spec = filters @ magnitudes
188
+
189
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
190
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
191
+ log_spec = (log_spec + 4.0) / 4.0
192
+ if return_duration:
193
+ return log_spec, duration
194
+ else:
195
+ return log_spec
196
+
197
+
198
+ def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str,
199
+ str]]) -> None:
200
+ """Save predicted results and reference transcripts to a file.
201
+ https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
202
+ Args:
203
+ filename:
204
+ File to save the results to.
205
+ texts:
206
+ An iterable of tuples. The first element is the cur_id, the second is
207
+ the reference transcript and the third element is the predicted result.
208
+ Returns:
209
+ Return None.
210
+ """
211
+ with open(filename, "w") as f:
212
+ for cut_id, ref, hyp in texts:
213
+ print(f"{cut_id}:\tref={ref}", file=f)
214
+ print(f"{cut_id}:\thyp={hyp}", file=f)
215
+
216
+
217
+ def write_error_stats(
218
+ f: TextIO,
219
+ test_set_name: str,
220
+ results: List[Tuple[str, str]],
221
+ enable_log: bool = True,
222
+ ) -> float:
223
+ """Write statistics based on predicted results and reference transcripts.
224
+ https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
225
+ It will write the following to the given file:
226
+
227
+ - WER
228
+ - number of insertions, deletions, substitutions, corrects and total
229
+ reference words. For example::
230
+
231
+ Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
232
+ reference words (2337 correct)
233
+
234
+ - The difference between the reference transcript and predicted result.
235
+ An instance is given below::
236
+
237
+ THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
238
+
239
+ The above example shows that the reference word is `EDISON`,
240
+ but it is predicted to `ADDISON` (a substitution error).
241
+
242
+ Another example is::
243
+
244
+ FOR THE FIRST DAY (SIR->*) I THINK
245
+
246
+ The reference word `SIR` is missing in the predicted
247
+ results (a deletion error).
248
+ results:
249
+ An iterable of tuples. The first element is the cur_id, the second is
250
+ the reference transcript and the third element is the predicted result.
251
+ enable_log:
252
+ If True, also print detailed WER to the console.
253
+ Otherwise, it is written only to the given file.
254
+ Returns:
255
+ Return None.
256
+ """
257
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
258
+ ins: Dict[str, int] = defaultdict(int)
259
+ dels: Dict[str, int] = defaultdict(int)
260
+
261
+ # `words` stores counts per word, as follows:
262
+ # corr, ref_sub, hyp_sub, ins, dels
263
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
264
+ num_corr = 0
265
+ ERR = "*"
266
+ for cut_id, ref, hyp in results:
267
+ ali = kaldialign.align(ref, hyp, ERR)
268
+ for ref_word, hyp_word in ali:
269
+ if ref_word == ERR:
270
+ ins[hyp_word] += 1
271
+ words[hyp_word][3] += 1
272
+ elif hyp_word == ERR:
273
+ dels[ref_word] += 1
274
+ words[ref_word][4] += 1
275
+ elif hyp_word != ref_word:
276
+ subs[(ref_word, hyp_word)] += 1
277
+ words[ref_word][1] += 1
278
+ words[hyp_word][2] += 1
279
+ else:
280
+ words[ref_word][0] += 1
281
+ num_corr += 1
282
+ ref_len = sum([len(r) for _, r, _ in results])
283
+ sub_errs = sum(subs.values())
284
+ ins_errs = sum(ins.values())
285
+ del_errs = sum(dels.values())
286
+ tot_errs = sub_errs + ins_errs + del_errs
287
+ tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
288
+
289
+ if enable_log:
290
+ logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
291
+ f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
292
+ f"{del_errs} del, {sub_errs} sub ]")
293
+
294
+ print(f"%WER = {tot_err_rate}", file=f)
295
+ print(
296
+ f"Errors: {ins_errs} insertions, {del_errs} deletions, "
297
+ f"{sub_errs} substitutions, over {ref_len} reference "
298
+ f"words ({num_corr} correct)",
299
+ file=f,
300
+ )
301
+ print(
302
+ "Search below for sections starting with PER-UTT DETAILS:, "
303
+ "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
304
+ file=f,
305
+ )
306
+
307
+ print("", file=f)
308
+ print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
309
+ for cut_id, ref, hyp in results:
310
+ ali = kaldialign.align(ref, hyp, ERR)
311
+ combine_successive_errors = True
312
+ if combine_successive_errors:
313
+ ali = [[[x], [y]] for x, y in ali]
314
+ for i in range(len(ali) - 1):
315
+ if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
316
+ ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
317
+ ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
318
+ ali[i] = [[], []]
319
+ ali = [[
320
+ list(filter(lambda a: a != ERR, x)),
321
+ list(filter(lambda a: a != ERR, y)),
322
+ ] for x, y in ali]
323
+ ali = list(filter(lambda x: x != [[], []], ali))
324
+ ali = [[
325
+ ERR if x == [] else " ".join(x),
326
+ ERR if y == [] else " ".join(y),
327
+ ] for x, y in ali]
328
+
329
+ print(
330
+ f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else
331
+ f"({ref_word}->{hyp_word})"
332
+ for ref_word, hyp_word in ali)),
333
+ file=f,
334
+ )
335
+
336
+ print("", file=f)
337
+ print("SUBSTITUTIONS: count ref -> hyp", file=f)
338
+
339
+ for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()],
340
+ reverse=True):
341
+ print(f"{count} {ref} -> {hyp}", file=f)
342
+
343
+ print("", file=f)
344
+ print("DELETIONS: count ref", file=f)
345
+ for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
346
+ print(f"{count} {ref}", file=f)
347
+
348
+ print("", file=f)
349
+ print("INSERTIONS: count hyp", file=f)
350
+ for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
351
+ print(f"{count} {hyp}", file=f)
352
+
353
+ print("", file=f)
354
+ print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp",
355
+ file=f)
356
+ for _, word, counts in sorted([(sum(v[1:]), k, v)
357
+ for k, v in words.items()],
358
+ reverse=True):
359
+ (corr, ref_sub, hyp_sub, ins, dels) = counts
360
+ tot_errs = ref_sub + hyp_sub + ins + dels
361
+ ref_count = corr + ref_sub + dels
362
+ hyp_count = corr + hyp_sub + ins
363
+
364
+ print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
365
+ return float(tot_err_rate)