Fix CLI for parallel devices
Browse files- app.py +4 -1
- cli.py +8 -4
- src/vadParallel.py +12 -5
- src/whisperContainer.py +3 -2
app.py
CHANGED
@@ -60,6 +60,9 @@ class WhisperTranscriber:
|
|
60 |
self.inputAudioMaxDuration = input_audio_max_duration
|
61 |
self.deleteUploadedFiles = delete_uploaded_files
|
62 |
|
|
|
|
|
|
|
63 |
def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
|
64 |
try:
|
65 |
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
|
@@ -255,7 +258,7 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
|
|
255 |
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
|
256 |
|
257 |
# Specify a list of devices to use for parallel processing
|
258 |
-
ui.
|
259 |
|
260 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
261 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
|
|
60 |
self.inputAudioMaxDuration = input_audio_max_duration
|
61 |
self.deleteUploadedFiles = delete_uploaded_files
|
62 |
|
63 |
+
def set_parallel_devices(self, vad_parallel_devices: str):
|
64 |
+
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
65 |
+
|
66 |
def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
|
67 |
try:
|
68 |
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
|
|
|
258 |
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
|
259 |
|
260 |
# Specify a list of devices to use for parallel processing
|
261 |
+
ui.set_parallel_devices(vad_parallel_devices)
|
262 |
|
263 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
264 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
cli.py
CHANGED
@@ -12,6 +12,7 @@ from app import LANGUAGES, WhisperTranscriber
|
|
12 |
from src.download import download_url
|
13 |
|
14 |
from src.utils import optional_float, optional_int, str2bool
|
|
|
15 |
|
16 |
|
17 |
def cli():
|
@@ -31,7 +32,7 @@ def cli():
|
|
31 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
32 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
33 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
34 |
-
parser.add_argument("--vad_parallel_devices", type=str, default="
|
35 |
|
36 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
37 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
@@ -73,9 +74,12 @@ def cli():
|
|
73 |
vad_padding = args.pop("vad_padding")
|
74 |
vad_prompt_window = args.pop("vad_prompt_window")
|
75 |
|
76 |
-
model =
|
77 |
transcriber = WhisperTranscriber(delete_uploaded_files=False)
|
78 |
-
transcriber.
|
|
|
|
|
|
|
79 |
|
80 |
for audio_path in args.pop("audio"):
|
81 |
sources = []
|
@@ -99,7 +103,7 @@ def cli():
|
|
99 |
|
100 |
transcriber.write_result(result, source_name, output_dir)
|
101 |
|
102 |
-
transcriber.
|
103 |
|
104 |
def uri_validator(x):
|
105 |
try:
|
|
|
12 |
from src.download import download_url
|
13 |
|
14 |
from src.utils import optional_float, optional_int, str2bool
|
15 |
+
from src.whisperContainer import WhisperContainer
|
16 |
|
17 |
|
18 |
def cli():
|
|
|
32 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
33 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
34 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
35 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
|
36 |
|
37 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
38 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
|
74 |
vad_padding = args.pop("vad_padding")
|
75 |
vad_prompt_window = args.pop("vad_prompt_window")
|
76 |
|
77 |
+
model = WhisperContainer(model_name, device=device, download_root=model_dir)
|
78 |
transcriber = WhisperTranscriber(delete_uploaded_files=False)
|
79 |
+
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
80 |
+
|
81 |
+
if (transcriber._has_parallel_devices()):
|
82 |
+
print("Using parallel devices:", transcriber.parallel_device_list)
|
83 |
|
84 |
for audio_path in args.pop("audio"):
|
85 |
sources = []
|
|
|
103 |
|
104 |
transcriber.write_result(result, source_name, output_dir)
|
105 |
|
106 |
+
transcriber.close()
|
107 |
|
108 |
def uri_validator(x):
|
109 |
try:
|
src/vadParallel.py
CHANGED
@@ -88,14 +88,20 @@ class ParallelTranscription(AbstractTranscription):
|
|
88 |
|
89 |
# Split into a list for each device
|
90 |
# TODO: Split by time instead of by number of chunks
|
91 |
-
merged_split = self.
|
92 |
|
93 |
# Parameters that will be passed to the transcribe function
|
94 |
parameters = []
|
95 |
segment_index = config.initial_segment_index
|
96 |
|
97 |
for i in range(len(merged_split)):
|
98 |
-
device_segment_list = merged_split[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
# Create a new config with the given device ID
|
101 |
device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
|
@@ -159,7 +165,8 @@ class ParallelTranscription(AbstractTranscription):
|
|
159 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
160 |
return super().transcribe(audio, whisperCallable, config)
|
161 |
|
162 |
-
def
|
163 |
-
"""
|
164 |
-
|
|
|
165 |
|
|
|
88 |
|
89 |
# Split into a list for each device
|
90 |
# TODO: Split by time instead of by number of chunks
|
91 |
+
merged_split = list(self._split(merged, len(devices)))
|
92 |
|
93 |
# Parameters that will be passed to the transcribe function
|
94 |
parameters = []
|
95 |
segment_index = config.initial_segment_index
|
96 |
|
97 |
for i in range(len(merged_split)):
|
98 |
+
device_segment_list = list(merged_split[i])
|
99 |
+
device_id = devices[i]
|
100 |
+
|
101 |
+
if (len(device_segment_list) <= 0):
|
102 |
+
continue
|
103 |
+
|
104 |
+
print("Device " + device_id + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
|
105 |
|
106 |
# Create a new config with the given device ID
|
107 |
device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
|
|
|
165 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
166 |
return super().transcribe(audio, whisperCallable, config)
|
167 |
|
168 |
+
def _split(self, a, n):
|
169 |
+
"""Split a list into n approximately equal parts."""
|
170 |
+
k, m = divmod(len(a), n)
|
171 |
+
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
172 |
|
src/whisperContainer.py
CHANGED
@@ -23,9 +23,10 @@ class WhisperModelCache:
|
|
23 |
GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
|
24 |
|
25 |
class WhisperContainer:
|
26 |
-
def __init__(self, model_name: str, device: str = None, cache: WhisperModelCache = None):
|
27 |
self.model_name = model_name
|
28 |
self.device = device
|
|
|
29 |
self.cache = cache
|
30 |
|
31 |
# Will be created on demand
|
@@ -36,7 +37,7 @@ class WhisperContainer:
|
|
36 |
|
37 |
if (self.cache is None):
|
38 |
print("Loading whisper model " + self.model_name)
|
39 |
-
self.model = whisper.load_model(self.model_name, device=self.device)
|
40 |
else:
|
41 |
self.model = self.cache.get(self.model_name, device=self.device)
|
42 |
return self.model
|
|
|
23 |
GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
|
24 |
|
25 |
class WhisperContainer:
|
26 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: WhisperModelCache = None):
|
27 |
self.model_name = model_name
|
28 |
self.device = device
|
29 |
+
self.download_root = download_root
|
30 |
self.cache = cache
|
31 |
|
32 |
# Will be created on demand
|
|
|
37 |
|
38 |
if (self.cache is None):
|
39 |
print("Loading whisper model " + self.model_name)
|
40 |
+
self.model = whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
41 |
else:
|
42 |
self.model = self.cache.get(self.model_name, device=self.device)
|
43 |
return self.model
|