Spaces:
Runtime error
Runtime error
Add configuration file and support for custom models
Browse filesCustom models can be added to the configuration file,
under the "models" section. See the comments for more
details.
- .gitignore +1 -0
- app.py +37 -17
- cli.py +72 -38
- config.json5 +62 -0
- requirements.txt +3 -1
- src/config.py +134 -0
- src/conversion/hf_converter.py +67 -0
- src/whisperContainer.py +29 -3
.gitignore
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
|
|
3 |
flagged/
|
4 |
*.py[cod]
|
5 |
*$py.class
|
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
3 |
+
.vscode/
|
4 |
flagged/
|
5 |
*.py[cod]
|
6 |
*$py.class
|
app.py
CHANGED
@@ -11,6 +11,7 @@ import zipfile
|
|
11 |
import numpy as np
|
12 |
|
13 |
import torch
|
|
|
14 |
from src.modelCache import ModelCache
|
15 |
from src.source import get_audio_source_collection
|
16 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
@@ -62,7 +63,8 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large
|
|
62 |
|
63 |
class WhisperTranscriber:
|
64 |
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
|
65 |
-
vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None
|
|
|
66 |
self.model_cache = ModelCache()
|
67 |
self.parallel_device_list = None
|
68 |
self.gpu_parallel_context = None
|
@@ -75,6 +77,8 @@ class WhisperTranscriber:
|
|
75 |
self.deleteUploadedFiles = delete_uploaded_files
|
76 |
self.output_dir = output_dir
|
77 |
|
|
|
|
|
78 |
def set_parallel_devices(self, vad_parallel_devices: str):
|
79 |
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
80 |
|
@@ -115,7 +119,7 @@ class WhisperTranscriber:
|
|
115 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
116 |
selectedModel = modelName if modelName is not None else "base"
|
117 |
|
118 |
-
model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
|
119 |
|
120 |
# Result
|
121 |
download = []
|
@@ -360,8 +364,8 @@ class WhisperTranscriber:
|
|
360 |
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
361 |
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
|
362 |
vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
|
363 |
-
output_dir: str = None):
|
364 |
-
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir)
|
365 |
|
366 |
# Specify a list of devices to use for parallel processing
|
367 |
ui.set_parallel_devices(vad_parallel_devices)
|
@@ -378,8 +382,10 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
|
|
378 |
|
379 |
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
|
380 |
|
|
|
|
|
381 |
simple_inputs = lambda : [
|
382 |
-
gr.Dropdown(choices=
|
383 |
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
|
384 |
gr.Text(label="URL (YouTube, etc.)"),
|
385 |
gr.File(label="Upload Files", file_count="multiple"),
|
@@ -429,18 +435,32 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
|
|
429 |
ui.close()
|
430 |
|
431 |
if __name__ == '__main__':
|
|
|
|
|
|
|
432 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
433 |
-
parser.add_argument("--input_audio_max_duration", type=int, default=
|
434 |
-
|
435 |
-
parser.add_argument("--
|
436 |
-
|
437 |
-
parser.add_argument("--
|
438 |
-
|
439 |
-
parser.add_argument("--
|
440 |
-
|
441 |
-
parser.add_argument("--
|
442 |
-
|
443 |
-
parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
args = parser.parse_args().__dict__
|
446 |
-
create_ui(**args)
|
|
|
11 |
import numpy as np
|
12 |
|
13 |
import torch
|
14 |
+
from src.config import ApplicationConfig
|
15 |
from src.modelCache import ModelCache
|
16 |
from src.source import get_audio_source_collection
|
17 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
|
|
63 |
|
64 |
class WhisperTranscriber:
|
65 |
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
|
66 |
+
vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None,
|
67 |
+
app_config: ApplicationConfig = None):
|
68 |
self.model_cache = ModelCache()
|
69 |
self.parallel_device_list = None
|
70 |
self.gpu_parallel_context = None
|
|
|
77 |
self.deleteUploadedFiles = delete_uploaded_files
|
78 |
self.output_dir = output_dir
|
79 |
|
80 |
+
self.app_config = app_config
|
81 |
+
|
82 |
def set_parallel_devices(self, vad_parallel_devices: str):
|
83 |
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
84 |
|
|
|
119 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
120 |
selectedModel = modelName if modelName is not None else "base"
|
121 |
|
122 |
+
model = WhisperContainer(model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
|
123 |
|
124 |
# Result
|
125 |
download = []
|
|
|
364 |
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
365 |
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
|
366 |
vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
|
367 |
+
output_dir: str = None, app_config: ApplicationConfig = None):
|
368 |
+
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir, app_config)
|
369 |
|
370 |
# Specify a list of devices to use for parallel processing
|
371 |
ui.set_parallel_devices(vad_parallel_devices)
|
|
|
382 |
|
383 |
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
|
384 |
|
385 |
+
whisper_models = app_config.get_model_names()
|
386 |
+
|
387 |
simple_inputs = lambda : [
|
388 |
+
gr.Dropdown(choices=whisper_models, value=default_model_name, label="Model"),
|
389 |
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
|
390 |
gr.Text(label="URL (YouTube, etc.)"),
|
391 |
gr.File(label="Upload Files", file_count="multiple"),
|
|
|
435 |
ui.close()
|
436 |
|
437 |
if __name__ == '__main__':
|
438 |
+
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
|
439 |
+
whisper_models = app_config.get_model_names()
|
440 |
+
|
441 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
442 |
+
parser.add_argument("--input_audio_max_duration", type=int, default=app_config.input_audio_max_duration, \
|
443 |
+
help="Maximum audio file length in seconds, or -1 for no limit.") # 600
|
444 |
+
parser.add_argument("--share", type=bool, default=app_config.share, \
|
445 |
+
help="True to share the app on HuggingFace.") # False
|
446 |
+
parser.add_argument("--server_name", type=str, default=app_config.server_name, \
|
447 |
+
help="The host or IP to bind to. If None, bind to localhost.") # None
|
448 |
+
parser.add_argument("--server_port", type=int, default=app_config.server_port, \
|
449 |
+
help="The port to bind to.") # 7860
|
450 |
+
parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=app_config.default_model_name, \
|
451 |
+
help="The default model name.") # medium
|
452 |
+
parser.add_argument("--default_vad", type=str, default=app_config.default_vad, \
|
453 |
+
help="The default VAD.") # silero-vad
|
454 |
+
parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
|
455 |
+
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
456 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
|
457 |
+
help="The number of CPU cores to use for VAD pre-processing.") # 1
|
458 |
+
parser.add_argument("--vad_process_timeout", type=float, default=app_config.vad_process_timeout, \
|
459 |
+
help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
|
460 |
+
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
461 |
+
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
462 |
+
parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
|
463 |
+
help="directory to save the outputs") # None
|
464 |
|
465 |
args = parser.parse_args().__dict__
|
466 |
+
create_ui(app_config=app_config, **args)
|
cli.py
CHANGED
@@ -6,48 +6,81 @@ import warnings
|
|
6 |
import numpy as np
|
7 |
|
8 |
import torch
|
9 |
-
from app import LANGUAGES,
|
|
|
10 |
from src.download import download_url
|
11 |
|
12 |
from src.utils import optional_float, optional_int, str2bool
|
13 |
from src.whisperContainer import WhisperContainer
|
14 |
|
15 |
def cli():
|
|
|
|
|
|
|
16 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
17 |
-
parser.add_argument("audio", nargs="+", type=str,
|
18 |
-
|
19 |
-
parser.add_argument("--
|
20 |
-
|
21 |
-
parser.add_argument("--
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
parser.add_argument("--
|
26 |
-
|
27 |
-
parser.add_argument("--
|
28 |
-
|
29 |
-
|
30 |
-
parser.add_argument("--
|
31 |
-
|
32 |
-
parser.add_argument("--
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
parser.add_argument("--
|
38 |
-
|
39 |
-
parser.add_argument("--
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
parser.add_argument("--
|
44 |
-
|
45 |
-
parser.add_argument("--
|
46 |
-
|
47 |
-
parser.add_argument("--
|
48 |
-
|
49 |
-
parser.add_argument("--
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
args = parser.parse_args().__dict__
|
53 |
model_name: str = args.pop("model")
|
@@ -74,12 +107,13 @@ def cli():
|
|
74 |
vad_prompt_window = args.pop("vad_prompt_window")
|
75 |
vad_cpu_cores = args.pop("vad_cpu_cores")
|
76 |
auto_parallel = args.pop("auto_parallel")
|
77 |
-
|
78 |
-
|
79 |
-
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
|
80 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
81 |
transcriber.set_auto_parallel(auto_parallel)
|
82 |
|
|
|
|
|
83 |
if (transcriber._has_parallel_devices()):
|
84 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
85 |
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
import torch
|
9 |
+
from app import LANGUAGES, WhisperTranscriber
|
10 |
+
from src.config import ApplicationConfig
|
11 |
from src.download import download_url
|
12 |
|
13 |
from src.utils import optional_float, optional_int, str2bool
|
14 |
from src.whisperContainer import WhisperContainer
|
15 |
|
16 |
def cli():
|
17 |
+
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
|
18 |
+
whisper_models = app_config.get_model_names()
|
19 |
+
|
20 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
21 |
+
parser.add_argument("audio", nargs="+", type=str, \
|
22 |
+
help="audio file(s) to transcribe")
|
23 |
+
parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
|
24 |
+
help="name of the Whisper model to use") # medium
|
25 |
+
parser.add_argument("--model_dir", type=str, default=None, \
|
26 |
+
help="the path to save model files; uses ~/.cache/whisper by default")
|
27 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", \
|
28 |
+
help="device to use for PyTorch inference")
|
29 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", \
|
30 |
+
help="directory to save the outputs")
|
31 |
+
parser.add_argument("--verbose", type=str2bool, default=True, \
|
32 |
+
help="whether to print out the progress and debug messages")
|
33 |
+
|
34 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], \
|
35 |
+
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
36 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), \
|
37 |
+
help="language spoken in the audio, specify None to perform language detection")
|
38 |
+
|
39 |
+
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
40 |
+
help="The voice activity detection algorithm to use") # silero-vad
|
41 |
+
parser.add_argument("--vad_merge_window", type=optional_float, default=5, \
|
42 |
+
help="The window size (in seconds) to merge voice segments")
|
43 |
+
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30,\
|
44 |
+
help="The maximum size (in seconds) of a voice segment")
|
45 |
+
parser.add_argument("--vad_padding", type=optional_float, default=1, \
|
46 |
+
help="The padding (in seconds) to add to each voice segment")
|
47 |
+
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, \
|
48 |
+
help="The window size of the prompt to pass to Whisper")
|
49 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
|
50 |
+
help="The number of CPU cores to use for VAD pre-processing.") # 1
|
51 |
+
parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
|
52 |
+
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
53 |
+
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
54 |
+
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
55 |
+
|
56 |
+
parser.add_argument("--temperature", type=float, default=0, \
|
57 |
+
help="temperature to use for sampling")
|
58 |
+
parser.add_argument("--best_of", type=optional_int, default=5, \
|
59 |
+
help="number of candidates when sampling with non-zero temperature")
|
60 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, \
|
61 |
+
help="number of beams in beam search, only applicable when temperature is zero")
|
62 |
+
parser.add_argument("--patience", type=float, default=None, \
|
63 |
+
help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
64 |
+
parser.add_argument("--length_penalty", type=float, default=None, \
|
65 |
+
help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
|
66 |
+
|
67 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", \
|
68 |
+
help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
69 |
+
parser.add_argument("--initial_prompt", type=str, default=None, \
|
70 |
+
help="optional text to provide as a prompt for the first window.")
|
71 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, \
|
72 |
+
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
73 |
+
parser.add_argument("--fp16", type=str2bool, default=True, \
|
74 |
+
help="whether to perform inference in fp16; True by default")
|
75 |
+
|
76 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, \
|
77 |
+
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
78 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, \
|
79 |
+
help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
80 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, \
|
81 |
+
help="if the average log probability is lower than this value, treat the decoding as failed")
|
82 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, \
|
83 |
+
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
84 |
|
85 |
args = parser.parse_args().__dict__
|
86 |
model_name: str = args.pop("model")
|
|
|
107 |
vad_prompt_window = args.pop("vad_prompt_window")
|
108 |
vad_cpu_cores = args.pop("vad_cpu_cores")
|
109 |
auto_parallel = args.pop("auto_parallel")
|
110 |
+
|
111 |
+
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
|
|
112 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
113 |
transcriber.set_auto_parallel(auto_parallel)
|
114 |
|
115 |
+
model = WhisperContainer(model_name, device=device, download_root=model_dir, models=app_config.models)
|
116 |
+
|
117 |
if (transcriber._has_parallel_devices()):
|
118 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
119 |
|
config.json5
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": [
|
3 |
+
// Configuration for the built-in models. You can remove any of these
|
4 |
+
// if you don't want to use the default models.
|
5 |
+
{
|
6 |
+
"name": "tiny",
|
7 |
+
"url": "tiny"
|
8 |
+
},
|
9 |
+
{
|
10 |
+
"name": "base",
|
11 |
+
"url": "base"
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"name": "small",
|
15 |
+
"url": "small"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "medium",
|
19 |
+
"url": "medium"
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"name": "large",
|
23 |
+
"url": "large"
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "large-v2",
|
27 |
+
"url": "large-v2"
|
28 |
+
},
|
29 |
+
// Uncomment to add custom Japanese models
|
30 |
+
//{
|
31 |
+
// "name": "whisper-large-v2-mix-jp",
|
32 |
+
// "url": "vumichien/whisper-large-v2-mix-jp",
|
33 |
+
// // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
|
34 |
+
// // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
|
35 |
+
// "type": "huggingface",
|
36 |
+
//}
|
37 |
+
],
|
38 |
+
// Configuration options that will be used if they are not specified in the command line arguments.
|
39 |
+
|
40 |
+
// Maximum audio file length in seconds, or -1 for no limit.
|
41 |
+
"input_audio_max_duration": 600,
|
42 |
+
// True to share the app on HuggingFace.
|
43 |
+
"share": false,
|
44 |
+
// The host or IP to bind to. If None, bind to localhost.
|
45 |
+
"server_name": null,
|
46 |
+
// The port to bind to.
|
47 |
+
"server_port": 7860,
|
48 |
+
// The default model name.
|
49 |
+
"default_model_name": "medium",
|
50 |
+
// The default VAD.
|
51 |
+
"default_vad": "silero-vad",
|
52 |
+
// A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
|
53 |
+
"vad_parallel_devices": "",
|
54 |
+
// The number of CPU cores to use for VAD pre-processing.
|
55 |
+
"vad_cpu_cores": 1,
|
56 |
+
// The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
|
57 |
+
"vad_process_timeout": 1800,
|
58 |
+
// True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
|
59 |
+
"auto_parallel": false,
|
60 |
+
// Directory to save the outputs
|
61 |
+
"output_dir": null
|
62 |
+
}
|
requirements.txt
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
1 |
git+https://github.com/openai/whisper.git
|
2 |
transformers
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.13.0
|
5 |
yt-dlp
|
6 |
torchaudio
|
7 |
-
altair
|
|
|
|
1 |
+
git+https://github.com/huggingface/transformers
|
2 |
git+https://github.com/openai/whisper.git
|
3 |
transformers
|
4 |
ffmpeg-python==0.2.0
|
5 |
gradio==3.13.0
|
6 |
yt-dlp
|
7 |
torchaudio
|
8 |
+
altair
|
9 |
+
json5
|
src/config.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import urllib
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from src.conversion.hf_converter import convert_hf_whisper
|
10 |
+
|
11 |
+
class ModelConfig:
|
12 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
|
13 |
+
"""
|
14 |
+
Initialize a model configuration.
|
15 |
+
|
16 |
+
name: Name of the model
|
17 |
+
url: URL to download the model from
|
18 |
+
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
19 |
+
type: Type of model. Can be whisper or huggingface.
|
20 |
+
"""
|
21 |
+
self.name = name
|
22 |
+
self.url = url
|
23 |
+
self.path = path
|
24 |
+
self.type = type
|
25 |
+
|
26 |
+
def download_url(self, root_dir: str):
|
27 |
+
import whisper
|
28 |
+
|
29 |
+
# See if path is already set
|
30 |
+
if self.path is not None:
|
31 |
+
return self.path
|
32 |
+
|
33 |
+
if root_dir is None:
|
34 |
+
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
35 |
+
|
36 |
+
model_type = self.type.lower() if self.type is not None else "whisper"
|
37 |
+
|
38 |
+
if model_type in ["huggingface", "hf"]:
|
39 |
+
self.path = self.url
|
40 |
+
destination_target = os.path.join(root_dir, self.name + ".pt")
|
41 |
+
|
42 |
+
# Convert from HuggingFace format to Whisper format
|
43 |
+
if os.path.exists(destination_target):
|
44 |
+
print(f"File {destination_target} already exists, skipping conversion")
|
45 |
+
else:
|
46 |
+
print("Saving HuggingFace model in Whisper format to " + destination_target)
|
47 |
+
convert_hf_whisper(self.url, destination_target)
|
48 |
+
|
49 |
+
self.path = destination_target
|
50 |
+
|
51 |
+
elif model_type in ["whisper", "w"]:
|
52 |
+
self.path = self.url
|
53 |
+
|
54 |
+
# See if URL is just a file
|
55 |
+
if self.url in whisper._MODELS:
|
56 |
+
# No need to download anything - Whisper will handle it
|
57 |
+
self.path = self.url
|
58 |
+
elif self.url.startswith("file://"):
|
59 |
+
# Get file path
|
60 |
+
self.path = urlparse(self.url).path
|
61 |
+
# See if it is an URL
|
62 |
+
elif self.url.startswith("http://") or self.url.startswith("https://"):
|
63 |
+
# Extension (or file name)
|
64 |
+
extension = os.path.splitext(self.url)[-1]
|
65 |
+
download_target = os.path.join(root_dir, self.name + extension)
|
66 |
+
|
67 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
68 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
69 |
+
|
70 |
+
if not os.path.isfile(download_target):
|
71 |
+
self._download_file(self.url, download_target)
|
72 |
+
else:
|
73 |
+
print(f"File {download_target} already exists, skipping download")
|
74 |
+
|
75 |
+
self.path = download_target
|
76 |
+
# Must be a local file
|
77 |
+
else:
|
78 |
+
self.path = self.url
|
79 |
+
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unknown model type {model_type}")
|
82 |
+
|
83 |
+
return self.path
|
84 |
+
|
85 |
+
def _download_file(self, url: str, destination: str):
|
86 |
+
with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
|
87 |
+
with tqdm(
|
88 |
+
total=int(source.info().get("Content-Length")),
|
89 |
+
ncols=80,
|
90 |
+
unit="iB",
|
91 |
+
unit_scale=True,
|
92 |
+
unit_divisor=1024,
|
93 |
+
) as loop:
|
94 |
+
while True:
|
95 |
+
buffer = source.read(8192)
|
96 |
+
if not buffer:
|
97 |
+
break
|
98 |
+
|
99 |
+
output.write(buffer)
|
100 |
+
loop.update(len(buffer))
|
101 |
+
|
102 |
+
class ApplicationConfig:
|
103 |
+
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
104 |
+
share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium",
|
105 |
+
default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
106 |
+
auto_parallel: bool = False, output_dir: str = None):
|
107 |
+
self.models = models
|
108 |
+
self.input_audio_max_duration = input_audio_max_duration
|
109 |
+
self.share = share
|
110 |
+
self.server_name = server_name
|
111 |
+
self.server_port = server_port
|
112 |
+
self.default_model_name = default_model_name
|
113 |
+
self.default_vad = default_vad
|
114 |
+
self.vad_parallel_devices = vad_parallel_devices
|
115 |
+
self.vad_cpu_cores = vad_cpu_cores
|
116 |
+
self.vad_process_timeout = vad_process_timeout
|
117 |
+
self.auto_parallel = auto_parallel
|
118 |
+
self.output_dir = output_dir
|
119 |
+
|
120 |
+
def get_model_names(self):
|
121 |
+
return [ x.name for x in self.models ]
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def parse_file(config_path: str):
|
125 |
+
import json5
|
126 |
+
|
127 |
+
with open(config_path, "r") as f:
|
128 |
+
# Load using json5
|
129 |
+
data = json5.load(f)
|
130 |
+
data_models = data.pop("models", [])
|
131 |
+
|
132 |
+
models = [ ModelConfig(**x) for x in data_models ]
|
133 |
+
|
134 |
+
return ApplicationConfig(models, **data)
|
src/conversion/hf_converter.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
import torch
|
5 |
+
from transformers import WhisperForConditionalGeneration
|
6 |
+
|
7 |
+
WHISPER_MAPPING = {
|
8 |
+
"layers": "blocks",
|
9 |
+
"fc1": "mlp.0",
|
10 |
+
"fc2": "mlp.2",
|
11 |
+
"final_layer_norm": "mlp_ln",
|
12 |
+
"layers": "blocks",
|
13 |
+
".self_attn.q_proj": ".attn.query",
|
14 |
+
".self_attn.k_proj": ".attn.key",
|
15 |
+
".self_attn.v_proj": ".attn.value",
|
16 |
+
".self_attn_layer_norm": ".attn_ln",
|
17 |
+
".self_attn.out_proj": ".attn.out",
|
18 |
+
".encoder_attn.q_proj": ".cross_attn.query",
|
19 |
+
".encoder_attn.k_proj": ".cross_attn.key",
|
20 |
+
".encoder_attn.v_proj": ".cross_attn.value",
|
21 |
+
".encoder_attn_layer_norm": ".cross_attn_ln",
|
22 |
+
".encoder_attn.out_proj": ".cross_attn.out",
|
23 |
+
"decoder.layer_norm.": "decoder.ln.",
|
24 |
+
"encoder.layer_norm.": "encoder.ln_post.",
|
25 |
+
"embed_tokens": "token_embedding",
|
26 |
+
"encoder.embed_positions.weight": "encoder.positional_embedding",
|
27 |
+
"decoder.embed_positions.weight": "decoder.positional_embedding",
|
28 |
+
"layer_norm": "ln_post",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def rename_keys(s_dict):
|
33 |
+
keys = list(s_dict.keys())
|
34 |
+
for key in keys:
|
35 |
+
new_key = key
|
36 |
+
for k, v in WHISPER_MAPPING.items():
|
37 |
+
if k in key:
|
38 |
+
new_key = new_key.replace(k, v)
|
39 |
+
|
40 |
+
print(f"{key} -> {new_key}")
|
41 |
+
|
42 |
+
s_dict[new_key] = s_dict.pop(key)
|
43 |
+
return s_dict
|
44 |
+
|
45 |
+
|
46 |
+
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
47 |
+
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
48 |
+
config = transformer_model.config
|
49 |
+
|
50 |
+
# first build dims
|
51 |
+
dims = {
|
52 |
+
'n_mels': config.num_mel_bins,
|
53 |
+
'n_vocab': config.vocab_size,
|
54 |
+
'n_audio_ctx': config.max_source_positions,
|
55 |
+
'n_audio_state': config.d_model,
|
56 |
+
'n_audio_head': config.encoder_attention_heads,
|
57 |
+
'n_audio_layer': config.encoder_layers,
|
58 |
+
'n_text_ctx': config.max_target_positions,
|
59 |
+
'n_text_state': config.d_model,
|
60 |
+
'n_text_head': config.decoder_attention_heads,
|
61 |
+
'n_text_layer': config.decoder_layers
|
62 |
+
}
|
63 |
+
|
64 |
+
state_dict = deepcopy(transformer_model.model.state_dict())
|
65 |
+
state_dict = rename_keys(state_dict)
|
66 |
+
|
67 |
+
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
|
src/whisperContainer.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
# External programs
|
2 |
import os
|
|
|
3 |
import whisper
|
|
|
4 |
|
5 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
6 |
|
7 |
class WhisperContainer:
|
8 |
-
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
|
|
9 |
self.model_name = model_name
|
10 |
self.device = device
|
11 |
self.download_root = download_root
|
@@ -13,6 +16,9 @@ class WhisperContainer:
|
|
13 |
|
14 |
# Will be created on demand
|
15 |
self.model = None
|
|
|
|
|
|
|
16 |
|
17 |
def get_model(self):
|
18 |
if self.model is None:
|
@@ -32,21 +38,40 @@ class WhisperContainer:
|
|
32 |
# Warning: Using private API here
|
33 |
try:
|
34 |
root_dir = self.download_root
|
|
|
35 |
|
36 |
if root_dir is None:
|
37 |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
38 |
|
39 |
if self.model_name in whisper._MODELS:
|
40 |
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
|
|
|
|
|
|
|
41 |
return True
|
|
|
42 |
except Exception as e:
|
43 |
# Given that the API is private, it could change at any time. We don't want to crash the program
|
44 |
print("Error pre-downloading model: " + str(e))
|
45 |
return False
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def _create_model(self):
|
48 |
print("Loading whisper model " + self.model_name)
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
52 |
"""
|
@@ -71,12 +96,13 @@ class WhisperContainer:
|
|
71 |
|
72 |
# This is required for multiprocessing
|
73 |
def __getstate__(self):
|
74 |
-
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
|
75 |
|
76 |
def __setstate__(self, state):
|
77 |
self.model_name = state["model_name"]
|
78 |
self.device = state["device"]
|
79 |
self.download_root = state["download_root"]
|
|
|
80 |
self.model = None
|
81 |
# Depickled objects must use the global cache
|
82 |
self.cache = GLOBAL_MODEL_CACHE
|
|
|
1 |
# External programs
|
2 |
import os
|
3 |
+
from typing import List
|
4 |
import whisper
|
5 |
+
from src.config import ModelConfig
|
6 |
|
7 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
8 |
|
9 |
class WhisperContainer:
|
10 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
11 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
12 |
self.model_name = model_name
|
13 |
self.device = device
|
14 |
self.download_root = download_root
|
|
|
16 |
|
17 |
# Will be created on demand
|
18 |
self.model = None
|
19 |
+
|
20 |
+
# List of known models
|
21 |
+
self.models = models
|
22 |
|
23 |
def get_model(self):
|
24 |
if self.model is None:
|
|
|
38 |
# Warning: Using private API here
|
39 |
try:
|
40 |
root_dir = self.download_root
|
41 |
+
model_config = self.get_model_config()
|
42 |
|
43 |
if root_dir is None:
|
44 |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
45 |
|
46 |
if self.model_name in whisper._MODELS:
|
47 |
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
|
48 |
+
else:
|
49 |
+
# If the model is not in the official list, see if it needs to be downloaded
|
50 |
+
model_config.download_url(root_dir)
|
51 |
return True
|
52 |
+
|
53 |
except Exception as e:
|
54 |
# Given that the API is private, it could change at any time. We don't want to crash the program
|
55 |
print("Error pre-downloading model: " + str(e))
|
56 |
return False
|
57 |
|
58 |
+
def get_model_config(self) -> ModelConfig:
|
59 |
+
"""
|
60 |
+
Get the model configuration for the model.
|
61 |
+
"""
|
62 |
+
for model in self.models:
|
63 |
+
if model.name == self.model_name:
|
64 |
+
return model
|
65 |
+
return None
|
66 |
+
|
67 |
def _create_model(self):
|
68 |
print("Loading whisper model " + self.model_name)
|
69 |
+
|
70 |
+
model_config = self.get_model_config()
|
71 |
+
# Note that the model will not be downloaded in the case of an official Whisper model
|
72 |
+
model_path = model_config.download_url(self.download_root)
|
73 |
+
|
74 |
+
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
75 |
|
76 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
77 |
"""
|
|
|
96 |
|
97 |
# This is required for multiprocessing
|
98 |
def __getstate__(self):
|
99 |
+
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
|
100 |
|
101 |
def __setstate__(self, state):
|
102 |
self.model_name = state["model_name"]
|
103 |
self.device = state["device"]
|
104 |
self.download_root = state["download_root"]
|
105 |
+
self.models = state["models"]
|
106 |
self.model = None
|
107 |
# Depickled objects must use the global cache
|
108 |
self.cache = GLOBAL_MODEL_CACHE
|