Spaces:
Sleeping
Sleeping
Add configuration file and support for custom models
Browse filesCustom models can be added to the configuration file,
under the "models" section. See the comments for more
details.
- .gitignore +1 -0
- app.py +37 -17
- cli.py +72 -38
- config.json5 +62 -0
- requirements.txt +3 -1
- src/config.py +134 -0
- src/conversion/hf_converter.py +67 -0
- src/whisperContainer.py +29 -3
.gitignore
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
|
|
3 |
flagged/
|
4 |
*.py[cod]
|
5 |
*$py.class
|
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
3 |
+
.vscode/
|
4 |
flagged/
|
5 |
*.py[cod]
|
6 |
*$py.class
|
app.py
CHANGED
@@ -11,6 +11,7 @@ import zipfile
|
|
11 |
import numpy as np
|
12 |
|
13 |
import torch
|
|
|
14 |
from src.modelCache import ModelCache
|
15 |
from src.source import get_audio_source_collection
|
16 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
@@ -62,7 +63,8 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large
|
|
62 |
|
63 |
class WhisperTranscriber:
|
64 |
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
|
65 |
-
vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None
|
|
|
66 |
self.model_cache = ModelCache()
|
67 |
self.parallel_device_list = None
|
68 |
self.gpu_parallel_context = None
|
@@ -75,6 +77,8 @@ class WhisperTranscriber:
|
|
75 |
self.deleteUploadedFiles = delete_uploaded_files
|
76 |
self.output_dir = output_dir
|
77 |
|
|
|
|
|
78 |
def set_parallel_devices(self, vad_parallel_devices: str):
|
79 |
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
80 |
|
@@ -115,7 +119,7 @@ class WhisperTranscriber:
|
|
115 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
116 |
selectedModel = modelName if modelName is not None else "base"
|
117 |
|
118 |
-
model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
|
119 |
|
120 |
# Result
|
121 |
download = []
|
@@ -360,8 +364,8 @@ class WhisperTranscriber:
|
|
360 |
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
361 |
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
|
362 |
vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
|
363 |
-
output_dir: str = None):
|
364 |
-
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir)
|
365 |
|
366 |
# Specify a list of devices to use for parallel processing
|
367 |
ui.set_parallel_devices(vad_parallel_devices)
|
@@ -378,8 +382,10 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
|
|
378 |
|
379 |
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
|
380 |
|
|
|
|
|
381 |
simple_inputs = lambda : [
|
382 |
-
gr.Dropdown(choices=
|
383 |
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
|
384 |
gr.Text(label="URL (YouTube, etc.)"),
|
385 |
gr.File(label="Upload Files", file_count="multiple"),
|
@@ -429,18 +435,32 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
|
|
429 |
ui.close()
|
430 |
|
431 |
if __name__ == '__main__':
|
|
|
|
|
|
|
432 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
433 |
-
parser.add_argument("--input_audio_max_duration", type=int, default=
|
434 |
-
|
435 |
-
parser.add_argument("--
|
436 |
-
|
437 |
-
parser.add_argument("--
|
438 |
-
|
439 |
-
parser.add_argument("--
|
440 |
-
|
441 |
-
parser.add_argument("--
|
442 |
-
|
443 |
-
parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
args = parser.parse_args().__dict__
|
446 |
-
create_ui(**args)
|
|
|
11 |
import numpy as np
|
12 |
|
13 |
import torch
|
14 |
+
from src.config import ApplicationConfig
|
15 |
from src.modelCache import ModelCache
|
16 |
from src.source import get_audio_source_collection
|
17 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
|
|
63 |
|
64 |
class WhisperTranscriber:
|
65 |
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
|
66 |
+
vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None,
|
67 |
+
app_config: ApplicationConfig = None):
|
68 |
self.model_cache = ModelCache()
|
69 |
self.parallel_device_list = None
|
70 |
self.gpu_parallel_context = None
|
|
|
77 |
self.deleteUploadedFiles = delete_uploaded_files
|
78 |
self.output_dir = output_dir
|
79 |
|
80 |
+
self.app_config = app_config
|
81 |
+
|
82 |
def set_parallel_devices(self, vad_parallel_devices: str):
|
83 |
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
84 |
|
|
|
119 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
120 |
selectedModel = modelName if modelName is not None else "base"
|
121 |
|
122 |
+
model = WhisperContainer(model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
|
123 |
|
124 |
# Result
|
125 |
download = []
|
|
|
364 |
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
365 |
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
|
366 |
vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
|
367 |
+
output_dir: str = None, app_config: ApplicationConfig = None):
|
368 |
+
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir, app_config)
|
369 |
|
370 |
# Specify a list of devices to use for parallel processing
|
371 |
ui.set_parallel_devices(vad_parallel_devices)
|
|
|
382 |
|
383 |
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
|
384 |
|
385 |
+
whisper_models = app_config.get_model_names()
|
386 |
+
|
387 |
simple_inputs = lambda : [
|
388 |
+
gr.Dropdown(choices=whisper_models, value=default_model_name, label="Model"),
|
389 |
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
|
390 |
gr.Text(label="URL (YouTube, etc.)"),
|
391 |
gr.File(label="Upload Files", file_count="multiple"),
|
|
|
435 |
ui.close()
|
436 |
|
437 |
if __name__ == '__main__':
|
438 |
+
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
|
439 |
+
whisper_models = app_config.get_model_names()
|
440 |
+
|
441 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
442 |
+
parser.add_argument("--input_audio_max_duration", type=int, default=app_config.input_audio_max_duration, \
|
443 |
+
help="Maximum audio file length in seconds, or -1 for no limit.") # 600
|
444 |
+
parser.add_argument("--share", type=bool, default=app_config.share, \
|
445 |
+
help="True to share the app on HuggingFace.") # False
|
446 |
+
parser.add_argument("--server_name", type=str, default=app_config.server_name, \
|
447 |
+
help="The host or IP to bind to. If None, bind to localhost.") # None
|
448 |
+
parser.add_argument("--server_port", type=int, default=app_config.server_port, \
|
449 |
+
help="The port to bind to.") # 7860
|
450 |
+
parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=app_config.default_model_name, \
|
451 |
+
help="The default model name.") # medium
|
452 |
+
parser.add_argument("--default_vad", type=str, default=app_config.default_vad, \
|
453 |
+
help="The default VAD.") # silero-vad
|
454 |
+
parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
|
455 |
+
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
456 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
|
457 |
+
help="The number of CPU cores to use for VAD pre-processing.") # 1
|
458 |
+
parser.add_argument("--vad_process_timeout", type=float, default=app_config.vad_process_timeout, \
|
459 |
+
help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
|
460 |
+
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
461 |
+
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
462 |
+
parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
|
463 |
+
help="directory to save the outputs") # None
|
464 |
|
465 |
args = parser.parse_args().__dict__
|
466 |
+
create_ui(app_config=app_config, **args)
|
cli.py
CHANGED
@@ -6,48 +6,81 @@ import warnings
|
|
6 |
import numpy as np
|
7 |
|
8 |
import torch
|
9 |
-
from app import LANGUAGES,
|
|
|
10 |
from src.download import download_url
|
11 |
|
12 |
from src.utils import optional_float, optional_int, str2bool
|
13 |
from src.whisperContainer import WhisperContainer
|
14 |
|
15 |
def cli():
|
|
|
|
|
|
|
16 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
17 |
-
parser.add_argument("audio", nargs="+", type=str,
|
18 |
-
|
19 |
-
parser.add_argument("--
|
20 |
-
|
21 |
-
parser.add_argument("--
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
parser.add_argument("--
|
26 |
-
|
27 |
-
parser.add_argument("--
|
28 |
-
|
29 |
-
|
30 |
-
parser.add_argument("--
|
31 |
-
|
32 |
-
parser.add_argument("--
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
parser.add_argument("--
|
38 |
-
|
39 |
-
parser.add_argument("--
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
parser.add_argument("--
|
44 |
-
|
45 |
-
parser.add_argument("--
|
46 |
-
|
47 |
-
parser.add_argument("--
|
48 |
-
|
49 |
-
parser.add_argument("--
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
args = parser.parse_args().__dict__
|
53 |
model_name: str = args.pop("model")
|
@@ -74,12 +107,13 @@ def cli():
|
|
74 |
vad_prompt_window = args.pop("vad_prompt_window")
|
75 |
vad_cpu_cores = args.pop("vad_cpu_cores")
|
76 |
auto_parallel = args.pop("auto_parallel")
|
77 |
-
|
78 |
-
|
79 |
-
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
|
80 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
81 |
transcriber.set_auto_parallel(auto_parallel)
|
82 |
|
|
|
|
|
83 |
if (transcriber._has_parallel_devices()):
|
84 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
85 |
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
import torch
|
9 |
+
from app import LANGUAGES, WhisperTranscriber
|
10 |
+
from src.config import ApplicationConfig
|
11 |
from src.download import download_url
|
12 |
|
13 |
from src.utils import optional_float, optional_int, str2bool
|
14 |
from src.whisperContainer import WhisperContainer
|
15 |
|
16 |
def cli():
|
17 |
+
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
|
18 |
+
whisper_models = app_config.get_model_names()
|
19 |
+
|
20 |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
21 |
+
parser.add_argument("audio", nargs="+", type=str, \
|
22 |
+
help="audio file(s) to transcribe")
|
23 |
+
parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
|
24 |
+
help="name of the Whisper model to use") # medium
|
25 |
+
parser.add_argument("--model_dir", type=str, default=None, \
|
26 |
+
help="the path to save model files; uses ~/.cache/whisper by default")
|
27 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", \
|
28 |
+
help="device to use for PyTorch inference")
|
29 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", \
|
30 |
+
help="directory to save the outputs")
|
31 |
+
parser.add_argument("--verbose", type=str2bool, default=True, \
|
32 |
+
help="whether to print out the progress and debug messages")
|
33 |
+
|
34 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], \
|
35 |
+
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
36 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), \
|
37 |
+
help="language spoken in the audio, specify None to perform language detection")
|
38 |
+
|
39 |
+
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
40 |
+
help="The voice activity detection algorithm to use") # silero-vad
|
41 |
+
parser.add_argument("--vad_merge_window", type=optional_float, default=5, \
|
42 |
+
help="The window size (in seconds) to merge voice segments")
|
43 |
+
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30,\
|
44 |
+
help="The maximum size (in seconds) of a voice segment")
|
45 |
+
parser.add_argument("--vad_padding", type=optional_float, default=1, \
|
46 |
+
help="The padding (in seconds) to add to each voice segment")
|
47 |
+
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, \
|
48 |
+
help="The window size of the prompt to pass to Whisper")
|
49 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
|
50 |
+
help="The number of CPU cores to use for VAD pre-processing.") # 1
|
51 |
+
parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
|
52 |
+
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
53 |
+
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
54 |
+
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
55 |
+
|
56 |
+
parser.add_argument("--temperature", type=float, default=0, \
|
57 |
+
help="temperature to use for sampling")
|
58 |
+
parser.add_argument("--best_of", type=optional_int, default=5, \
|
59 |
+
help="number of candidates when sampling with non-zero temperature")
|
60 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, \
|
61 |
+
help="number of beams in beam search, only applicable when temperature is zero")
|
62 |
+
parser.add_argument("--patience", type=float, default=None, \
|
63 |
+
help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
64 |
+
parser.add_argument("--length_penalty", type=float, default=None, \
|
65 |
+
help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
|
66 |
+
|
67 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", \
|
68 |
+
help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
69 |
+
parser.add_argument("--initial_prompt", type=str, default=None, \
|
70 |
+
help="optional text to provide as a prompt for the first window.")
|
71 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, \
|
72 |
+
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
73 |
+
parser.add_argument("--fp16", type=str2bool, default=True, \
|
74 |
+
help="whether to perform inference in fp16; True by default")
|
75 |
+
|
76 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, \
|
77 |
+
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
78 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, \
|
79 |
+
help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
80 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, \
|
81 |
+
help="if the average log probability is lower than this value, treat the decoding as failed")
|
82 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, \
|
83 |
+
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
84 |
|
85 |
args = parser.parse_args().__dict__
|
86 |
model_name: str = args.pop("model")
|
|
|
107 |
vad_prompt_window = args.pop("vad_prompt_window")
|
108 |
vad_cpu_cores = args.pop("vad_cpu_cores")
|
109 |
auto_parallel = args.pop("auto_parallel")
|
110 |
+
|
111 |
+
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
|
|
112 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
113 |
transcriber.set_auto_parallel(auto_parallel)
|
114 |
|
115 |
+
model = WhisperContainer(model_name, device=device, download_root=model_dir, models=app_config.models)
|
116 |
+
|
117 |
if (transcriber._has_parallel_devices()):
|
118 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
119 |
|
config.json5
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": [
|
3 |
+
// Configuration for the built-in models. You can remove any of these
|
4 |
+
// if you don't want to use the default models.
|
5 |
+
{
|
6 |
+
"name": "tiny",
|
7 |
+
"url": "tiny"
|
8 |
+
},
|
9 |
+
{
|
10 |
+
"name": "base",
|
11 |
+
"url": "base"
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"name": "small",
|
15 |
+
"url": "small"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "medium",
|
19 |
+
"url": "medium"
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"name": "large",
|
23 |
+
"url": "large"
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "large-v2",
|
27 |
+
"url": "large-v2"
|
28 |
+
},
|
29 |
+
// Uncomment to add custom Japanese models
|
30 |
+
//{
|
31 |
+
// "name": "whisper-large-v2-mix-jp",
|
32 |
+
// "url": "vumichien/whisper-large-v2-mix-jp",
|
33 |
+
// // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
|
34 |
+
// // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
|
35 |
+
// "type": "huggingface",
|
36 |
+
//}
|
37 |
+
],
|
38 |
+
// Configuration options that will be used if they are not specified in the command line arguments.
|
39 |
+
|
40 |
+
// Maximum audio file length in seconds, or -1 for no limit.
|
41 |
+
"input_audio_max_duration": 600,
|
42 |
+
// True to share the app on HuggingFace.
|
43 |
+
"share": false,
|
44 |
+
// The host or IP to bind to. If None, bind to localhost.
|
45 |
+
"server_name": null,
|
46 |
+
// The port to bind to.
|
47 |
+
"server_port": 7860,
|
48 |
+
// The default model name.
|
49 |
+
"default_model_name": "medium",
|
50 |
+
// The default VAD.
|
51 |
+
"default_vad": "silero-vad",
|
52 |
+
// A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
|
53 |
+
"vad_parallel_devices": "",
|
54 |
+
// The number of CPU cores to use for VAD pre-processing.
|
55 |
+
"vad_cpu_cores": 1,
|
56 |
+
// The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
|
57 |
+
"vad_process_timeout": 1800,
|
58 |
+
// True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
|
59 |
+
"auto_parallel": false,
|
60 |
+
// Directory to save the outputs
|
61 |
+
"output_dir": null
|
62 |
+
}
|
requirements.txt
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
1 |
git+https://github.com/openai/whisper.git
|
2 |
transformers
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.13.0
|
5 |
yt-dlp
|
6 |
torchaudio
|
7 |
-
altair
|
|
|
|
1 |
+
git+https://github.com/huggingface/transformers
|
2 |
git+https://github.com/openai/whisper.git
|
3 |
transformers
|
4 |
ffmpeg-python==0.2.0
|
5 |
gradio==3.13.0
|
6 |
yt-dlp
|
7 |
torchaudio
|
8 |
+
altair
|
9 |
+
json5
|
src/config.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import urllib
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from src.conversion.hf_converter import convert_hf_whisper
|
10 |
+
|
11 |
+
class ModelConfig:
|
12 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
|
13 |
+
"""
|
14 |
+
Initialize a model configuration.
|
15 |
+
|
16 |
+
name: Name of the model
|
17 |
+
url: URL to download the model from
|
18 |
+
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
19 |
+
type: Type of model. Can be whisper or huggingface.
|
20 |
+
"""
|
21 |
+
self.name = name
|
22 |
+
self.url = url
|
23 |
+
self.path = path
|
24 |
+
self.type = type
|
25 |
+
|
26 |
+
def download_url(self, root_dir: str):
|
27 |
+
import whisper
|
28 |
+
|
29 |
+
# See if path is already set
|
30 |
+
if self.path is not None:
|
31 |
+
return self.path
|
32 |
+
|
33 |
+
if root_dir is None:
|
34 |
+
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
35 |
+
|
36 |
+
model_type = self.type.lower() if self.type is not None else "whisper"
|
37 |
+
|
38 |
+
if model_type in ["huggingface", "hf"]:
|
39 |
+
self.path = self.url
|
40 |
+
destination_target = os.path.join(root_dir, self.name + ".pt")
|
41 |
+
|
42 |
+
# Convert from HuggingFace format to Whisper format
|
43 |
+
if os.path.exists(destination_target):
|
44 |
+
print(f"File {destination_target} already exists, skipping conversion")
|
45 |
+
else:
|
46 |
+
print("Saving HuggingFace model in Whisper format to " + destination_target)
|
47 |
+
convert_hf_whisper(self.url, destination_target)
|
48 |
+
|
49 |
+
self.path = destination_target
|
50 |
+
|
51 |
+
elif model_type in ["whisper", "w"]:
|
52 |
+
self.path = self.url
|
53 |
+
|
54 |
+
# See if URL is just a file
|
55 |
+
if self.url in whisper._MODELS:
|
56 |
+
# No need to download anything - Whisper will handle it
|
57 |
+
self.path = self.url
|
58 |
+
elif self.url.startswith("file://"):
|
59 |
+
# Get file path
|
60 |
+
self.path = urlparse(self.url).path
|
61 |
+
# See if it is an URL
|
62 |
+
elif self.url.startswith("http://") or self.url.startswith("https://"):
|
63 |
+
# Extension (or file name)
|
64 |
+
extension = os.path.splitext(self.url)[-1]
|
65 |
+
download_target = os.path.join(root_dir, self.name + extension)
|
66 |
+
|
67 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
68 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
69 |
+
|
70 |
+
if not os.path.isfile(download_target):
|
71 |
+
self._download_file(self.url, download_target)
|
72 |
+
else:
|
73 |
+
print(f"File {download_target} already exists, skipping download")
|
74 |
+
|
75 |
+
self.path = download_target
|
76 |
+
# Must be a local file
|
77 |
+
else:
|
78 |
+
self.path = self.url
|
79 |
+
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unknown model type {model_type}")
|
82 |
+
|
83 |
+
return self.path
|
84 |
+
|
85 |
+
def _download_file(self, url: str, destination: str):
|
86 |
+
with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
|
87 |
+
with tqdm(
|
88 |
+
total=int(source.info().get("Content-Length")),
|
89 |
+
ncols=80,
|
90 |
+
unit="iB",
|
91 |
+
unit_scale=True,
|
92 |
+
unit_divisor=1024,
|
93 |
+
) as loop:
|
94 |
+
while True:
|
95 |
+
buffer = source.read(8192)
|
96 |
+
if not buffer:
|
97 |
+
break
|
98 |
+
|
99 |
+
output.write(buffer)
|
100 |
+
loop.update(len(buffer))
|
101 |
+
|
102 |
+
class ApplicationConfig:
|
103 |
+
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
104 |
+
share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium",
|
105 |
+
default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
106 |
+
auto_parallel: bool = False, output_dir: str = None):
|
107 |
+
self.models = models
|
108 |
+
self.input_audio_max_duration = input_audio_max_duration
|
109 |
+
self.share = share
|
110 |
+
self.server_name = server_name
|
111 |
+
self.server_port = server_port
|
112 |
+
self.default_model_name = default_model_name
|
113 |
+
self.default_vad = default_vad
|
114 |
+
self.vad_parallel_devices = vad_parallel_devices
|
115 |
+
self.vad_cpu_cores = vad_cpu_cores
|
116 |
+
self.vad_process_timeout = vad_process_timeout
|
117 |
+
self.auto_parallel = auto_parallel
|
118 |
+
self.output_dir = output_dir
|
119 |
+
|
120 |
+
def get_model_names(self):
|
121 |
+
return [ x.name for x in self.models ]
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def parse_file(config_path: str):
|
125 |
+
import json5
|
126 |
+
|
127 |
+
with open(config_path, "r") as f:
|
128 |
+
# Load using json5
|
129 |
+
data = json5.load(f)
|
130 |
+
data_models = data.pop("models", [])
|
131 |
+
|
132 |
+
models = [ ModelConfig(**x) for x in data_models ]
|
133 |
+
|
134 |
+
return ApplicationConfig(models, **data)
|
src/conversion/hf_converter.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
import torch
|
5 |
+
from transformers import WhisperForConditionalGeneration
|
6 |
+
|
7 |
+
WHISPER_MAPPING = {
|
8 |
+
"layers": "blocks",
|
9 |
+
"fc1": "mlp.0",
|
10 |
+
"fc2": "mlp.2",
|
11 |
+
"final_layer_norm": "mlp_ln",
|
12 |
+
"layers": "blocks",
|
13 |
+
".self_attn.q_proj": ".attn.query",
|
14 |
+
".self_attn.k_proj": ".attn.key",
|
15 |
+
".self_attn.v_proj": ".attn.value",
|
16 |
+
".self_attn_layer_norm": ".attn_ln",
|
17 |
+
".self_attn.out_proj": ".attn.out",
|
18 |
+
".encoder_attn.q_proj": ".cross_attn.query",
|
19 |
+
".encoder_attn.k_proj": ".cross_attn.key",
|
20 |
+
".encoder_attn.v_proj": ".cross_attn.value",
|
21 |
+
".encoder_attn_layer_norm": ".cross_attn_ln",
|
22 |
+
".encoder_attn.out_proj": ".cross_attn.out",
|
23 |
+
"decoder.layer_norm.": "decoder.ln.",
|
24 |
+
"encoder.layer_norm.": "encoder.ln_post.",
|
25 |
+
"embed_tokens": "token_embedding",
|
26 |
+
"encoder.embed_positions.weight": "encoder.positional_embedding",
|
27 |
+
"decoder.embed_positions.weight": "decoder.positional_embedding",
|
28 |
+
"layer_norm": "ln_post",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def rename_keys(s_dict):
|
33 |
+
keys = list(s_dict.keys())
|
34 |
+
for key in keys:
|
35 |
+
new_key = key
|
36 |
+
for k, v in WHISPER_MAPPING.items():
|
37 |
+
if k in key:
|
38 |
+
new_key = new_key.replace(k, v)
|
39 |
+
|
40 |
+
print(f"{key} -> {new_key}")
|
41 |
+
|
42 |
+
s_dict[new_key] = s_dict.pop(key)
|
43 |
+
return s_dict
|
44 |
+
|
45 |
+
|
46 |
+
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
47 |
+
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
48 |
+
config = transformer_model.config
|
49 |
+
|
50 |
+
# first build dims
|
51 |
+
dims = {
|
52 |
+
'n_mels': config.num_mel_bins,
|
53 |
+
'n_vocab': config.vocab_size,
|
54 |
+
'n_audio_ctx': config.max_source_positions,
|
55 |
+
'n_audio_state': config.d_model,
|
56 |
+
'n_audio_head': config.encoder_attention_heads,
|
57 |
+
'n_audio_layer': config.encoder_layers,
|
58 |
+
'n_text_ctx': config.max_target_positions,
|
59 |
+
'n_text_state': config.d_model,
|
60 |
+
'n_text_head': config.decoder_attention_heads,
|
61 |
+
'n_text_layer': config.decoder_layers
|
62 |
+
}
|
63 |
+
|
64 |
+
state_dict = deepcopy(transformer_model.model.state_dict())
|
65 |
+
state_dict = rename_keys(state_dict)
|
66 |
+
|
67 |
+
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
|
src/whisperContainer.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
# External programs
|
2 |
import os
|
|
|
3 |
import whisper
|
|
|
4 |
|
5 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
6 |
|
7 |
class WhisperContainer:
|
8 |
-
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
|
|
9 |
self.model_name = model_name
|
10 |
self.device = device
|
11 |
self.download_root = download_root
|
@@ -13,6 +16,9 @@ class WhisperContainer:
|
|
13 |
|
14 |
# Will be created on demand
|
15 |
self.model = None
|
|
|
|
|
|
|
16 |
|
17 |
def get_model(self):
|
18 |
if self.model is None:
|
@@ -32,21 +38,40 @@ class WhisperContainer:
|
|
32 |
# Warning: Using private API here
|
33 |
try:
|
34 |
root_dir = self.download_root
|
|
|
35 |
|
36 |
if root_dir is None:
|
37 |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
38 |
|
39 |
if self.model_name in whisper._MODELS:
|
40 |
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
|
|
|
|
|
|
|
41 |
return True
|
|
|
42 |
except Exception as e:
|
43 |
# Given that the API is private, it could change at any time. We don't want to crash the program
|
44 |
print("Error pre-downloading model: " + str(e))
|
45 |
return False
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def _create_model(self):
|
48 |
print("Loading whisper model " + self.model_name)
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
52 |
"""
|
@@ -71,12 +96,13 @@ class WhisperContainer:
|
|
71 |
|
72 |
# This is required for multiprocessing
|
73 |
def __getstate__(self):
|
74 |
-
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
|
75 |
|
76 |
def __setstate__(self, state):
|
77 |
self.model_name = state["model_name"]
|
78 |
self.device = state["device"]
|
79 |
self.download_root = state["download_root"]
|
|
|
80 |
self.model = None
|
81 |
# Depickled objects must use the global cache
|
82 |
self.cache = GLOBAL_MODEL_CACHE
|
|
|
1 |
# External programs
|
2 |
import os
|
3 |
+
from typing import List
|
4 |
import whisper
|
5 |
+
from src.config import ModelConfig
|
6 |
|
7 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
8 |
|
9 |
class WhisperContainer:
|
10 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
11 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
12 |
self.model_name = model_name
|
13 |
self.device = device
|
14 |
self.download_root = download_root
|
|
|
16 |
|
17 |
# Will be created on demand
|
18 |
self.model = None
|
19 |
+
|
20 |
+
# List of known models
|
21 |
+
self.models = models
|
22 |
|
23 |
def get_model(self):
|
24 |
if self.model is None:
|
|
|
38 |
# Warning: Using private API here
|
39 |
try:
|
40 |
root_dir = self.download_root
|
41 |
+
model_config = self.get_model_config()
|
42 |
|
43 |
if root_dir is None:
|
44 |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
45 |
|
46 |
if self.model_name in whisper._MODELS:
|
47 |
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
|
48 |
+
else:
|
49 |
+
# If the model is not in the official list, see if it needs to be downloaded
|
50 |
+
model_config.download_url(root_dir)
|
51 |
return True
|
52 |
+
|
53 |
except Exception as e:
|
54 |
# Given that the API is private, it could change at any time. We don't want to crash the program
|
55 |
print("Error pre-downloading model: " + str(e))
|
56 |
return False
|
57 |
|
58 |
+
def get_model_config(self) -> ModelConfig:
|
59 |
+
"""
|
60 |
+
Get the model configuration for the model.
|
61 |
+
"""
|
62 |
+
for model in self.models:
|
63 |
+
if model.name == self.model_name:
|
64 |
+
return model
|
65 |
+
return None
|
66 |
+
|
67 |
def _create_model(self):
|
68 |
print("Loading whisper model " + self.model_name)
|
69 |
+
|
70 |
+
model_config = self.get_model_config()
|
71 |
+
# Note that the model will not be downloaded in the case of an official Whisper model
|
72 |
+
model_path = model_config.download_url(self.download_root)
|
73 |
+
|
74 |
+
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
75 |
|
76 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
77 |
"""
|
|
|
96 |
|
97 |
# This is required for multiprocessing
|
98 |
def __getstate__(self):
|
99 |
+
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
|
100 |
|
101 |
def __setstate__(self, state):
|
102 |
self.model_name = state["model_name"]
|
103 |
self.device = state["device"]
|
104 |
self.download_root = state["download_root"]
|
105 |
+
self.models = state["models"]
|
106 |
self.model = None
|
107 |
# Depickled objects must use the global cache
|
108 |
self.cache = GLOBAL_MODEL_CACHE
|