Spaces:
Running
Running
Merge master
Browse files- .github/ISSUE_TEMPLATE/bug_report.md +1 -1
- .github/ISSUE_TEMPLATE/feature_request.md +1 -1
- .github/ISSUE_TEMPLATE/hallucination.md +12 -0
- .github/pull_request_template.md +5 -0
- .github/workflows/{shell-scrpit-test.yml → ci-shell.yml} +23 -19
- .github/workflows/ci.yml +41 -0
- .gitignore +2 -0
- app.py +32 -22
- configs/default_parameters.yaml +4 -4
- modules/translation/deepl_api.py +26 -26
- modules/translation/nllb_inference.py +13 -3
- modules/translation/translation_base.py +9 -7
- modules/utils/subtitle_manager.py +0 -3
- modules/utils/youtube_manager.py +19 -1
- modules/whisper/faster_whisper_inference.py +3 -3
- modules/whisper/insanely_fast_whisper_inference.py +17 -10
- modules/whisper/whisper_Inference.py +2 -2
- modules/whisper/whisper_base.py +29 -32
- modules/whisper/whisper_parameter.py +64 -54
- requirements.txt +5 -4
- tests/test_bgm_separation.py +53 -0
- tests/test_config.py +17 -0
- tests/test_diarization.py +31 -0
- tests/test_srt.srt +7 -0
- tests/test_transcription.py +97 -0
- tests/test_translation.py +52 -0
- tests/test_vad.py +26 -0
- tests/test_vtt.vtt +6 -0
.github/ISSUE_TEMPLATE/bug_report.md
CHANGED
@@ -3,7 +3,7 @@ name: Bug report
|
|
3 |
about: Create a report to help us improve
|
4 |
title: ''
|
5 |
labels: bug
|
6 |
-
assignees:
|
7 |
|
8 |
---
|
9 |
|
|
|
3 |
about: Create a report to help us improve
|
4 |
title: ''
|
5 |
labels: bug
|
6 |
+
assignees: jhj0517
|
7 |
|
8 |
---
|
9 |
|
.github/ISSUE_TEMPLATE/feature_request.md
CHANGED
@@ -3,7 +3,7 @@ name: Feature request
|
|
3 |
about: Any feature you want
|
4 |
title: ''
|
5 |
labels: enhancement
|
6 |
-
assignees:
|
7 |
|
8 |
---
|
9 |
|
|
|
3 |
about: Any feature you want
|
4 |
title: ''
|
5 |
labels: enhancement
|
6 |
+
assignees: jhj0517
|
7 |
|
8 |
---
|
9 |
|
.github/ISSUE_TEMPLATE/hallucination.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Hallucination
|
3 |
+
about: Whisper hallucinations. ( Repeating certain words or subtitles starting too
|
4 |
+
early, etc. )
|
5 |
+
title: ''
|
6 |
+
labels: hallucination
|
7 |
+
assignees: jhj0517
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
**Download URL for sample audio**
|
12 |
+
- Please upload download URL for sample audio file so I can test with some settings for better result. You can use https://easyupload.io/ or any other service to share.
|
.github/pull_request_template.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Related issues
|
2 |
+
- #0
|
3 |
+
|
4 |
+
## Changed
|
5 |
+
1. Changes
|
.github/workflows/{shell-scrpit-test.yml → ci-shell.yml}
RENAMED
@@ -1,38 +1,42 @@
|
|
1 |
-
name: Shell Script
|
2 |
|
3 |
on:
|
|
|
|
|
4 |
push:
|
5 |
-
branches:
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
9 |
|
10 |
jobs:
|
11 |
test-shell-script:
|
|
|
12 |
runs-on: ubuntu-latest
|
|
|
|
|
|
|
|
|
13 |
steps:
|
14 |
-
- name:
|
15 |
-
|
16 |
|
17 |
-
-
|
18 |
-
|
|
|
19 |
with:
|
20 |
-
python-version: ${{
|
21 |
|
22 |
-
- name:
|
23 |
-
|
24 |
-
id: setup-ffmpeg
|
25 |
-
with:
|
26 |
-
ffmpeg-version: release
|
27 |
-
architecture: 'arm64'
|
28 |
-
linking-type: static
|
29 |
|
30 |
-
- name:
|
31 |
run: |
|
32 |
chmod +x ./Install.sh
|
33 |
./Install.sh
|
34 |
|
35 |
-
- name:
|
36 |
run: |
|
37 |
chmod +x ./start-webui.sh
|
38 |
timeout 60s ./start-webui.sh || true
|
|
|
1 |
+
name: CI-Shell Script
|
2 |
|
3 |
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
|
6 |
push:
|
7 |
+
branches:
|
8 |
+
- master
|
9 |
+
pull_request:
|
10 |
+
branches:
|
11 |
+
- master
|
12 |
|
13 |
jobs:
|
14 |
test-shell-script:
|
15 |
+
|
16 |
runs-on: ubuntu-latest
|
17 |
+
strategy:
|
18 |
+
matrix:
|
19 |
+
python: [ "3.10" ]
|
20 |
+
|
21 |
steps:
|
22 |
+
- name: Clean up space for action
|
23 |
+
run: rm -rf /opt/hostedtoolcache
|
24 |
|
25 |
+
- uses: actions/checkout@v4
|
26 |
+
- name: Setup Python
|
27 |
+
uses: actions/setup-python@v5
|
28 |
with:
|
29 |
+
python-version: ${{ matrix.python }}
|
30 |
|
31 |
+
- name: Install git and ffmpeg
|
32 |
+
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
- name: Execute Install.sh
|
35 |
run: |
|
36 |
chmod +x ./Install.sh
|
37 |
./Install.sh
|
38 |
|
39 |
+
- name: Execute start-webui.sh
|
40 |
run: |
|
41 |
chmod +x ./start-webui.sh
|
42 |
timeout 60s ./start-webui.sh || true
|
.github/workflows/ci.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
|
6 |
+
push:
|
7 |
+
branches:
|
8 |
+
- master
|
9 |
+
pull_request:
|
10 |
+
branches:
|
11 |
+
- master
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
build:
|
15 |
+
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
strategy:
|
18 |
+
matrix:
|
19 |
+
python: ["3.10"]
|
20 |
+
|
21 |
+
env:
|
22 |
+
DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
|
23 |
+
|
24 |
+
steps:
|
25 |
+
- name: Clean up space for action
|
26 |
+
run: rm -rf /opt/hostedtoolcache
|
27 |
+
|
28 |
+
- uses: actions/checkout@v4
|
29 |
+
- name: Setup Python
|
30 |
+
uses: actions/setup-python@v5
|
31 |
+
with:
|
32 |
+
python-version: ${{ matrix.python }}
|
33 |
+
|
34 |
+
- name: Install git and ffmpeg
|
35 |
+
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
|
36 |
+
|
37 |
+
- name: Install dependencies
|
38 |
+
run: pip install -r requirements.txt pytest
|
39 |
+
|
40 |
+
- name: Run test
|
41 |
+
run: python -m pytest -rs tests
|
.gitignore
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
*.png
|
3 |
*.mp4
|
4 |
*.mp3
|
|
|
|
|
5 |
venv/
|
6 |
modules/ui/__pycache__/
|
7 |
outputs/
|
|
|
2 |
*.png
|
3 |
*.mp4
|
4 |
*.mp3
|
5 |
+
.idea/
|
6 |
+
.pytest_cache/
|
7 |
venv/
|
8 |
modules/ui/__pycache__/
|
9 |
outputs/
|
app.py
CHANGED
@@ -21,7 +21,7 @@ from modules.whisper.whisper_parameter import *
|
|
21 |
class App:
|
22 |
def __init__(self, args):
|
23 |
self.args = args
|
24 |
-
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
25 |
self.whisper_inf = WhisperFactory.create_whisper_inference(
|
26 |
whisper_type=self.args.whisper_type,
|
27 |
whisper_model_dir=self.args.whisper_model_dir,
|
@@ -59,6 +59,7 @@ class App:
|
|
59 |
with gr.Row():
|
60 |
cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add a timestamp to the end of the filename",
|
61 |
interactive=True)
|
|
|
62 |
with gr.Accordion("Advanced Parameters", open=False):
|
63 |
nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
|
64 |
info="Beam size to use for decoding.")
|
@@ -68,6 +69,7 @@ class App:
|
|
68 |
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
69 |
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
70 |
value=self.whisper_inf.current_compute_type, interactive=True,
|
|
|
71 |
info="Select the type of computation to perform.")
|
72 |
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
|
73 |
info="Number of candidates when sampling with non-zero temperature.")
|
@@ -88,6 +90,9 @@ class App:
|
|
88 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
|
89 |
interactive=True,
|
90 |
info="If the gzip compression ratio is above this value, treat as failed.")
|
|
|
|
|
|
|
91 |
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
92 |
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
|
93 |
info="Exponential length penalty constant.")
|
@@ -113,9 +118,6 @@ class App:
|
|
113 |
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
|
114 |
precision=0,
|
115 |
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
116 |
-
nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: whisper_params["chunk_length"],
|
117 |
-
precision=0,
|
118 |
-
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
119 |
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
120 |
value=lambda: whisper_params["hallucination_silence_threshold"],
|
121 |
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
@@ -127,32 +129,37 @@ class App:
|
|
127 |
precision=0,
|
128 |
info="Number of segments to consider for the language detection.")
|
129 |
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
130 |
-
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=whisper_params["chunk_length_s"],
|
131 |
-
precision=0)
|
132 |
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
|
133 |
|
134 |
-
with gr.Accordion("
|
135 |
-
cb_bgm_separation = gr.Checkbox(label="Enable
|
136 |
-
interactive=True
|
|
|
|
|
137 |
dd_uvr_device = gr.Dropdown(label="Device", value=self.whisper_inf.music_separator.device,
|
138 |
choices=self.whisper_inf.music_separator.available_devices)
|
139 |
dd_uvr_model_size = gr.Dropdown(label="Model", value=uvr_params["model_size"],
|
140 |
choices=self.whisper_inf.music_separator.available_models)
|
141 |
nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
|
142 |
cb_uvr_save_file = gr.Checkbox(label="Save separated files to output", value=uvr_params["save_file"])
|
|
|
|
|
143 |
|
144 |
-
with gr.Accordion("
|
145 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
|
146 |
-
interactive=True
|
147 |
-
|
|
|
|
|
148 |
info="Lower it to be more sensitive to small sounds.")
|
149 |
-
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
|
|
|
150 |
info="Final speech chunks shorter than this time are thrown out")
|
151 |
-
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
info="In the end of each speech chunk wait for this time"
|
157 |
" before separating it")
|
158 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
|
@@ -161,7 +168,10 @@ class App:
|
|
161 |
with gr.Accordion("Diarization", open=False):
|
162 |
cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
|
163 |
tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
|
164 |
-
info="This is only needed the first time you download the model. If you already have
|
|
|
|
|
|
|
165 |
dd_diarization_device = gr.Dropdown(label="Device",
|
166 |
choices=self.whisper_inf.diarizer.get_available_device(),
|
167 |
value=self.whisper_inf.diarizer.get_device())
|
@@ -177,19 +187,19 @@ class App:
|
|
177 |
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
|
178 |
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
|
179 |
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
|
180 |
-
speech_pad_ms=nb_speech_pad_ms,
|
181 |
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
|
182 |
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
|
183 |
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
|
184 |
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
|
185 |
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
|
186 |
-
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
|
187 |
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
|
188 |
language_detection_threshold=nb_language_detection_threshold,
|
189 |
language_detection_segments=nb_language_detection_segments,
|
190 |
prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
|
191 |
uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
|
192 |
-
uvr_save_file=cb_uvr_save_file
|
193 |
),
|
194 |
dd_file_format,
|
195 |
cb_timestamp
|
|
|
21 |
class App:
|
22 |
def __init__(self, args):
|
23 |
self.args = args
|
24 |
+
self.app = gr.Blocks(css=CSS, theme=self.args.theme, delete_cache=(60, 3600))
|
25 |
self.whisper_inf = WhisperFactory.create_whisper_inference(
|
26 |
whisper_type=self.args.whisper_type,
|
27 |
whisper_model_dir=self.args.whisper_model_dir,
|
|
|
59 |
with gr.Row():
|
60 |
cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add a timestamp to the end of the filename",
|
61 |
interactive=True)
|
62 |
+
|
63 |
with gr.Accordion("Advanced Parameters", open=False):
|
64 |
nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
|
65 |
info="Beam size to use for decoding.")
|
|
|
69 |
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
70 |
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
71 |
value=self.whisper_inf.current_compute_type, interactive=True,
|
72 |
+
allow_custom_value=True,
|
73 |
info="Select the type of computation to perform.")
|
74 |
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
|
75 |
info="Number of candidates when sampling with non-zero temperature.")
|
|
|
90 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
|
91 |
interactive=True,
|
92 |
info="If the gzip compression ratio is above this value, treat as failed.")
|
93 |
+
nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
|
94 |
+
precision=0,
|
95 |
+
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
96 |
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
97 |
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
|
98 |
info="Exponential length penalty constant.")
|
|
|
118 |
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
|
119 |
precision=0,
|
120 |
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
|
|
|
|
|
|
121 |
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
122 |
value=lambda: whisper_params["hallucination_silence_threshold"],
|
123 |
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
|
|
129 |
precision=0,
|
130 |
info="Number of segments to consider for the language detection.")
|
131 |
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
|
|
|
|
132 |
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
|
133 |
|
134 |
+
with gr.Accordion("Background Music Remover Filter", open=False):
|
135 |
+
cb_bgm_separation = gr.Checkbox(label="Enable Background Music Remover Filter", value=uvr_params["is_separate_bgm"],
|
136 |
+
interactive=True,
|
137 |
+
info="Enabling this will remove background music by submodel before"
|
138 |
+
" transcribing ")
|
139 |
dd_uvr_device = gr.Dropdown(label="Device", value=self.whisper_inf.music_separator.device,
|
140 |
choices=self.whisper_inf.music_separator.available_devices)
|
141 |
dd_uvr_model_size = gr.Dropdown(label="Model", value=uvr_params["model_size"],
|
142 |
choices=self.whisper_inf.music_separator.available_models)
|
143 |
nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
|
144 |
cb_uvr_save_file = gr.Checkbox(label="Save separated files to output", value=uvr_params["save_file"])
|
145 |
+
cb_uvr_enable_offload = gr.Checkbox(label="Offload sub model after removing background music",
|
146 |
+
value=uvr_params["enable_offload"])
|
147 |
|
148 |
+
with gr.Accordion("Voice Detection Filter", open=False):
|
149 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
|
150 |
+
interactive=True,
|
151 |
+
info="Enable this to transcribe only detected voice parts by submodel.")
|
152 |
+
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
153 |
+
value=vad_params["threshold"],
|
154 |
info="Lower it to be more sensitive to small sounds.")
|
155 |
+
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
|
156 |
+
value=vad_params["min_speech_duration_ms"],
|
157 |
info="Final speech chunks shorter than this time are thrown out")
|
158 |
+
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
|
159 |
+
value=vad_params["max_speech_duration_s"],
|
160 |
+
info="Maximum duration of speech chunks in \"seconds\".")
|
161 |
+
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
|
162 |
+
value=vad_params["min_silence_duration_ms"],
|
163 |
info="In the end of each speech chunk wait for this time"
|
164 |
" before separating it")
|
165 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
|
|
|
168 |
with gr.Accordion("Diarization", open=False):
|
169 |
cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
|
170 |
tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
|
171 |
+
info="This is only needed the first time you download the model. If you already have"
|
172 |
+
" models, you don't need to enter. To download the model, you must manually go "
|
173 |
+
"to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to"
|
174 |
+
" their requirement.")
|
175 |
dd_diarization_device = gr.Dropdown(label="Device",
|
176 |
choices=self.whisper_inf.diarizer.get_available_device(),
|
177 |
value=self.whisper_inf.diarizer.get_device())
|
|
|
187 |
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
|
188 |
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
|
189 |
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
|
190 |
+
speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
|
191 |
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
|
192 |
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
|
193 |
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
|
194 |
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
|
195 |
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
|
196 |
+
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
|
197 |
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
|
198 |
language_detection_threshold=nb_language_detection_threshold,
|
199 |
language_detection_segments=nb_language_detection_segments,
|
200 |
prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
|
201 |
uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
|
202 |
+
uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
|
203 |
),
|
204 |
dd_file_format,
|
205 |
cb_timestamp
|
configs/default_parameters.yaml
CHANGED
@@ -12,7 +12,7 @@ whisper:
|
|
12 |
initial_prompt: null
|
13 |
temperature: 0
|
14 |
compression_ratio_threshold: 2.4
|
15 |
-
|
16 |
batch_size: 24
|
17 |
length_penalty: 1
|
18 |
repetition_penalty: 1
|
@@ -25,7 +25,6 @@ whisper:
|
|
25 |
prepend_punctuations: "\"'“¿([{-"
|
26 |
append_punctuations: "\"'.。,,!!??::”)]}、"
|
27 |
max_new_tokens: null
|
28 |
-
chunk_length: null
|
29 |
hallucination_silence_threshold: null
|
30 |
hotwords: null
|
31 |
language_detection_threshold: null
|
@@ -37,8 +36,8 @@ vad:
|
|
37 |
threshold: 0.5
|
38 |
min_speech_duration_ms: 250
|
39 |
max_speech_duration_s: 9999
|
40 |
-
min_silence_duration_ms:
|
41 |
-
speech_pad_ms:
|
42 |
|
43 |
diarization:
|
44 |
is_diarize: false
|
@@ -49,6 +48,7 @@ bgm_separation:
|
|
49 |
model_size: "UVR-MDX-NET-Inst_HQ_4"
|
50 |
segment_size: 256
|
51 |
save_file: false
|
|
|
52 |
|
53 |
translation:
|
54 |
deepl:
|
|
|
12 |
initial_prompt: null
|
13 |
temperature: 0
|
14 |
compression_ratio_threshold: 2.4
|
15 |
+
chunk_length: 30
|
16 |
batch_size: 24
|
17 |
length_penalty: 1
|
18 |
repetition_penalty: 1
|
|
|
25 |
prepend_punctuations: "\"'“¿([{-"
|
26 |
append_punctuations: "\"'.。,,!!??::”)]}、"
|
27 |
max_new_tokens: null
|
|
|
28 |
hallucination_silence_threshold: null
|
29 |
hotwords: null
|
30 |
language_detection_threshold: null
|
|
|
36 |
threshold: 0.5
|
37 |
min_speech_duration_ms: 250
|
38 |
max_speech_duration_s: 9999
|
39 |
+
min_silence_duration_ms: 1000
|
40 |
+
speech_pad_ms: 2000
|
41 |
|
42 |
diarization:
|
43 |
is_diarize: false
|
|
|
48 |
model_size: "UVR-MDX-NET-Inst_HQ_4"
|
49 |
segment_size: 256
|
50 |
save_file: false
|
51 |
+
enable_offload: true
|
52 |
|
53 |
translation:
|
54 |
deepl:
|
modules/translation/deepl_api.py
CHANGED
@@ -98,8 +98,8 @@ class DeepLAPI:
|
|
98 |
fileobjs: list,
|
99 |
source_lang: str,
|
100 |
target_lang: str,
|
101 |
-
is_pro: bool,
|
102 |
-
add_timestamp: bool,
|
103 |
progress=gr.Progress()) -> list:
|
104 |
"""
|
105 |
Translate subtitle files using DeepL API
|
@@ -126,6 +126,9 @@ class DeepLAPI:
|
|
126 |
String to return to gr.Textbox()
|
127 |
Files to return to gr.Files()
|
128 |
"""
|
|
|
|
|
|
|
129 |
self.cache_parameters(
|
130 |
api_key=auth_key,
|
131 |
is_pro=is_pro,
|
@@ -136,37 +139,28 @@ class DeepLAPI:
|
|
136 |
|
137 |
files_info = {}
|
138 |
for fileobj in fileobjs:
|
139 |
-
file_path = fileobj
|
140 |
-
file_name, file_ext = os.path.splitext(os.path.basename(fileobj
|
141 |
|
142 |
if file_ext == ".srt":
|
143 |
parsed_dicts = parse_srt(file_path=file_path)
|
144 |
|
145 |
-
batch_size = self.max_text_batch_size
|
146 |
-
for batch_start in range(0, len(parsed_dicts), batch_size):
|
147 |
-
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
148 |
-
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
149 |
-
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
150 |
-
target_lang, is_pro)
|
151 |
-
for i, translated_text in enumerate(translated_texts):
|
152 |
-
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
153 |
-
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
154 |
-
|
155 |
-
subtitle = get_serialized_srt(parsed_dicts)
|
156 |
-
|
157 |
elif file_ext == ".vtt":
|
158 |
parsed_dicts = parse_vtt(file_path=file_path)
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
|
|
|
|
|
|
|
170 |
subtitle = get_serialized_vtt(parsed_dicts)
|
171 |
|
172 |
if add_timestamp:
|
@@ -193,8 +187,14 @@ class DeepLAPI:
|
|
193 |
text: list,
|
194 |
source_lang: str,
|
195 |
target_lang: str,
|
196 |
-
is_pro: bool):
|
197 |
"""Request API response to DeepL server"""
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
|
200 |
headers = {
|
|
|
98 |
fileobjs: list,
|
99 |
source_lang: str,
|
100 |
target_lang: str,
|
101 |
+
is_pro: bool = False,
|
102 |
+
add_timestamp: bool = True,
|
103 |
progress=gr.Progress()) -> list:
|
104 |
"""
|
105 |
Translate subtitle files using DeepL API
|
|
|
126 |
String to return to gr.Textbox()
|
127 |
Files to return to gr.Files()
|
128 |
"""
|
129 |
+
if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
|
130 |
+
fileobjs = [fileobj.name for fileobj in fileobjs]
|
131 |
+
|
132 |
self.cache_parameters(
|
133 |
api_key=auth_key,
|
134 |
is_pro=is_pro,
|
|
|
139 |
|
140 |
files_info = {}
|
141 |
for fileobj in fileobjs:
|
142 |
+
file_path = fileobj
|
143 |
+
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
144 |
|
145 |
if file_ext == ".srt":
|
146 |
parsed_dicts = parse_srt(file_path=file_path)
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
elif file_ext == ".vtt":
|
149 |
parsed_dicts = parse_vtt(file_path=file_path)
|
150 |
|
151 |
+
batch_size = self.max_text_batch_size
|
152 |
+
for batch_start in range(0, len(parsed_dicts), batch_size):
|
153 |
+
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
154 |
+
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
155 |
+
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
156 |
+
target_lang, is_pro)
|
157 |
+
for i, translated_text in enumerate(translated_texts):
|
158 |
+
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
159 |
+
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
160 |
|
161 |
+
if file_ext == ".srt":
|
162 |
+
subtitle = get_serialized_srt(parsed_dicts)
|
163 |
+
elif file_ext == ".vtt":
|
164 |
subtitle = get_serialized_vtt(parsed_dicts)
|
165 |
|
166 |
if add_timestamp:
|
|
|
187 |
text: list,
|
188 |
source_lang: str,
|
189 |
target_lang: str,
|
190 |
+
is_pro: bool = False):
|
191 |
"""Request API response to DeepL server"""
|
192 |
+
if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()):
|
193 |
+
raise ValueError(f"Source language {source_lang} is not supported."
|
194 |
+
f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}")
|
195 |
+
if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()):
|
196 |
+
raise ValueError(f"Target language {target_lang} is not supported."
|
197 |
+
f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}")
|
198 |
|
199 |
url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
|
200 |
headers = {
|
modules/translation/nllb_inference.py
CHANGED
@@ -38,8 +38,19 @@ class NLLBInference(TranslationBase):
|
|
38 |
model_size: str,
|
39 |
src_lang: str,
|
40 |
tgt_lang: str,
|
41 |
-
progress: gr.Progress
|
42 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if model_size != self.current_model_size or self.model is None:
|
44 |
print("\nInitializing NLLB Model..\n")
|
45 |
progress(0, desc="Initializing NLLB Model..")
|
@@ -51,8 +62,7 @@ class NLLBInference(TranslationBase):
|
|
51 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
52 |
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
53 |
local_files_only=local_files_only)
|
54 |
-
|
55 |
-
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
56 |
self.pipeline = pipeline("translation",
|
57 |
model=self.model,
|
58 |
tokenizer=self.tokenizer,
|
|
|
38 |
model_size: str,
|
39 |
src_lang: str,
|
40 |
tgt_lang: str,
|
41 |
+
progress: gr.Progress = gr.Progress()
|
42 |
):
|
43 |
+
def validate_language(lang: str) -> str:
|
44 |
+
if lang in NLLB_AVAILABLE_LANGS:
|
45 |
+
return NLLB_AVAILABLE_LANGS[lang]
|
46 |
+
elif lang not in NLLB_AVAILABLE_LANGS.values():
|
47 |
+
raise ValueError(
|
48 |
+
f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
|
49 |
+
return lang
|
50 |
+
|
51 |
+
src_lang = validate_language(src_lang)
|
52 |
+
tgt_lang = validate_language(tgt_lang)
|
53 |
+
|
54 |
if model_size != self.current_model_size or self.model is None:
|
55 |
print("\nInitializing NLLB Model..\n")
|
56 |
progress(0, desc="Initializing NLLB Model..")
|
|
|
62 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
63 |
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
64 |
local_files_only=local_files_only)
|
65 |
+
|
|
|
66 |
self.pipeline = pipeline("translation",
|
67 |
model=self.model,
|
68 |
tokenizer=self.tokenizer,
|
modules/translation/translation_base.py
CHANGED
@@ -40,7 +40,7 @@ class TranslationBase(ABC):
|
|
40 |
model_size: str,
|
41 |
src_lang: str,
|
42 |
tgt_lang: str,
|
43 |
-
progress: gr.Progress
|
44 |
):
|
45 |
pass
|
46 |
|
@@ -50,8 +50,8 @@ class TranslationBase(ABC):
|
|
50 |
model_size: str,
|
51 |
src_lang: str,
|
52 |
tgt_lang: str,
|
53 |
-
max_length: int,
|
54 |
-
add_timestamp: bool,
|
55 |
progress=gr.Progress()) -> list:
|
56 |
"""
|
57 |
Translate subtitle file from source language to target language
|
@@ -81,6 +81,9 @@ class TranslationBase(ABC):
|
|
81 |
Files to return to gr.Files()
|
82 |
"""
|
83 |
try:
|
|
|
|
|
|
|
84 |
self.cache_parameters(model_size=model_size,
|
85 |
src_lang=src_lang,
|
86 |
tgt_lang=tgt_lang,
|
@@ -94,10 +97,9 @@ class TranslationBase(ABC):
|
|
94 |
|
95 |
files_info = {}
|
96 |
for fileobj in fileobjs:
|
97 |
-
|
98 |
-
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
99 |
if file_ext == ".srt":
|
100 |
-
parsed_dicts = parse_srt(file_path=
|
101 |
total_progress = len(parsed_dicts)
|
102 |
for index, dic in enumerate(parsed_dicts):
|
103 |
progress(index / total_progress, desc="Translating..")
|
@@ -106,7 +108,7 @@ class TranslationBase(ABC):
|
|
106 |
subtitle = get_serialized_srt(parsed_dicts)
|
107 |
|
108 |
elif file_ext == ".vtt":
|
109 |
-
parsed_dicts = parse_vtt(file_path=
|
110 |
total_progress = len(parsed_dicts)
|
111 |
for index, dic in enumerate(parsed_dicts):
|
112 |
progress(index / total_progress, desc="Translating..")
|
|
|
40 |
model_size: str,
|
41 |
src_lang: str,
|
42 |
tgt_lang: str,
|
43 |
+
progress: gr.Progress = gr.Progress()
|
44 |
):
|
45 |
pass
|
46 |
|
|
|
50 |
model_size: str,
|
51 |
src_lang: str,
|
52 |
tgt_lang: str,
|
53 |
+
max_length: int = 200,
|
54 |
+
add_timestamp: bool = True,
|
55 |
progress=gr.Progress()) -> list:
|
56 |
"""
|
57 |
Translate subtitle file from source language to target language
|
|
|
81 |
Files to return to gr.Files()
|
82 |
"""
|
83 |
try:
|
84 |
+
if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
|
85 |
+
fileobjs = [file.name for file in fileobjs]
|
86 |
+
|
87 |
self.cache_parameters(model_size=model_size,
|
88 |
src_lang=src_lang,
|
89 |
tgt_lang=tgt_lang,
|
|
|
97 |
|
98 |
files_info = {}
|
99 |
for fileobj in fileobjs:
|
100 |
+
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
|
|
101 |
if file_ext == ".srt":
|
102 |
+
parsed_dicts = parse_srt(file_path=fileobj)
|
103 |
total_progress = len(parsed_dicts)
|
104 |
for index, dic in enumerate(parsed_dicts):
|
105 |
progress(index / total_progress, desc="Translating..")
|
|
|
108 |
subtitle = get_serialized_srt(parsed_dicts)
|
109 |
|
110 |
elif file_ext == ".vtt":
|
111 |
+
parsed_dicts = parse_vtt(file_path=fileobj)
|
112 |
total_progress = len(parsed_dicts)
|
113 |
for index, dic in enumerate(parsed_dicts):
|
114 |
progress(index / total_progress, desc="Translating..")
|
modules/utils/subtitle_manager.py
CHANGED
@@ -121,11 +121,8 @@ def get_serialized_vtt(dicts):
|
|
121 |
|
122 |
@spaces.GPU(duration=120)
|
123 |
def safe_filename(name):
|
124 |
-
from app import _args
|
125 |
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
126 |
safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
|
127 |
-
if not _args.colab:
|
128 |
-
return safe_name
|
129 |
# Truncate the filename if it exceeds the max_length (20)
|
130 |
if len(safe_name) > 20:
|
131 |
file_extension = safe_name.split('.')[-1]
|
|
|
121 |
|
122 |
@spaces.GPU(duration=120)
|
123 |
def safe_filename(name):
|
|
|
124 |
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
125 |
safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
|
|
|
|
|
126 |
# Truncate the filename if it exceeds the max_length (20)
|
127 |
if len(safe_name) > 20:
|
128 |
file_extension = safe_name.split('.')[-1]
|
modules/utils/youtube_manager.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from pytubefix import YouTube
|
|
|
2 |
import os
|
3 |
|
4 |
|
@@ -12,4 +13,21 @@ def get_ytmetas(link):
|
|
12 |
|
13 |
|
14 |
def get_ytaudio(ytdata: YouTube):
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pytubefix import YouTube
|
2 |
+
import subprocess
|
3 |
import os
|
4 |
|
5 |
|
|
|
13 |
|
14 |
|
15 |
def get_ytaudio(ytdata: YouTube):
|
16 |
+
# Somehow the audio is corrupted so need to convert to valid audio file.
|
17 |
+
# Fix for : https://github.com/jhj0517/Whisper-WebUI/issues/304
|
18 |
+
|
19 |
+
audio_path = ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
|
20 |
+
temp_audio_path = os.path.join("modules", "yt_tmp_fixed.wav")
|
21 |
+
|
22 |
+
try:
|
23 |
+
subprocess.run([
|
24 |
+
'ffmpeg', '-y',
|
25 |
+
'-i', audio_path,
|
26 |
+
temp_audio_path
|
27 |
+
], check=True)
|
28 |
+
|
29 |
+
os.replace(temp_audio_path, audio_path)
|
30 |
+
return audio_path
|
31 |
+
except subprocess.CalledProcessError as e:
|
32 |
+
print(f"Error during ffmpeg conversion: {e}")
|
33 |
+
return None
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
|
|
40 |
|
41 |
def transcribe(self,
|
42 |
audio: Union[str, BinaryIO, np.ndarray],
|
43 |
-
progress: gr.Progress,
|
44 |
*whisper_params,
|
45 |
) -> Tuple[List[dict], float]:
|
46 |
"""
|
@@ -126,7 +126,7 @@ class FasterWhisperInference(WhisperBase):
|
|
126 |
def update_model(self,
|
127 |
model_size: str,
|
128 |
compute_type: str,
|
129 |
-
progress: gr.Progress
|
130 |
):
|
131 |
"""
|
132 |
Update current model setting
|
@@ -159,7 +159,7 @@ class FasterWhisperInference(WhisperBase):
|
|
159 |
----------
|
160 |
Name list of models
|
161 |
"""
|
162 |
-
model_paths = {model:model for model in
|
163 |
faster_whisper_prefix = "models--Systran--faster-whisper-"
|
164 |
|
165 |
existing_models = os.listdir(self.model_dir)
|
|
|
40 |
|
41 |
def transcribe(self,
|
42 |
audio: Union[str, BinaryIO, np.ndarray],
|
43 |
+
progress: gr.Progress = gr.Progress(),
|
44 |
*whisper_params,
|
45 |
) -> Tuple[List[dict], float]:
|
46 |
"""
|
|
|
126 |
def update_model(self,
|
127 |
model_size: str,
|
128 |
compute_type: str,
|
129 |
+
progress: gr.Progress = gr.Progress()
|
130 |
):
|
131 |
"""
|
132 |
Update current model setting
|
|
|
159 |
----------
|
160 |
Name list of models
|
161 |
"""
|
162 |
+
model_paths = {model:model for model in faster_whisper.available_models()}
|
163 |
faster_whisper_prefix = "models--Systran--faster-whisper-"
|
164 |
|
165 |
existing_models = os.listdir(self.model_dir)
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -39,7 +39,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
39 |
|
40 |
def transcribe(self,
|
41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
42 |
-
progress: gr.Progress,
|
43 |
*whisper_params,
|
44 |
) -> Tuple[List[dict], float]:
|
45 |
"""
|
@@ -75,18 +75,25 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
75 |
) as progress:
|
76 |
progress.add_task("[yellow]Transcribing...", total=None)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
segments = self.model(
|
79 |
inputs=audio,
|
80 |
return_timestamps=True,
|
81 |
-
chunk_length_s=params.
|
82 |
batch_size=params.batch_size,
|
83 |
-
generate_kwargs=
|
84 |
-
"language": params.lang,
|
85 |
-
"task": "translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
86 |
-
"no_speech_threshold": params.no_speech_threshold,
|
87 |
-
"temperature": params.temperature,
|
88 |
-
"compression_ratio_threshold": params.compression_ratio_threshold
|
89 |
-
}
|
90 |
)
|
91 |
|
92 |
segments_result = self.format_result(
|
@@ -98,7 +105,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
98 |
def update_model(self,
|
99 |
model_size: str,
|
100 |
compute_type: str,
|
101 |
-
progress: gr.Progress,
|
102 |
):
|
103 |
"""
|
104 |
Update current model setting
|
|
|
39 |
|
40 |
def transcribe(self,
|
41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
42 |
+
progress: gr.Progress = gr.Progress(),
|
43 |
*whisper_params,
|
44 |
) -> Tuple[List[dict], float]:
|
45 |
"""
|
|
|
75 |
) as progress:
|
76 |
progress.add_task("[yellow]Transcribing...", total=None)
|
77 |
|
78 |
+
kwargs = {
|
79 |
+
"no_speech_threshold": params.no_speech_threshold,
|
80 |
+
"temperature": params.temperature,
|
81 |
+
"compression_ratio_threshold": params.compression_ratio_threshold,
|
82 |
+
"logprob_threshold": params.log_prob_threshold,
|
83 |
+
}
|
84 |
+
|
85 |
+
if self.current_model_size.endswith(".en"):
|
86 |
+
pass
|
87 |
+
else:
|
88 |
+
kwargs["language"] = params.lang
|
89 |
+
kwargs["task"] = "translate" if params.is_translate else "transcribe"
|
90 |
+
|
91 |
segments = self.model(
|
92 |
inputs=audio,
|
93 |
return_timestamps=True,
|
94 |
+
chunk_length_s=params.chunk_length,
|
95 |
batch_size=params.batch_size,
|
96 |
+
generate_kwargs=kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
98 |
|
99 |
segments_result = self.format_result(
|
|
|
105 |
def update_model(self,
|
106 |
model_size: str,
|
107 |
compute_type: str,
|
108 |
+
progress: gr.Progress = gr.Progress(),
|
109 |
):
|
110 |
"""
|
111 |
Update current model setting
|
modules/whisper/whisper_Inference.py
CHANGED
@@ -28,7 +28,7 @@ class WhisperInference(WhisperBase):
|
|
28 |
|
29 |
def transcribe(self,
|
30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
31 |
-
progress: gr.Progress,
|
32 |
*whisper_params,
|
33 |
) -> Tuple[List[dict], float]:
|
34 |
"""
|
@@ -79,7 +79,7 @@ class WhisperInference(WhisperBase):
|
|
79 |
def update_model(self,
|
80 |
model_size: str,
|
81 |
compute_type: str,
|
82 |
-
progress: gr.Progress,
|
83 |
):
|
84 |
"""
|
85 |
Update current model setting
|
|
|
28 |
|
29 |
def transcribe(self,
|
30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
31 |
+
progress: gr.Progress = gr.Progress(),
|
32 |
*whisper_params,
|
33 |
) -> Tuple[List[dict], float]:
|
34 |
"""
|
|
|
79 |
def update_model(self,
|
80 |
model_size: str,
|
81 |
compute_type: str,
|
82 |
+
progress: gr.Progress = gr.Progress(),
|
83 |
):
|
84 |
"""
|
85 |
Update current model setting
|
modules/whisper/whisper_base.py
CHANGED
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
|
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
55 |
audio: Union[str, BinaryIO, np.ndarray],
|
56 |
-
progress: gr.Progress,
|
57 |
*whisper_params,
|
58 |
):
|
59 |
"""Inference whisper model to transcribe"""
|
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
|
|
63 |
def update_model(self,
|
64 |
model_size: str,
|
65 |
compute_type: str,
|
66 |
-
progress: gr.Progress
|
67 |
):
|
68 |
"""Initialize whisper model"""
|
69 |
pass
|
@@ -104,7 +104,9 @@ class WhisperBase(ABC):
|
|
104 |
add_timestamp=add_timestamp
|
105 |
)
|
106 |
|
107 |
-
if params.lang
|
|
|
|
|
108 |
params.lang = None
|
109 |
else:
|
110 |
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
@@ -128,11 +130,12 @@ class WhisperBase(ABC):
|
|
128 |
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
129 |
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
130 |
|
131 |
-
|
|
|
132 |
|
133 |
if params.vad_filter:
|
134 |
# Explicit value set for float('inf') from gr.Number()
|
135 |
-
if params.max_speech_duration_s >= 9999:
|
136 |
params.max_speech_duration_s = float('inf')
|
137 |
|
138 |
vad_options = VadOptions(
|
@@ -171,10 +174,10 @@ class WhisperBase(ABC):
|
|
171 |
return result, elapsed_time
|
172 |
|
173 |
def transcribe_file(self,
|
174 |
-
files:
|
175 |
-
input_folder_path: str,
|
176 |
-
file_format: str,
|
177 |
-
add_timestamp: bool,
|
178 |
progress=gr.Progress(),
|
179 |
*whisper_params,
|
180 |
) -> list:
|
@@ -207,18 +210,21 @@ class WhisperBase(ABC):
|
|
207 |
try:
|
208 |
if input_folder_path:
|
209 |
files = get_media_files(input_folder_path)
|
210 |
-
|
|
|
|
|
|
|
211 |
|
212 |
files_info = {}
|
213 |
for file in files:
|
214 |
transcribed_segments, time_for_task = self.run(
|
215 |
-
file
|
216 |
progress,
|
217 |
add_timestamp,
|
218 |
*whisper_params,
|
219 |
)
|
220 |
|
221 |
-
file_name, file_ext = os.path.splitext(os.path.basename(file
|
222 |
subtitle, file_path = self.generate_and_write_file(
|
223 |
file_name=file_name,
|
224 |
transcribed_segments=transcribed_segments,
|
@@ -245,13 +251,11 @@ class WhisperBase(ABC):
|
|
245 |
print(f"Error transcribing file: {e}")
|
246 |
finally:
|
247 |
self.release_cuda_memory()
|
248 |
-
if not files:
|
249 |
-
self.remove_input_files([file.name for file in files])
|
250 |
|
251 |
def transcribe_mic(self,
|
252 |
mic_audio: str,
|
253 |
-
file_format: str,
|
254 |
-
add_timestamp: bool,
|
255 |
progress=gr.Progress(),
|
256 |
*whisper_params,
|
257 |
) -> list:
|
@@ -302,12 +306,11 @@ class WhisperBase(ABC):
|
|
302 |
print(f"Error transcribing file: {e}")
|
303 |
finally:
|
304 |
self.release_cuda_memory()
|
305 |
-
self.remove_input_files([mic_audio])
|
306 |
|
307 |
def transcribe_youtube(self,
|
308 |
youtube_link: str,
|
309 |
-
file_format: str,
|
310 |
-
add_timestamp: bool,
|
311 |
progress=gr.Progress(),
|
312 |
*whisper_params,
|
313 |
) -> list:
|
@@ -358,22 +361,15 @@ class WhisperBase(ABC):
|
|
358 |
)
|
359 |
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
360 |
|
|
|
|
|
|
|
361 |
return [result_str, result_file_path]
|
362 |
|
363 |
except Exception as e:
|
364 |
print(f"Error transcribing file: {e}")
|
365 |
finally:
|
366 |
-
|
367 |
-
if 'yt' not in locals():
|
368 |
-
yt = get_ytdata(youtube_link)
|
369 |
-
file_path = get_ytaudio(yt)
|
370 |
-
else:
|
371 |
-
file_path = get_ytaudio(yt)
|
372 |
-
|
373 |
-
self.release_cuda_memory()
|
374 |
-
self.remove_input_files([file_path])
|
375 |
-
except Exception as cleanup_error:
|
376 |
-
pass
|
377 |
|
378 |
@staticmethod
|
379 |
def generate_and_write_file(file_name: str,
|
@@ -411,11 +407,12 @@ class WhisperBase(ABC):
|
|
411 |
else:
|
412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
413 |
|
414 |
-
|
|
|
415 |
content = get_srt(transcribed_segments)
|
416 |
output_path += '.srt'
|
417 |
|
418 |
-
elif file_format == "
|
419 |
content = get_vtt(transcribed_segments)
|
420 |
output_path += '.vtt'
|
421 |
|
|
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
55 |
audio: Union[str, BinaryIO, np.ndarray],
|
56 |
+
progress: gr.Progress = gr.Progress(),
|
57 |
*whisper_params,
|
58 |
):
|
59 |
"""Inference whisper model to transcribe"""
|
|
|
63 |
def update_model(self,
|
64 |
model_size: str,
|
65 |
compute_type: str,
|
66 |
+
progress: gr.Progress = gr.Progress()
|
67 |
):
|
68 |
"""Initialize whisper model"""
|
69 |
pass
|
|
|
104 |
add_timestamp=add_timestamp
|
105 |
)
|
106 |
|
107 |
+
if params.lang is None:
|
108 |
+
pass
|
109 |
+
elif params.lang == "Automatic Detection":
|
110 |
params.lang = None
|
111 |
else:
|
112 |
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
|
|
130 |
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
131 |
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
132 |
|
133 |
+
if params.uvr_enable_offload:
|
134 |
+
self.music_separator.offload()
|
135 |
|
136 |
if params.vad_filter:
|
137 |
# Explicit value set for float('inf') from gr.Number()
|
138 |
+
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
|
139 |
params.max_speech_duration_s = float('inf')
|
140 |
|
141 |
vad_options = VadOptions(
|
|
|
174 |
return result, elapsed_time
|
175 |
|
176 |
def transcribe_file(self,
|
177 |
+
files: Optional[List] = None,
|
178 |
+
input_folder_path: Optional[str] = None,
|
179 |
+
file_format: str = "SRT",
|
180 |
+
add_timestamp: bool = True,
|
181 |
progress=gr.Progress(),
|
182 |
*whisper_params,
|
183 |
) -> list:
|
|
|
210 |
try:
|
211 |
if input_folder_path:
|
212 |
files = get_media_files(input_folder_path)
|
213 |
+
if isinstance(files, str):
|
214 |
+
files = [files]
|
215 |
+
if files and isinstance(files[0], gr.utils.NamedString):
|
216 |
+
files = [file.name for file in files]
|
217 |
|
218 |
files_info = {}
|
219 |
for file in files:
|
220 |
transcribed_segments, time_for_task = self.run(
|
221 |
+
file,
|
222 |
progress,
|
223 |
add_timestamp,
|
224 |
*whisper_params,
|
225 |
)
|
226 |
|
227 |
+
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
228 |
subtitle, file_path = self.generate_and_write_file(
|
229 |
file_name=file_name,
|
230 |
transcribed_segments=transcribed_segments,
|
|
|
251 |
print(f"Error transcribing file: {e}")
|
252 |
finally:
|
253 |
self.release_cuda_memory()
|
|
|
|
|
254 |
|
255 |
def transcribe_mic(self,
|
256 |
mic_audio: str,
|
257 |
+
file_format: str = "SRT",
|
258 |
+
add_timestamp: bool = True,
|
259 |
progress=gr.Progress(),
|
260 |
*whisper_params,
|
261 |
) -> list:
|
|
|
306 |
print(f"Error transcribing file: {e}")
|
307 |
finally:
|
308 |
self.release_cuda_memory()
|
|
|
309 |
|
310 |
def transcribe_youtube(self,
|
311 |
youtube_link: str,
|
312 |
+
file_format: str = "SRT",
|
313 |
+
add_timestamp: bool = True,
|
314 |
progress=gr.Progress(),
|
315 |
*whisper_params,
|
316 |
) -> list:
|
|
|
361 |
)
|
362 |
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
363 |
|
364 |
+
if os.path.exists(audio):
|
365 |
+
os.remove(audio)
|
366 |
+
|
367 |
return [result_str, result_file_path]
|
368 |
|
369 |
except Exception as e:
|
370 |
print(f"Error transcribing file: {e}")
|
371 |
finally:
|
372 |
+
self.release_cuda_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
@staticmethod
|
375 |
def generate_and_write_file(file_name: str,
|
|
|
407 |
else:
|
408 |
output_path = os.path.join(output_dir, f"{file_name}")
|
409 |
|
410 |
+
file_format = file_format.strip().lower()
|
411 |
+
if file_format == "srt":
|
412 |
content = get_srt(transcribed_segments)
|
413 |
output_path += '.srt'
|
414 |
|
415 |
+
elif file_format == "webvtt":
|
416 |
content = get_vtt(transcribed_segments)
|
417 |
output_path += '.vtt'
|
418 |
|
modules/whisper/whisper_parameter.py
CHANGED
@@ -26,7 +26,6 @@ class WhisperParameters:
|
|
26 |
max_speech_duration_s: gr.Number
|
27 |
min_silence_duration_ms: gr.Number
|
28 |
speech_pad_ms: gr.Number
|
29 |
-
chunk_length_s: gr.Number
|
30 |
batch_size: gr.Number
|
31 |
is_diarize: gr.Checkbox
|
32 |
hf_token: gr.Textbox
|
@@ -52,6 +51,7 @@ class WhisperParameters:
|
|
52 |
uvr_device: gr.Dropdown
|
53 |
uvr_segment_size: gr.Number
|
54 |
uvr_save_file: gr.Checkbox
|
|
|
55 |
"""
|
56 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
57 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
@@ -136,10 +136,6 @@ class WhisperParameters:
|
|
136 |
speech_pad_ms: gr.Number
|
137 |
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
138 |
|
139 |
-
chunk_length_s: gr.Number
|
140 |
-
This parameter is related with insanely-fast-whisper pipe.
|
141 |
-
Maximum length of each chunk
|
142 |
-
|
143 |
batch_size: gr.Number
|
144 |
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
145 |
|
@@ -193,8 +189,8 @@ class WhisperParameters:
|
|
193 |
the maximum will be set by the default max_length.
|
194 |
|
195 |
chunk_length: gr.Number
|
196 |
-
This parameter is related to faster-whisper. The length of audio segments
|
197 |
-
|
198 |
|
199 |
hallucination_silence_threshold: gr.Number
|
200 |
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
@@ -223,6 +219,10 @@ class WhisperParameters:
|
|
223 |
|
224 |
uvr_save_file: gr.Checkbox
|
225 |
This parameter is related to UVR. Boolean value that determines whether to save the file or not.
|
|
|
|
|
|
|
|
|
226 |
"""
|
227 |
|
228 |
def as_list(self) -> list:
|
@@ -252,52 +252,52 @@ class WhisperParameters:
|
|
252 |
|
253 |
@dataclass
|
254 |
class WhisperValues:
|
255 |
-
model_size: str
|
256 |
-
lang: str
|
257 |
-
is_translate: bool
|
258 |
-
beam_size: int
|
259 |
-
log_prob_threshold: float
|
260 |
-
no_speech_threshold: float
|
261 |
-
compute_type: str
|
262 |
-
best_of: int
|
263 |
-
patience: float
|
264 |
-
condition_on_previous_text: bool
|
265 |
-
prompt_reset_on_temperature: float
|
266 |
-
initial_prompt: Optional[str]
|
267 |
-
temperature: float
|
268 |
-
compression_ratio_threshold: float
|
269 |
-
vad_filter: bool
|
270 |
-
threshold: float
|
271 |
-
min_speech_duration_ms: int
|
272 |
-
max_speech_duration_s: float
|
273 |
-
min_silence_duration_ms: int
|
274 |
-
speech_pad_ms: int
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
"""
|
302 |
A data class to use Whisper parameters.
|
303 |
"""
|
@@ -318,7 +318,6 @@ class WhisperValues:
|
|
318 |
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
|
319 |
"temperature": self.temperature,
|
320 |
"compression_ratio_threshold": self.compression_ratio_threshold,
|
321 |
-
"chunk_length_s": None if self.chunk_length_s is None else self.chunk_length_s,
|
322 |
"batch_size": self.batch_size,
|
323 |
"length_penalty": self.length_penalty,
|
324 |
"repetition_penalty": self.repetition_penalty,
|
@@ -354,6 +353,17 @@ class WhisperValues:
|
|
354 |
"model_size": self.uvr_model_size,
|
355 |
"segment_size": self.uvr_segment_size,
|
356 |
"save_file": self.uvr_save_file,
|
|
|
357 |
},
|
358 |
}
|
359 |
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
max_speech_duration_s: gr.Number
|
27 |
min_silence_duration_ms: gr.Number
|
28 |
speech_pad_ms: gr.Number
|
|
|
29 |
batch_size: gr.Number
|
30 |
is_diarize: gr.Checkbox
|
31 |
hf_token: gr.Textbox
|
|
|
51 |
uvr_device: gr.Dropdown
|
52 |
uvr_segment_size: gr.Number
|
53 |
uvr_save_file: gr.Checkbox
|
54 |
+
uvr_enable_offload: gr.Checkbox
|
55 |
"""
|
56 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
57 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
|
136 |
speech_pad_ms: gr.Number
|
137 |
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
138 |
|
|
|
|
|
|
|
|
|
139 |
batch_size: gr.Number
|
140 |
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
141 |
|
|
|
189 |
the maximum will be set by the default max_length.
|
190 |
|
191 |
chunk_length: gr.Number
|
192 |
+
This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
|
193 |
+
If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
|
194 |
|
195 |
hallucination_silence_threshold: gr.Number
|
196 |
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
|
|
219 |
|
220 |
uvr_save_file: gr.Checkbox
|
221 |
This parameter is related to UVR. Boolean value that determines whether to save the file or not.
|
222 |
+
|
223 |
+
uvr_enable_offload: gr.Checkbox
|
224 |
+
This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
|
225 |
+
after each transcription.
|
226 |
"""
|
227 |
|
228 |
def as_list(self) -> list:
|
|
|
252 |
|
253 |
@dataclass
|
254 |
class WhisperValues:
|
255 |
+
model_size: str = "large-v2"
|
256 |
+
lang: Optional[str] = None
|
257 |
+
is_translate: bool = False
|
258 |
+
beam_size: int = 5
|
259 |
+
log_prob_threshold: float = -1.0
|
260 |
+
no_speech_threshold: float = 0.6
|
261 |
+
compute_type: str = "float16"
|
262 |
+
best_of: int = 5
|
263 |
+
patience: float = 1.0
|
264 |
+
condition_on_previous_text: bool = True
|
265 |
+
prompt_reset_on_temperature: float = 0.5
|
266 |
+
initial_prompt: Optional[str] = None
|
267 |
+
temperature: float = 0.0
|
268 |
+
compression_ratio_threshold: float = 2.4
|
269 |
+
vad_filter: bool = False
|
270 |
+
threshold: float = 0.5
|
271 |
+
min_speech_duration_ms: int = 250
|
272 |
+
max_speech_duration_s: float = float("inf")
|
273 |
+
min_silence_duration_ms: int = 2000
|
274 |
+
speech_pad_ms: int = 400
|
275 |
+
batch_size: int = 24
|
276 |
+
is_diarize: bool = False
|
277 |
+
hf_token: str = ""
|
278 |
+
diarization_device: str = "cuda"
|
279 |
+
length_penalty: float = 1.0
|
280 |
+
repetition_penalty: float = 1.0
|
281 |
+
no_repeat_ngram_size: int = 0
|
282 |
+
prefix: Optional[str] = None
|
283 |
+
suppress_blank: bool = True
|
284 |
+
suppress_tokens: Optional[str] = "[-1]"
|
285 |
+
max_initial_timestamp: float = 0.0
|
286 |
+
word_timestamps: bool = False
|
287 |
+
prepend_punctuations: Optional[str] = "\"'“¿([{-"
|
288 |
+
append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
|
289 |
+
max_new_tokens: Optional[int] = None
|
290 |
+
chunk_length: Optional[int] = 30
|
291 |
+
hallucination_silence_threshold: Optional[float] = None
|
292 |
+
hotwords: Optional[str] = None
|
293 |
+
language_detection_threshold: Optional[float] = None
|
294 |
+
language_detection_segments: int = 1
|
295 |
+
is_bgm_separate: bool = False
|
296 |
+
uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
|
297 |
+
uvr_device: str = "cuda"
|
298 |
+
uvr_segment_size: int = 256
|
299 |
+
uvr_save_file: bool = False
|
300 |
+
uvr_enable_offload: bool = True
|
301 |
"""
|
302 |
A data class to use Whisper parameters.
|
303 |
"""
|
|
|
318 |
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
|
319 |
"temperature": self.temperature,
|
320 |
"compression_ratio_threshold": self.compression_ratio_threshold,
|
|
|
321 |
"batch_size": self.batch_size,
|
322 |
"length_penalty": self.length_penalty,
|
323 |
"repetition_penalty": self.repetition_penalty,
|
|
|
353 |
"model_size": self.uvr_model_size,
|
354 |
"segment_size": self.uvr_segment_size,
|
355 |
"save_file": self.uvr_save_file,
|
356 |
+
"enable_offload": self.uvr_enable_offload
|
357 |
},
|
358 |
}
|
359 |
return data
|
360 |
+
|
361 |
+
def as_list(self) -> list:
|
362 |
+
"""
|
363 |
+
Converts the data class attributes into a list
|
364 |
+
|
365 |
+
Returns
|
366 |
+
----------
|
367 |
+
A list of Whisper parameters
|
368 |
+
"""
|
369 |
+
return [getattr(self, f.name) for f in fields(self)]
|
requirements.txt
CHANGED
@@ -2,14 +2,15 @@
|
|
2 |
# If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement):
|
3 |
# For CUDA 12.1, use : https://download.pytorch.org/whl/cu121
|
4 |
# For CUDA 12.4, use : https://download.pytorch.org/whl/cu124
|
5 |
-
--extra-index-url https://download.pytorch.org/whl/
|
6 |
|
7 |
|
8 |
-
torch
|
|
|
9 |
git+https://github.com/jhj0517/jhj0517-whisper.git
|
10 |
faster-whisper==1.0.3
|
11 |
-
transformers
|
12 |
-
gradio
|
13 |
pytubefix
|
14 |
ruamel.yaml==0.18.6
|
15 |
pyannote.audio==3.3.1
|
|
|
2 |
# If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement):
|
3 |
# For CUDA 12.1, use : https://download.pytorch.org/whl/cu121
|
4 |
# For CUDA 12.4, use : https://download.pytorch.org/whl/cu124
|
5 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
6 |
|
7 |
|
8 |
+
torch==2.3.1
|
9 |
+
torchaudio==2.3.1
|
10 |
git+https://github.com/jhj0517/jhj0517-whisper.git
|
11 |
faster-whisper==1.0.3
|
12 |
+
transformers
|
13 |
+
gradio
|
14 |
pytubefix
|
15 |
ruamel.yaml==0.18.6
|
16 |
pyannote.audio==3.3.1
|
tests/test_bgm_separation.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils.paths import *
|
2 |
+
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
+
from modules.whisper.whisper_parameter import WhisperValues
|
4 |
+
from test_config import *
|
5 |
+
from test_transcription import download_file, test_transcribe
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import pytest
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.mark.skipif(
|
14 |
+
not is_cuda_available(),
|
15 |
+
reason="Skipping because the test only works on GPU"
|
16 |
+
)
|
17 |
+
@pytest.mark.parametrize(
|
18 |
+
"whisper_type,vad_filter,bgm_separation,diarization",
|
19 |
+
[
|
20 |
+
("whisper", False, True, False),
|
21 |
+
("faster-whisper", False, True, False),
|
22 |
+
("insanely_fast_whisper", False, True, False)
|
23 |
+
]
|
24 |
+
)
|
25 |
+
def test_bgm_separation_pipeline(
|
26 |
+
whisper_type: str,
|
27 |
+
vad_filter: bool,
|
28 |
+
bgm_separation: bool,
|
29 |
+
diarization: bool,
|
30 |
+
):
|
31 |
+
test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.mark.skipif(
|
35 |
+
not is_cuda_available(),
|
36 |
+
reason="Skipping because the test only works on GPU"
|
37 |
+
)
|
38 |
+
@pytest.mark.parametrize(
|
39 |
+
"whisper_type,vad_filter,bgm_separation,diarization",
|
40 |
+
[
|
41 |
+
("whisper", True, True, False),
|
42 |
+
("faster-whisper", True, True, False),
|
43 |
+
("insanely_fast_whisper", True, True, False)
|
44 |
+
]
|
45 |
+
)
|
46 |
+
def test_bgm_separation_with_vad_pipeline(
|
47 |
+
whisper_type: str,
|
48 |
+
vad_filter: bool,
|
49 |
+
bgm_separation: bool,
|
50 |
+
diarization: bool,
|
51 |
+
):
|
52 |
+
test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
|
53 |
+
|
tests/test_config.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils.paths import *
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
7 |
+
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
8 |
+
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
|
9 |
+
TEST_WHISPER_MODEL = "tiny"
|
10 |
+
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
|
11 |
+
TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
|
12 |
+
TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
13 |
+
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
|
14 |
+
|
15 |
+
|
16 |
+
def is_cuda_available():
|
17 |
+
return torch.cuda.is_available()
|
tests/test_diarization.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils.paths import *
|
2 |
+
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
+
from modules.whisper.whisper_parameter import WhisperValues
|
4 |
+
from test_config import *
|
5 |
+
from test_transcription import download_file, test_transcribe
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import pytest
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.mark.skipif(
|
13 |
+
not is_cuda_available(),
|
14 |
+
reason="Skipping because the test only works on GPU"
|
15 |
+
)
|
16 |
+
@pytest.mark.parametrize(
|
17 |
+
"whisper_type,vad_filter,bgm_separation,diarization",
|
18 |
+
[
|
19 |
+
("whisper", False, False, True),
|
20 |
+
("faster-whisper", False, False, True),
|
21 |
+
("insanely_fast_whisper", False, False, True)
|
22 |
+
]
|
23 |
+
)
|
24 |
+
def test_diarization_pipeline(
|
25 |
+
whisper_type: str,
|
26 |
+
vad_filter: bool,
|
27 |
+
bgm_separation: bool,
|
28 |
+
diarization: bool,
|
29 |
+
):
|
30 |
+
test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
|
31 |
+
|
tests/test_srt.srt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1
|
2 |
+
00:00:00,000 --> 00:00:02,240
|
3 |
+
You've got
|
4 |
+
|
5 |
+
2
|
6 |
+
00:00:02,240 --> 00:00:04,160
|
7 |
+
a friend in me.
|
tests/test_transcription.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.whisper.whisper_factory import WhisperFactory
|
2 |
+
from modules.whisper.whisper_parameter import WhisperValues
|
3 |
+
from modules.utils.paths import WEBUI_DIR
|
4 |
+
from test_config import *
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import pytest
|
8 |
+
import gradio as gr
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.mark.parametrize(
|
13 |
+
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
+
[
|
15 |
+
("whisper", False, False, False),
|
16 |
+
("faster-whisper", False, False, False),
|
17 |
+
("insanely_fast_whisper", False, False, False)
|
18 |
+
]
|
19 |
+
)
|
20 |
+
def test_transcribe(
|
21 |
+
whisper_type: str,
|
22 |
+
vad_filter: bool,
|
23 |
+
bgm_separation: bool,
|
24 |
+
diarization: bool,
|
25 |
+
):
|
26 |
+
audio_path_dir = os.path.join(WEBUI_DIR, "tests")
|
27 |
+
audio_path = os.path.join(audio_path_dir, "jfk.wav")
|
28 |
+
if not os.path.exists(audio_path):
|
29 |
+
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
30 |
+
|
31 |
+
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
32 |
+
whisper_type=whisper_type,
|
33 |
+
)
|
34 |
+
print(
|
35 |
+
f"""Whisper Device : {whisper_inferencer.device}\n"""
|
36 |
+
f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n"""
|
37 |
+
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
|
38 |
+
)
|
39 |
+
|
40 |
+
hparams = WhisperValues(
|
41 |
+
model_size=TEST_WHISPER_MODEL,
|
42 |
+
vad_filter=vad_filter,
|
43 |
+
is_bgm_separate=bgm_separation,
|
44 |
+
compute_type=whisper_inferencer.current_compute_type,
|
45 |
+
uvr_enable_offload=True,
|
46 |
+
is_diarize=diarization,
|
47 |
+
).as_list()
|
48 |
+
|
49 |
+
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
50 |
+
[audio_path],
|
51 |
+
None,
|
52 |
+
"SRT",
|
53 |
+
False,
|
54 |
+
gr.Progress(),
|
55 |
+
*hparams,
|
56 |
+
)
|
57 |
+
|
58 |
+
assert isinstance(subtitle_str, str) and subtitle_str
|
59 |
+
assert isinstance(file_path[0], str) and file_path
|
60 |
+
|
61 |
+
whisper_inferencer.transcribe_youtube(
|
62 |
+
TEST_YOUTUBE_URL,
|
63 |
+
"SRT",
|
64 |
+
False,
|
65 |
+
gr.Progress(),
|
66 |
+
*hparams,
|
67 |
+
)
|
68 |
+
assert isinstance(subtitle_str, str) and subtitle_str
|
69 |
+
assert isinstance(file_path[0], str) and file_path
|
70 |
+
|
71 |
+
whisper_inferencer.transcribe_mic(
|
72 |
+
audio_path,
|
73 |
+
"SRT",
|
74 |
+
False,
|
75 |
+
gr.Progress(),
|
76 |
+
*hparams,
|
77 |
+
)
|
78 |
+
assert isinstance(subtitle_str, str) and subtitle_str
|
79 |
+
assert isinstance(file_path[0], str) and file_path
|
80 |
+
|
81 |
+
|
82 |
+
def download_file(url, save_dir):
|
83 |
+
if os.path.exists(TEST_FILE_PATH):
|
84 |
+
return
|
85 |
+
|
86 |
+
if not os.path.exists(save_dir):
|
87 |
+
os.makedirs(save_dir)
|
88 |
+
|
89 |
+
file_name = url.split("/")[-1]
|
90 |
+
file_path = os.path.join(save_dir, file_name)
|
91 |
+
|
92 |
+
response = requests.get(url)
|
93 |
+
|
94 |
+
with open(file_path, "wb") as file:
|
95 |
+
file.write(response.content)
|
96 |
+
|
97 |
+
print(f"File downloaded to: {file_path}")
|
tests/test_translation.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.translation.deepl_api import DeepLAPI
|
2 |
+
from modules.translation.nllb_inference import NLLBInference
|
3 |
+
from test_config import *
|
4 |
+
|
5 |
+
import os
|
6 |
+
import pytest
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.mark.parametrize("model_size, file_path", [
|
10 |
+
(TEST_NLLB_MODEL, TEST_SUBTITLE_SRT_PATH),
|
11 |
+
(TEST_NLLB_MODEL, TEST_SUBTITLE_VTT_PATH),
|
12 |
+
])
|
13 |
+
def test_nllb_inference(
|
14 |
+
model_size: str,
|
15 |
+
file_path: str
|
16 |
+
):
|
17 |
+
nllb_inferencer = NLLBInference()
|
18 |
+
print(f"NLLB Device : {nllb_inferencer.device}")
|
19 |
+
|
20 |
+
result_str, file_paths = nllb_inferencer.translate_file(
|
21 |
+
fileobjs=[file_path],
|
22 |
+
model_size=model_size,
|
23 |
+
src_lang="eng_Latn",
|
24 |
+
tgt_lang="kor_Hang",
|
25 |
+
)
|
26 |
+
|
27 |
+
assert isinstance(result_str, str)
|
28 |
+
assert isinstance(file_paths[0], str)
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.mark.parametrize("file_path", [
|
32 |
+
TEST_SUBTITLE_SRT_PATH,
|
33 |
+
TEST_SUBTITLE_VTT_PATH,
|
34 |
+
])
|
35 |
+
def test_deepl_api(
|
36 |
+
file_path: str
|
37 |
+
):
|
38 |
+
deepl_api = DeepLAPI()
|
39 |
+
|
40 |
+
api_key = os.getenv("DEEPL_API_KEY")
|
41 |
+
|
42 |
+
result_str, file_paths = deepl_api.translate_deepl(
|
43 |
+
auth_key=api_key,
|
44 |
+
fileobjs=[file_path],
|
45 |
+
source_lang="English",
|
46 |
+
target_lang="Korean",
|
47 |
+
is_pro=False,
|
48 |
+
add_timestamp=True,
|
49 |
+
)
|
50 |
+
|
51 |
+
assert isinstance(result_str, str)
|
52 |
+
assert isinstance(file_paths[0], str)
|
tests/test_vad.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.utils.paths import *
|
2 |
+
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
+
from modules.whisper.whisper_parameter import WhisperValues
|
4 |
+
from test_config import *
|
5 |
+
from test_transcription import download_file, test_transcribe
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import pytest
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.mark.parametrize(
|
13 |
+
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
+
[
|
15 |
+
("whisper", True, False, False),
|
16 |
+
("faster-whisper", True, False, False),
|
17 |
+
("insanely_fast_whisper", True, False, False)
|
18 |
+
]
|
19 |
+
)
|
20 |
+
def test_vad_pipeline(
|
21 |
+
whisper_type: str,
|
22 |
+
vad_filter: bool,
|
23 |
+
bgm_separation: bool,
|
24 |
+
diarization: bool,
|
25 |
+
):
|
26 |
+
test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
|
tests/test_vtt.vtt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WEBVTT
|
2 |
+
00:00:00.500 --> 00:00:02.000
|
3 |
+
You've got
|
4 |
+
|
5 |
+
00:00:02.500 --> 00:00:04.300
|
6 |
+
a friend in me.
|