jhj0517 commited on
Commit
451ca33
·
2 Parent(s): 602c60a 20f48aa

Merge master

Browse files
.github/ISSUE_TEMPLATE/bug_report.md CHANGED
@@ -3,7 +3,7 @@ name: Bug report
3
  about: Create a report to help us improve
4
  title: ''
5
  labels: bug
6
- assignees: ''
7
 
8
  ---
9
 
 
3
  about: Create a report to help us improve
4
  title: ''
5
  labels: bug
6
+ assignees: jhj0517
7
 
8
  ---
9
 
.github/ISSUE_TEMPLATE/feature_request.md CHANGED
@@ -3,7 +3,7 @@ name: Feature request
3
  about: Any feature you want
4
  title: ''
5
  labels: enhancement
6
- assignees: ''
7
 
8
  ---
9
 
 
3
  about: Any feature you want
4
  title: ''
5
  labels: enhancement
6
+ assignees: jhj0517
7
 
8
  ---
9
 
.github/ISSUE_TEMPLATE/hallucination.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Hallucination
3
+ about: Whisper hallucinations. ( Repeating certain words or subtitles starting too
4
+ early, etc. )
5
+ title: ''
6
+ labels: hallucination
7
+ assignees: jhj0517
8
+
9
+ ---
10
+
11
+ **Download URL for sample audio**
12
+ - Please upload download URL for sample audio file so I can test with some settings for better result. You can use https://easyupload.io/ or any other service to share.
.github/pull_request_template.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ## Related issues
2
+ - #0
3
+
4
+ ## Changed
5
+ 1. Changes
.github/workflows/{shell-scrpit-test.yml → ci-shell.yml} RENAMED
@@ -1,38 +1,42 @@
1
- name: Shell Script Test
2
 
3
  on:
 
 
4
  push:
5
- branches: ["feature/shell-script"]
6
-
7
- env:
8
- PYTHON_VERSION: '3.9'
 
9
 
10
  jobs:
11
  test-shell-script:
 
12
  runs-on: ubuntu-latest
 
 
 
 
13
  steps:
14
- - name: 'Checkout GitHub Action'
15
- uses: actions/checkout@v3
16
 
17
- - name: Setup Python ${{ env.PYTHON_VERSION }} Environment
18
- uses: actions/setup-python@v4
 
19
  with:
20
- python-version: ${{ env.PYTHON_VERSION }}
21
 
22
- - name: 'Setup FFmpeg'
23
- uses: FedericoCarboni/setup-ffmpeg@v3
24
- id: setup-ffmpeg
25
- with:
26
- ffmpeg-version: release
27
- architecture: 'arm64'
28
- linking-type: static
29
 
30
- - name: 'Execute Install.sh'
31
  run: |
32
  chmod +x ./Install.sh
33
  ./Install.sh
34
 
35
- - name: 'Execute start-webui.sh'
36
  run: |
37
  chmod +x ./start-webui.sh
38
  timeout 60s ./start-webui.sh || true
 
1
+ name: CI-Shell Script
2
 
3
  on:
4
+ workflow_dispatch:
5
+
6
  push:
7
+ branches:
8
+ - master
9
+ pull_request:
10
+ branches:
11
+ - master
12
 
13
  jobs:
14
  test-shell-script:
15
+
16
  runs-on: ubuntu-latest
17
+ strategy:
18
+ matrix:
19
+ python: [ "3.10" ]
20
+
21
  steps:
22
+ - name: Clean up space for action
23
+ run: rm -rf /opt/hostedtoolcache
24
 
25
+ - uses: actions/checkout@v4
26
+ - name: Setup Python
27
+ uses: actions/setup-python@v5
28
  with:
29
+ python-version: ${{ matrix.python }}
30
 
31
+ - name: Install git and ffmpeg
32
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
 
 
 
 
 
33
 
34
+ - name: Execute Install.sh
35
  run: |
36
  chmod +x ./Install.sh
37
  ./Install.sh
38
 
39
+ - name: Execute start-webui.sh
40
  run: |
41
  chmod +x ./start-webui.sh
42
  timeout 60s ./start-webui.sh || true
.github/workflows/ci.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ workflow_dispatch:
5
+
6
+ push:
7
+ branches:
8
+ - master
9
+ pull_request:
10
+ branches:
11
+ - master
12
+
13
+ jobs:
14
+ build:
15
+
16
+ runs-on: ubuntu-latest
17
+ strategy:
18
+ matrix:
19
+ python: ["3.10"]
20
+
21
+ env:
22
+ DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
23
+
24
+ steps:
25
+ - name: Clean up space for action
26
+ run: rm -rf /opt/hostedtoolcache
27
+
28
+ - uses: actions/checkout@v4
29
+ - name: Setup Python
30
+ uses: actions/setup-python@v5
31
+ with:
32
+ python-version: ${{ matrix.python }}
33
+
34
+ - name: Install git and ffmpeg
35
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
36
+
37
+ - name: Install dependencies
38
+ run: pip install -r requirements.txt pytest
39
+
40
+ - name: Run test
41
+ run: python -m pytest -rs tests
.gitignore CHANGED
@@ -2,6 +2,8 @@
2
  *.png
3
  *.mp4
4
  *.mp3
 
 
5
  venv/
6
  modules/ui/__pycache__/
7
  outputs/
 
2
  *.png
3
  *.mp4
4
  *.mp3
5
+ .idea/
6
+ .pytest_cache/
7
  venv/
8
  modules/ui/__pycache__/
9
  outputs/
app.py CHANGED
@@ -21,7 +21,7 @@ from modules.whisper.whisper_parameter import *
21
  class App:
22
  def __init__(self, args):
23
  self.args = args
24
- self.app = gr.Blocks(css=CSS, theme=self.args.theme)
25
  self.whisper_inf = WhisperFactory.create_whisper_inference(
26
  whisper_type=self.args.whisper_type,
27
  whisper_model_dir=self.args.whisper_model_dir,
@@ -59,6 +59,7 @@ class App:
59
  with gr.Row():
60
  cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add a timestamp to the end of the filename",
61
  interactive=True)
 
62
  with gr.Accordion("Advanced Parameters", open=False):
63
  nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
64
  info="Beam size to use for decoding.")
@@ -68,6 +69,7 @@ class App:
68
  info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
69
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
70
  value=self.whisper_inf.current_compute_type, interactive=True,
 
71
  info="Select the type of computation to perform.")
72
  nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
73
  info="Number of candidates when sampling with non-zero temperature.")
@@ -88,6 +90,9 @@ class App:
88
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
89
  interactive=True,
90
  info="If the gzip compression ratio is above this value, treat as failed.")
 
 
 
91
  with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
92
  nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
93
  info="Exponential length penalty constant.")
@@ -113,9 +118,6 @@ class App:
113
  nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
114
  precision=0,
115
  info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
116
- nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: whisper_params["chunk_length"],
117
- precision=0,
118
- info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
119
  nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
120
  value=lambda: whisper_params["hallucination_silence_threshold"],
121
  info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
@@ -127,32 +129,37 @@ class App:
127
  precision=0,
128
  info="Number of segments to consider for the language detection.")
129
  with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
130
- nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=whisper_params["chunk_length_s"],
131
- precision=0)
132
  nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
133
 
134
- with gr.Accordion("BGM Separation", open=False):
135
- cb_bgm_separation = gr.Checkbox(label="Enable BGM Separation Filter", value=uvr_params["is_separate_bgm"],
136
- interactive=True)
 
 
137
  dd_uvr_device = gr.Dropdown(label="Device", value=self.whisper_inf.music_separator.device,
138
  choices=self.whisper_inf.music_separator.available_devices)
139
  dd_uvr_model_size = gr.Dropdown(label="Model", value=uvr_params["model_size"],
140
  choices=self.whisper_inf.music_separator.available_models)
141
  nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
142
  cb_uvr_save_file = gr.Checkbox(label="Save separated files to output", value=uvr_params["save_file"])
 
 
143
 
144
- with gr.Accordion("VAD", open=False):
145
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
146
- interactive=True)
147
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=vad_params["threshold"],
 
 
148
  info="Lower it to be more sensitive to small sounds.")
149
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=vad_params["min_speech_duration_ms"],
 
150
  info="Final speech chunks shorter than this time are thrown out")
151
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=vad_params["max_speech_duration_s"],
152
- info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
153
- " than this time will be split at the timestamp of the last silence that"
154
- " lasts more than 100ms (if any), to prevent aggressive cutting.")
155
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=vad_params["min_silence_duration_ms"],
156
  info="In the end of each speech chunk wait for this time"
157
  " before separating it")
158
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
@@ -161,7 +168,10 @@ class App:
161
  with gr.Accordion("Diarization", open=False):
162
  cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
163
  tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
164
- info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
 
 
 
165
  dd_diarization_device = gr.Dropdown(label="Device",
166
  choices=self.whisper_inf.diarizer.get_available_device(),
167
  value=self.whisper_inf.diarizer.get_device())
@@ -177,19 +187,19 @@ class App:
177
  temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
178
  vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
179
  max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
180
- speech_pad_ms=nb_speech_pad_ms, chunk_length_s=nb_chunk_length_s, batch_size=nb_batch_size,
181
  is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
182
  length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
183
  no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
184
  suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
185
  word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
186
- append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, chunk_length=nb_chunk_length,
187
  hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
188
  language_detection_threshold=nb_language_detection_threshold,
189
  language_detection_segments=nb_language_detection_segments,
190
  prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
191
  uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
192
- uvr_save_file=cb_uvr_save_file
193
  ),
194
  dd_file_format,
195
  cb_timestamp
 
21
  class App:
22
  def __init__(self, args):
23
  self.args = args
24
+ self.app = gr.Blocks(css=CSS, theme=self.args.theme, delete_cache=(60, 3600))
25
  self.whisper_inf = WhisperFactory.create_whisper_inference(
26
  whisper_type=self.args.whisper_type,
27
  whisper_model_dir=self.args.whisper_model_dir,
 
59
  with gr.Row():
60
  cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add a timestamp to the end of the filename",
61
  interactive=True)
62
+
63
  with gr.Accordion("Advanced Parameters", open=False):
64
  nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
65
  info="Beam size to use for decoding.")
 
69
  info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
70
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
71
  value=self.whisper_inf.current_compute_type, interactive=True,
72
+ allow_custom_value=True,
73
  info="Select the type of computation to perform.")
74
  nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
75
  info="Number of candidates when sampling with non-zero temperature.")
 
90
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
91
  interactive=True,
92
  info="If the gzip compression ratio is above this value, treat as failed.")
93
+ nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
94
+ precision=0,
95
+ info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
96
  with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
97
  nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
98
  info="Exponential length penalty constant.")
 
118
  nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
119
  precision=0,
120
  info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
 
 
 
121
  nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
122
  value=lambda: whisper_params["hallucination_silence_threshold"],
123
  info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
 
129
  precision=0,
130
  info="Number of segments to consider for the language detection.")
131
  with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
 
 
132
  nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
133
 
134
+ with gr.Accordion("Background Music Remover Filter", open=False):
135
+ cb_bgm_separation = gr.Checkbox(label="Enable Background Music Remover Filter", value=uvr_params["is_separate_bgm"],
136
+ interactive=True,
137
+ info="Enabling this will remove background music by submodel before"
138
+ " transcribing ")
139
  dd_uvr_device = gr.Dropdown(label="Device", value=self.whisper_inf.music_separator.device,
140
  choices=self.whisper_inf.music_separator.available_devices)
141
  dd_uvr_model_size = gr.Dropdown(label="Model", value=uvr_params["model_size"],
142
  choices=self.whisper_inf.music_separator.available_models)
143
  nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
144
  cb_uvr_save_file = gr.Checkbox(label="Save separated files to output", value=uvr_params["save_file"])
145
+ cb_uvr_enable_offload = gr.Checkbox(label="Offload sub model after removing background music",
146
+ value=uvr_params["enable_offload"])
147
 
148
+ with gr.Accordion("Voice Detection Filter", open=False):
149
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
150
+ interactive=True,
151
+ info="Enable this to transcribe only detected voice parts by submodel.")
152
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
153
+ value=vad_params["threshold"],
154
  info="Lower it to be more sensitive to small sounds.")
155
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
156
+ value=vad_params["min_speech_duration_ms"],
157
  info="Final speech chunks shorter than this time are thrown out")
158
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
159
+ value=vad_params["max_speech_duration_s"],
160
+ info="Maximum duration of speech chunks in \"seconds\".")
161
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
162
+ value=vad_params["min_silence_duration_ms"],
163
  info="In the end of each speech chunk wait for this time"
164
  " before separating it")
165
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
 
168
  with gr.Accordion("Diarization", open=False):
169
  cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
170
  tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
171
+ info="This is only needed the first time you download the model. If you already have"
172
+ " models, you don't need to enter. To download the model, you must manually go "
173
+ "to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to"
174
+ " their requirement.")
175
  dd_diarization_device = gr.Dropdown(label="Device",
176
  choices=self.whisper_inf.diarizer.get_available_device(),
177
  value=self.whisper_inf.diarizer.get_device())
 
187
  temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
188
  vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
189
  max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
190
+ speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
191
  is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
192
  length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
193
  no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
194
  suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
195
  word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
196
+ append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
197
  hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
198
  language_detection_threshold=nb_language_detection_threshold,
199
  language_detection_segments=nb_language_detection_segments,
200
  prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
201
  uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
202
+ uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
203
  ),
204
  dd_file_format,
205
  cb_timestamp
configs/default_parameters.yaml CHANGED
@@ -12,7 +12,7 @@ whisper:
12
  initial_prompt: null
13
  temperature: 0
14
  compression_ratio_threshold: 2.4
15
- chunk_length_s: 30
16
  batch_size: 24
17
  length_penalty: 1
18
  repetition_penalty: 1
@@ -25,7 +25,6 @@ whisper:
25
  prepend_punctuations: "\"'“¿([{-"
26
  append_punctuations: "\"'.。,,!!??::”)]}、"
27
  max_new_tokens: null
28
- chunk_length: null
29
  hallucination_silence_threshold: null
30
  hotwords: null
31
  language_detection_threshold: null
@@ -37,8 +36,8 @@ vad:
37
  threshold: 0.5
38
  min_speech_duration_ms: 250
39
  max_speech_duration_s: 9999
40
- min_silence_duration_ms: 2000
41
- speech_pad_ms: 400
42
 
43
  diarization:
44
  is_diarize: false
@@ -49,6 +48,7 @@ bgm_separation:
49
  model_size: "UVR-MDX-NET-Inst_HQ_4"
50
  segment_size: 256
51
  save_file: false
 
52
 
53
  translation:
54
  deepl:
 
12
  initial_prompt: null
13
  temperature: 0
14
  compression_ratio_threshold: 2.4
15
+ chunk_length: 30
16
  batch_size: 24
17
  length_penalty: 1
18
  repetition_penalty: 1
 
25
  prepend_punctuations: "\"'“¿([{-"
26
  append_punctuations: "\"'.。,,!!??::”)]}、"
27
  max_new_tokens: null
 
28
  hallucination_silence_threshold: null
29
  hotwords: null
30
  language_detection_threshold: null
 
36
  threshold: 0.5
37
  min_speech_duration_ms: 250
38
  max_speech_duration_s: 9999
39
+ min_silence_duration_ms: 1000
40
+ speech_pad_ms: 2000
41
 
42
  diarization:
43
  is_diarize: false
 
48
  model_size: "UVR-MDX-NET-Inst_HQ_4"
49
  segment_size: 256
50
  save_file: false
51
+ enable_offload: true
52
 
53
  translation:
54
  deepl:
modules/translation/deepl_api.py CHANGED
@@ -98,8 +98,8 @@ class DeepLAPI:
98
  fileobjs: list,
99
  source_lang: str,
100
  target_lang: str,
101
- is_pro: bool,
102
- add_timestamp: bool,
103
  progress=gr.Progress()) -> list:
104
  """
105
  Translate subtitle files using DeepL API
@@ -126,6 +126,9 @@ class DeepLAPI:
126
  String to return to gr.Textbox()
127
  Files to return to gr.Files()
128
  """
 
 
 
129
  self.cache_parameters(
130
  api_key=auth_key,
131
  is_pro=is_pro,
@@ -136,37 +139,28 @@ class DeepLAPI:
136
 
137
  files_info = {}
138
  for fileobj in fileobjs:
139
- file_path = fileobj.name
140
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
141
 
142
  if file_ext == ".srt":
143
  parsed_dicts = parse_srt(file_path=file_path)
144
 
145
- batch_size = self.max_text_batch_size
146
- for batch_start in range(0, len(parsed_dicts), batch_size):
147
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
148
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
149
- translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
150
- target_lang, is_pro)
151
- for i, translated_text in enumerate(translated_texts):
152
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
153
- progress(batch_end / len(parsed_dicts), desc="Translating..")
154
-
155
- subtitle = get_serialized_srt(parsed_dicts)
156
-
157
  elif file_ext == ".vtt":
158
  parsed_dicts = parse_vtt(file_path=file_path)
159
 
160
- batch_size = self.max_text_batch_size
161
- for batch_start in range(0, len(parsed_dicts), batch_size):
162
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
163
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
164
- translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
165
- target_lang, is_pro)
166
- for i, translated_text in enumerate(translated_texts):
167
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
168
- progress(batch_end / len(parsed_dicts), desc="Translating..")
169
 
 
 
 
170
  subtitle = get_serialized_vtt(parsed_dicts)
171
 
172
  if add_timestamp:
@@ -193,8 +187,14 @@ class DeepLAPI:
193
  text: list,
194
  source_lang: str,
195
  target_lang: str,
196
- is_pro: bool):
197
  """Request API response to DeepL server"""
 
 
 
 
 
 
198
 
199
  url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
200
  headers = {
 
98
  fileobjs: list,
99
  source_lang: str,
100
  target_lang: str,
101
+ is_pro: bool = False,
102
+ add_timestamp: bool = True,
103
  progress=gr.Progress()) -> list:
104
  """
105
  Translate subtitle files using DeepL API
 
126
  String to return to gr.Textbox()
127
  Files to return to gr.Files()
128
  """
129
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
130
+ fileobjs = [fileobj.name for fileobj in fileobjs]
131
+
132
  self.cache_parameters(
133
  api_key=auth_key,
134
  is_pro=is_pro,
 
139
 
140
  files_info = {}
141
  for fileobj in fileobjs:
142
+ file_path = fileobj
143
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
144
 
145
  if file_ext == ".srt":
146
  parsed_dicts = parse_srt(file_path=file_path)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  elif file_ext == ".vtt":
149
  parsed_dicts = parse_vtt(file_path=file_path)
150
 
151
+ batch_size = self.max_text_batch_size
152
+ for batch_start in range(0, len(parsed_dicts), batch_size):
153
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
154
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
155
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
156
+ target_lang, is_pro)
157
+ for i, translated_text in enumerate(translated_texts):
158
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
159
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
160
 
161
+ if file_ext == ".srt":
162
+ subtitle = get_serialized_srt(parsed_dicts)
163
+ elif file_ext == ".vtt":
164
  subtitle = get_serialized_vtt(parsed_dicts)
165
 
166
  if add_timestamp:
 
187
  text: list,
188
  source_lang: str,
189
  target_lang: str,
190
+ is_pro: bool = False):
191
  """Request API response to DeepL server"""
192
+ if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()):
193
+ raise ValueError(f"Source language {source_lang} is not supported."
194
+ f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}")
195
+ if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()):
196
+ raise ValueError(f"Target language {target_lang} is not supported."
197
+ f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}")
198
 
199
  url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
200
  headers = {
modules/translation/nllb_inference.py CHANGED
@@ -38,8 +38,19 @@ class NLLBInference(TranslationBase):
38
  model_size: str,
39
  src_lang: str,
40
  tgt_lang: str,
41
- progress: gr.Progress
42
  ):
 
 
 
 
 
 
 
 
 
 
 
43
  if model_size != self.current_model_size or self.model is None:
44
  print("\nInitializing NLLB Model..\n")
45
  progress(0, desc="Initializing NLLB Model..")
@@ -51,8 +62,7 @@ class NLLBInference(TranslationBase):
51
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
52
  cache_dir=os.path.join(self.model_dir, "tokenizers"),
53
  local_files_only=local_files_only)
54
- src_lang = NLLB_AVAILABLE_LANGS[src_lang]
55
- tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
56
  self.pipeline = pipeline("translation",
57
  model=self.model,
58
  tokenizer=self.tokenizer,
 
38
  model_size: str,
39
  src_lang: str,
40
  tgt_lang: str,
41
+ progress: gr.Progress = gr.Progress()
42
  ):
43
+ def validate_language(lang: str) -> str:
44
+ if lang in NLLB_AVAILABLE_LANGS:
45
+ return NLLB_AVAILABLE_LANGS[lang]
46
+ elif lang not in NLLB_AVAILABLE_LANGS.values():
47
+ raise ValueError(
48
+ f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
49
+ return lang
50
+
51
+ src_lang = validate_language(src_lang)
52
+ tgt_lang = validate_language(tgt_lang)
53
+
54
  if model_size != self.current_model_size or self.model is None:
55
  print("\nInitializing NLLB Model..\n")
56
  progress(0, desc="Initializing NLLB Model..")
 
62
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
63
  cache_dir=os.path.join(self.model_dir, "tokenizers"),
64
  local_files_only=local_files_only)
65
+
 
66
  self.pipeline = pipeline("translation",
67
  model=self.model,
68
  tokenizer=self.tokenizer,
modules/translation/translation_base.py CHANGED
@@ -40,7 +40,7 @@ class TranslationBase(ABC):
40
  model_size: str,
41
  src_lang: str,
42
  tgt_lang: str,
43
- progress: gr.Progress
44
  ):
45
  pass
46
 
@@ -50,8 +50,8 @@ class TranslationBase(ABC):
50
  model_size: str,
51
  src_lang: str,
52
  tgt_lang: str,
53
- max_length: int,
54
- add_timestamp: bool,
55
  progress=gr.Progress()) -> list:
56
  """
57
  Translate subtitle file from source language to target language
@@ -81,6 +81,9 @@ class TranslationBase(ABC):
81
  Files to return to gr.Files()
82
  """
83
  try:
 
 
 
84
  self.cache_parameters(model_size=model_size,
85
  src_lang=src_lang,
86
  tgt_lang=tgt_lang,
@@ -94,10 +97,9 @@ class TranslationBase(ABC):
94
 
95
  files_info = {}
96
  for fileobj in fileobjs:
97
- file_path = fileobj.name
98
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
99
  if file_ext == ".srt":
100
- parsed_dicts = parse_srt(file_path=file_path)
101
  total_progress = len(parsed_dicts)
102
  for index, dic in enumerate(parsed_dicts):
103
  progress(index / total_progress, desc="Translating..")
@@ -106,7 +108,7 @@ class TranslationBase(ABC):
106
  subtitle = get_serialized_srt(parsed_dicts)
107
 
108
  elif file_ext == ".vtt":
109
- parsed_dicts = parse_vtt(file_path=file_path)
110
  total_progress = len(parsed_dicts)
111
  for index, dic in enumerate(parsed_dicts):
112
  progress(index / total_progress, desc="Translating..")
 
40
  model_size: str,
41
  src_lang: str,
42
  tgt_lang: str,
43
+ progress: gr.Progress = gr.Progress()
44
  ):
45
  pass
46
 
 
50
  model_size: str,
51
  src_lang: str,
52
  tgt_lang: str,
53
+ max_length: int = 200,
54
+ add_timestamp: bool = True,
55
  progress=gr.Progress()) -> list:
56
  """
57
  Translate subtitle file from source language to target language
 
81
  Files to return to gr.Files()
82
  """
83
  try:
84
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
85
+ fileobjs = [file.name for file in fileobjs]
86
+
87
  self.cache_parameters(model_size=model_size,
88
  src_lang=src_lang,
89
  tgt_lang=tgt_lang,
 
97
 
98
  files_info = {}
99
  for fileobj in fileobjs:
100
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
 
101
  if file_ext == ".srt":
102
+ parsed_dicts = parse_srt(file_path=fileobj)
103
  total_progress = len(parsed_dicts)
104
  for index, dic in enumerate(parsed_dicts):
105
  progress(index / total_progress, desc="Translating..")
 
108
  subtitle = get_serialized_srt(parsed_dicts)
109
 
110
  elif file_ext == ".vtt":
111
+ parsed_dicts = parse_vtt(file_path=fileobj)
112
  total_progress = len(parsed_dicts)
113
  for index, dic in enumerate(parsed_dicts):
114
  progress(index / total_progress, desc="Translating..")
modules/utils/subtitle_manager.py CHANGED
@@ -121,11 +121,8 @@ def get_serialized_vtt(dicts):
121
 
122
  @spaces.GPU(duration=120)
123
  def safe_filename(name):
124
- from app import _args
125
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
126
  safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
127
- if not _args.colab:
128
- return safe_name
129
  # Truncate the filename if it exceeds the max_length (20)
130
  if len(safe_name) > 20:
131
  file_extension = safe_name.split('.')[-1]
 
121
 
122
  @spaces.GPU(duration=120)
123
  def safe_filename(name):
 
124
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
125
  safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
 
 
126
  # Truncate the filename if it exceeds the max_length (20)
127
  if len(safe_name) > 20:
128
  file_extension = safe_name.split('.')[-1]
modules/utils/youtube_manager.py CHANGED
@@ -1,4 +1,5 @@
1
  from pytubefix import YouTube
 
2
  import os
3
 
4
 
@@ -12,4 +13,21 @@ def get_ytmetas(link):
12
 
13
 
14
  def get_ytaudio(ytdata: YouTube):
15
- return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pytubefix import YouTube
2
+ import subprocess
3
  import os
4
 
5
 
 
13
 
14
 
15
  def get_ytaudio(ytdata: YouTube):
16
+ # Somehow the audio is corrupted so need to convert to valid audio file.
17
+ # Fix for : https://github.com/jhj0517/Whisper-WebUI/issues/304
18
+
19
+ audio_path = ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
20
+ temp_audio_path = os.path.join("modules", "yt_tmp_fixed.wav")
21
+
22
+ try:
23
+ subprocess.run([
24
+ 'ffmpeg', '-y',
25
+ '-i', audio_path,
26
+ temp_audio_path
27
+ ], check=True)
28
+
29
+ os.replace(temp_audio_path, audio_path)
30
+ return audio_path
31
+ except subprocess.CalledProcessError as e:
32
+ print(f"Error during ffmpeg conversion: {e}")
33
+ return None
modules/whisper/faster_whisper_inference.py CHANGED
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
43
- progress: gr.Progress,
44
  *whisper_params,
45
  ) -> Tuple[List[dict], float]:
46
  """
@@ -126,7 +126,7 @@ class FasterWhisperInference(WhisperBase):
126
  def update_model(self,
127
  model_size: str,
128
  compute_type: str,
129
- progress: gr.Progress
130
  ):
131
  """
132
  Update current model setting
@@ -159,7 +159,7 @@ class FasterWhisperInference(WhisperBase):
159
  ----------
160
  Name list of models
161
  """
162
- model_paths = {model:model for model in whisper.available_models()}
163
  faster_whisper_prefix = "models--Systran--faster-whisper-"
164
 
165
  existing_models = os.listdir(self.model_dir)
 
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
43
+ progress: gr.Progress = gr.Progress(),
44
  *whisper_params,
45
  ) -> Tuple[List[dict], float]:
46
  """
 
126
  def update_model(self,
127
  model_size: str,
128
  compute_type: str,
129
+ progress: gr.Progress = gr.Progress()
130
  ):
131
  """
132
  Update current model setting
 
159
  ----------
160
  Name list of models
161
  """
162
+ model_paths = {model:model for model in faster_whisper.available_models()}
163
  faster_whisper_prefix = "models--Systran--faster-whisper-"
164
 
165
  existing_models = os.listdir(self.model_dir)
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -39,7 +39,7 @@ class InsanelyFastWhisperInference(WhisperBase):
39
 
40
  def transcribe(self,
41
  audio: Union[str, np.ndarray, torch.Tensor],
42
- progress: gr.Progress,
43
  *whisper_params,
44
  ) -> Tuple[List[dict], float]:
45
  """
@@ -75,18 +75,25 @@ class InsanelyFastWhisperInference(WhisperBase):
75
  ) as progress:
76
  progress.add_task("[yellow]Transcribing...", total=None)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  segments = self.model(
79
  inputs=audio,
80
  return_timestamps=True,
81
- chunk_length_s=params.chunk_length_s,
82
  batch_size=params.batch_size,
83
- generate_kwargs={
84
- "language": params.lang,
85
- "task": "translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
86
- "no_speech_threshold": params.no_speech_threshold,
87
- "temperature": params.temperature,
88
- "compression_ratio_threshold": params.compression_ratio_threshold
89
- }
90
  )
91
 
92
  segments_result = self.format_result(
@@ -98,7 +105,7 @@ class InsanelyFastWhisperInference(WhisperBase):
98
  def update_model(self,
99
  model_size: str,
100
  compute_type: str,
101
- progress: gr.Progress,
102
  ):
103
  """
104
  Update current model setting
 
39
 
40
  def transcribe(self,
41
  audio: Union[str, np.ndarray, torch.Tensor],
42
+ progress: gr.Progress = gr.Progress(),
43
  *whisper_params,
44
  ) -> Tuple[List[dict], float]:
45
  """
 
75
  ) as progress:
76
  progress.add_task("[yellow]Transcribing...", total=None)
77
 
78
+ kwargs = {
79
+ "no_speech_threshold": params.no_speech_threshold,
80
+ "temperature": params.temperature,
81
+ "compression_ratio_threshold": params.compression_ratio_threshold,
82
+ "logprob_threshold": params.log_prob_threshold,
83
+ }
84
+
85
+ if self.current_model_size.endswith(".en"):
86
+ pass
87
+ else:
88
+ kwargs["language"] = params.lang
89
+ kwargs["task"] = "translate" if params.is_translate else "transcribe"
90
+
91
  segments = self.model(
92
  inputs=audio,
93
  return_timestamps=True,
94
+ chunk_length_s=params.chunk_length,
95
  batch_size=params.batch_size,
96
+ generate_kwargs=kwargs
 
 
 
 
 
 
97
  )
98
 
99
  segments_result = self.format_result(
 
105
  def update_model(self,
106
  model_size: str,
107
  compute_type: str,
108
+ progress: gr.Progress = gr.Progress(),
109
  ):
110
  """
111
  Update current model setting
modules/whisper/whisper_Inference.py CHANGED
@@ -28,7 +28,7 @@ class WhisperInference(WhisperBase):
28
 
29
  def transcribe(self,
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
- progress: gr.Progress,
32
  *whisper_params,
33
  ) -> Tuple[List[dict], float]:
34
  """
@@ -79,7 +79,7 @@ class WhisperInference(WhisperBase):
79
  def update_model(self,
80
  model_size: str,
81
  compute_type: str,
82
- progress: gr.Progress,
83
  ):
84
  """
85
  Update current model setting
 
28
 
29
  def transcribe(self,
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
+ progress: gr.Progress = gr.Progress(),
32
  *whisper_params,
33
  ) -> Tuple[List[dict], float]:
34
  """
 
79
  def update_model(self,
80
  model_size: str,
81
  compute_type: str,
82
+ progress: gr.Progress = gr.Progress(),
83
  ):
84
  """
85
  Update current model setting
modules/whisper/whisper_base.py CHANGED
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
53
  @abstractmethod
54
  def transcribe(self,
55
  audio: Union[str, BinaryIO, np.ndarray],
56
- progress: gr.Progress,
57
  *whisper_params,
58
  ):
59
  """Inference whisper model to transcribe"""
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
63
  def update_model(self,
64
  model_size: str,
65
  compute_type: str,
66
- progress: gr.Progress
67
  ):
68
  """Initialize whisper model"""
69
  pass
@@ -104,7 +104,9 @@ class WhisperBase(ABC):
104
  add_timestamp=add_timestamp
105
  )
106
 
107
- if params.lang == "Automatic Detection":
 
 
108
  params.lang = None
109
  else:
110
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
@@ -128,11 +130,12 @@ class WhisperBase(ABC):
128
  origin_sample_rate = self.music_separator.audio_info.sample_rate
129
  audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
130
 
131
- self.music_separator.offload()
 
132
 
133
  if params.vad_filter:
134
  # Explicit value set for float('inf') from gr.Number()
135
- if params.max_speech_duration_s >= 9999:
136
  params.max_speech_duration_s = float('inf')
137
 
138
  vad_options = VadOptions(
@@ -171,10 +174,10 @@ class WhisperBase(ABC):
171
  return result, elapsed_time
172
 
173
  def transcribe_file(self,
174
- files: list,
175
- input_folder_path: str,
176
- file_format: str,
177
- add_timestamp: bool,
178
  progress=gr.Progress(),
179
  *whisper_params,
180
  ) -> list:
@@ -207,18 +210,21 @@ class WhisperBase(ABC):
207
  try:
208
  if input_folder_path:
209
  files = get_media_files(input_folder_path)
210
- files = format_gradio_files(files)
 
 
 
211
 
212
  files_info = {}
213
  for file in files:
214
  transcribed_segments, time_for_task = self.run(
215
- file.name,
216
  progress,
217
  add_timestamp,
218
  *whisper_params,
219
  )
220
 
221
- file_name, file_ext = os.path.splitext(os.path.basename(file.name))
222
  subtitle, file_path = self.generate_and_write_file(
223
  file_name=file_name,
224
  transcribed_segments=transcribed_segments,
@@ -245,13 +251,11 @@ class WhisperBase(ABC):
245
  print(f"Error transcribing file: {e}")
246
  finally:
247
  self.release_cuda_memory()
248
- if not files:
249
- self.remove_input_files([file.name for file in files])
250
 
251
  def transcribe_mic(self,
252
  mic_audio: str,
253
- file_format: str,
254
- add_timestamp: bool,
255
  progress=gr.Progress(),
256
  *whisper_params,
257
  ) -> list:
@@ -302,12 +306,11 @@ class WhisperBase(ABC):
302
  print(f"Error transcribing file: {e}")
303
  finally:
304
  self.release_cuda_memory()
305
- self.remove_input_files([mic_audio])
306
 
307
  def transcribe_youtube(self,
308
  youtube_link: str,
309
- file_format: str,
310
- add_timestamp: bool,
311
  progress=gr.Progress(),
312
  *whisper_params,
313
  ) -> list:
@@ -358,22 +361,15 @@ class WhisperBase(ABC):
358
  )
359
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
360
 
 
 
 
361
  return [result_str, result_file_path]
362
 
363
  except Exception as e:
364
  print(f"Error transcribing file: {e}")
365
  finally:
366
- try:
367
- if 'yt' not in locals():
368
- yt = get_ytdata(youtube_link)
369
- file_path = get_ytaudio(yt)
370
- else:
371
- file_path = get_ytaudio(yt)
372
-
373
- self.release_cuda_memory()
374
- self.remove_input_files([file_path])
375
- except Exception as cleanup_error:
376
- pass
377
 
378
  @staticmethod
379
  def generate_and_write_file(file_name: str,
@@ -411,11 +407,12 @@ class WhisperBase(ABC):
411
  else:
412
  output_path = os.path.join(output_dir, f"{file_name}")
413
 
414
- if file_format == "SRT":
 
415
  content = get_srt(transcribed_segments)
416
  output_path += '.srt'
417
 
418
- elif file_format == "WebVTT":
419
  content = get_vtt(transcribed_segments)
420
  output_path += '.vtt'
421
 
 
53
  @abstractmethod
54
  def transcribe(self,
55
  audio: Union[str, BinaryIO, np.ndarray],
56
+ progress: gr.Progress = gr.Progress(),
57
  *whisper_params,
58
  ):
59
  """Inference whisper model to transcribe"""
 
63
  def update_model(self,
64
  model_size: str,
65
  compute_type: str,
66
+ progress: gr.Progress = gr.Progress()
67
  ):
68
  """Initialize whisper model"""
69
  pass
 
104
  add_timestamp=add_timestamp
105
  )
106
 
107
+ if params.lang is None:
108
+ pass
109
+ elif params.lang == "Automatic Detection":
110
  params.lang = None
111
  else:
112
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
 
130
  origin_sample_rate = self.music_separator.audio_info.sample_rate
131
  audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
132
 
133
+ if params.uvr_enable_offload:
134
+ self.music_separator.offload()
135
 
136
  if params.vad_filter:
137
  # Explicit value set for float('inf') from gr.Number()
138
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
139
  params.max_speech_duration_s = float('inf')
140
 
141
  vad_options = VadOptions(
 
174
  return result, elapsed_time
175
 
176
  def transcribe_file(self,
177
+ files: Optional[List] = None,
178
+ input_folder_path: Optional[str] = None,
179
+ file_format: str = "SRT",
180
+ add_timestamp: bool = True,
181
  progress=gr.Progress(),
182
  *whisper_params,
183
  ) -> list:
 
210
  try:
211
  if input_folder_path:
212
  files = get_media_files(input_folder_path)
213
+ if isinstance(files, str):
214
+ files = [files]
215
+ if files and isinstance(files[0], gr.utils.NamedString):
216
+ files = [file.name for file in files]
217
 
218
  files_info = {}
219
  for file in files:
220
  transcribed_segments, time_for_task = self.run(
221
+ file,
222
  progress,
223
  add_timestamp,
224
  *whisper_params,
225
  )
226
 
227
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
228
  subtitle, file_path = self.generate_and_write_file(
229
  file_name=file_name,
230
  transcribed_segments=transcribed_segments,
 
251
  print(f"Error transcribing file: {e}")
252
  finally:
253
  self.release_cuda_memory()
 
 
254
 
255
  def transcribe_mic(self,
256
  mic_audio: str,
257
+ file_format: str = "SRT",
258
+ add_timestamp: bool = True,
259
  progress=gr.Progress(),
260
  *whisper_params,
261
  ) -> list:
 
306
  print(f"Error transcribing file: {e}")
307
  finally:
308
  self.release_cuda_memory()
 
309
 
310
  def transcribe_youtube(self,
311
  youtube_link: str,
312
+ file_format: str = "SRT",
313
+ add_timestamp: bool = True,
314
  progress=gr.Progress(),
315
  *whisper_params,
316
  ) -> list:
 
361
  )
362
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
363
 
364
+ if os.path.exists(audio):
365
+ os.remove(audio)
366
+
367
  return [result_str, result_file_path]
368
 
369
  except Exception as e:
370
  print(f"Error transcribing file: {e}")
371
  finally:
372
+ self.release_cuda_memory()
 
 
 
 
 
 
 
 
 
 
373
 
374
  @staticmethod
375
  def generate_and_write_file(file_name: str,
 
407
  else:
408
  output_path = os.path.join(output_dir, f"{file_name}")
409
 
410
+ file_format = file_format.strip().lower()
411
+ if file_format == "srt":
412
  content = get_srt(transcribed_segments)
413
  output_path += '.srt'
414
 
415
+ elif file_format == "webvtt":
416
  content = get_vtt(transcribed_segments)
417
  output_path += '.vtt'
418
 
modules/whisper/whisper_parameter.py CHANGED
@@ -26,7 +26,6 @@ class WhisperParameters:
26
  max_speech_duration_s: gr.Number
27
  min_silence_duration_ms: gr.Number
28
  speech_pad_ms: gr.Number
29
- chunk_length_s: gr.Number
30
  batch_size: gr.Number
31
  is_diarize: gr.Checkbox
32
  hf_token: gr.Textbox
@@ -52,6 +51,7 @@ class WhisperParameters:
52
  uvr_device: gr.Dropdown
53
  uvr_segment_size: gr.Number
54
  uvr_save_file: gr.Checkbox
 
55
  """
56
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
57
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
@@ -136,10 +136,6 @@ class WhisperParameters:
136
  speech_pad_ms: gr.Number
137
  This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
138
 
139
- chunk_length_s: gr.Number
140
- This parameter is related with insanely-fast-whisper pipe.
141
- Maximum length of each chunk
142
-
143
  batch_size: gr.Number
144
  This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
145
 
@@ -193,8 +189,8 @@ class WhisperParameters:
193
  the maximum will be set by the default max_length.
194
 
195
  chunk_length: gr.Number
196
- This parameter is related to faster-whisper. The length of audio segments. If it is not None, it will overwrite the
197
- default chunk_length of the FeatureExtractor.
198
 
199
  hallucination_silence_threshold: gr.Number
200
  This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
@@ -223,6 +219,10 @@ class WhisperParameters:
223
 
224
  uvr_save_file: gr.Checkbox
225
  This parameter is related to UVR. Boolean value that determines whether to save the file or not.
 
 
 
 
226
  """
227
 
228
  def as_list(self) -> list:
@@ -252,52 +252,52 @@ class WhisperParameters:
252
 
253
  @dataclass
254
  class WhisperValues:
255
- model_size: str
256
- lang: str
257
- is_translate: bool
258
- beam_size: int
259
- log_prob_threshold: float
260
- no_speech_threshold: float
261
- compute_type: str
262
- best_of: int
263
- patience: float
264
- condition_on_previous_text: bool
265
- prompt_reset_on_temperature: float
266
- initial_prompt: Optional[str]
267
- temperature: float
268
- compression_ratio_threshold: float
269
- vad_filter: bool
270
- threshold: float
271
- min_speech_duration_ms: int
272
- max_speech_duration_s: float
273
- min_silence_duration_ms: int
274
- speech_pad_ms: int
275
- chunk_length_s: int
276
- batch_size: int
277
- is_diarize: bool
278
- hf_token: str
279
- diarization_device: str
280
- length_penalty: float
281
- repetition_penalty: float
282
- no_repeat_ngram_size: int
283
- prefix: Optional[str]
284
- suppress_blank: bool
285
- suppress_tokens: Optional[str]
286
- max_initial_timestamp: float
287
- word_timestamps: bool
288
- prepend_punctuations: Optional[str]
289
- append_punctuations: Optional[str]
290
- max_new_tokens: Optional[int]
291
- chunk_length: Optional[int]
292
- hallucination_silence_threshold: Optional[float]
293
- hotwords: Optional[str]
294
- language_detection_threshold: Optional[float]
295
- language_detection_segments: int
296
- is_bgm_separate: bool
297
- uvr_model_size: str
298
- uvr_device: str
299
- uvr_segment_size: int
300
- uvr_save_file: bool
301
  """
302
  A data class to use Whisper parameters.
303
  """
@@ -318,7 +318,6 @@ class WhisperValues:
318
  "initial_prompt": None if not self.initial_prompt else self.initial_prompt,
319
  "temperature": self.temperature,
320
  "compression_ratio_threshold": self.compression_ratio_threshold,
321
- "chunk_length_s": None if self.chunk_length_s is None else self.chunk_length_s,
322
  "batch_size": self.batch_size,
323
  "length_penalty": self.length_penalty,
324
  "repetition_penalty": self.repetition_penalty,
@@ -354,6 +353,17 @@ class WhisperValues:
354
  "model_size": self.uvr_model_size,
355
  "segment_size": self.uvr_segment_size,
356
  "save_file": self.uvr_save_file,
 
357
  },
358
  }
359
  return data
 
 
 
 
 
 
 
 
 
 
 
26
  max_speech_duration_s: gr.Number
27
  min_silence_duration_ms: gr.Number
28
  speech_pad_ms: gr.Number
 
29
  batch_size: gr.Number
30
  is_diarize: gr.Checkbox
31
  hf_token: gr.Textbox
 
51
  uvr_device: gr.Dropdown
52
  uvr_segment_size: gr.Number
53
  uvr_save_file: gr.Checkbox
54
+ uvr_enable_offload: gr.Checkbox
55
  """
56
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
57
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
 
136
  speech_pad_ms: gr.Number
137
  This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
138
 
 
 
 
 
139
  batch_size: gr.Number
140
  This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
141
 
 
189
  the maximum will be set by the default max_length.
190
 
191
  chunk_length: gr.Number
192
+ This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
193
+ If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
194
 
195
  hallucination_silence_threshold: gr.Number
196
  This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
 
219
 
220
  uvr_save_file: gr.Checkbox
221
  This parameter is related to UVR. Boolean value that determines whether to save the file or not.
222
+
223
+ uvr_enable_offload: gr.Checkbox
224
+ This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
225
+ after each transcription.
226
  """
227
 
228
  def as_list(self) -> list:
 
252
 
253
  @dataclass
254
  class WhisperValues:
255
+ model_size: str = "large-v2"
256
+ lang: Optional[str] = None
257
+ is_translate: bool = False
258
+ beam_size: int = 5
259
+ log_prob_threshold: float = -1.0
260
+ no_speech_threshold: float = 0.6
261
+ compute_type: str = "float16"
262
+ best_of: int = 5
263
+ patience: float = 1.0
264
+ condition_on_previous_text: bool = True
265
+ prompt_reset_on_temperature: float = 0.5
266
+ initial_prompt: Optional[str] = None
267
+ temperature: float = 0.0
268
+ compression_ratio_threshold: float = 2.4
269
+ vad_filter: bool = False
270
+ threshold: float = 0.5
271
+ min_speech_duration_ms: int = 250
272
+ max_speech_duration_s: float = float("inf")
273
+ min_silence_duration_ms: int = 2000
274
+ speech_pad_ms: int = 400
275
+ batch_size: int = 24
276
+ is_diarize: bool = False
277
+ hf_token: str = ""
278
+ diarization_device: str = "cuda"
279
+ length_penalty: float = 1.0
280
+ repetition_penalty: float = 1.0
281
+ no_repeat_ngram_size: int = 0
282
+ prefix: Optional[str] = None
283
+ suppress_blank: bool = True
284
+ suppress_tokens: Optional[str] = "[-1]"
285
+ max_initial_timestamp: float = 0.0
286
+ word_timestamps: bool = False
287
+ prepend_punctuations: Optional[str] = "\"'“¿([{-"
288
+ append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
289
+ max_new_tokens: Optional[int] = None
290
+ chunk_length: Optional[int] = 30
291
+ hallucination_silence_threshold: Optional[float] = None
292
+ hotwords: Optional[str] = None
293
+ language_detection_threshold: Optional[float] = None
294
+ language_detection_segments: int = 1
295
+ is_bgm_separate: bool = False
296
+ uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
297
+ uvr_device: str = "cuda"
298
+ uvr_segment_size: int = 256
299
+ uvr_save_file: bool = False
300
+ uvr_enable_offload: bool = True
301
  """
302
  A data class to use Whisper parameters.
303
  """
 
318
  "initial_prompt": None if not self.initial_prompt else self.initial_prompt,
319
  "temperature": self.temperature,
320
  "compression_ratio_threshold": self.compression_ratio_threshold,
 
321
  "batch_size": self.batch_size,
322
  "length_penalty": self.length_penalty,
323
  "repetition_penalty": self.repetition_penalty,
 
353
  "model_size": self.uvr_model_size,
354
  "segment_size": self.uvr_segment_size,
355
  "save_file": self.uvr_save_file,
356
+ "enable_offload": self.uvr_enable_offload
357
  },
358
  }
359
  return data
360
+
361
+ def as_list(self) -> list:
362
+ """
363
+ Converts the data class attributes into a list
364
+
365
+ Returns
366
+ ----------
367
+ A list of Whisper parameters
368
+ """
369
+ return [getattr(self, f.name) for f in fields(self)]
requirements.txt CHANGED
@@ -2,14 +2,15 @@
2
  # If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement):
3
  # For CUDA 12.1, use : https://download.pytorch.org/whl/cu121
4
  # For CUDA 12.4, use : https://download.pytorch.org/whl/cu124
5
- --extra-index-url https://download.pytorch.org/whl/cu124
6
 
7
 
8
- torch
 
9
  git+https://github.com/jhj0517/jhj0517-whisper.git
10
  faster-whisper==1.0.3
11
- transformers==4.42.3
12
- gradio==4.43.0
13
  pytubefix
14
  ruamel.yaml==0.18.6
15
  pyannote.audio==3.3.1
 
2
  # If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement):
3
  # For CUDA 12.1, use : https://download.pytorch.org/whl/cu121
4
  # For CUDA 12.4, use : https://download.pytorch.org/whl/cu124
5
+ --extra-index-url https://download.pytorch.org/whl/cu121
6
 
7
 
8
+ torch==2.3.1
9
+ torchaudio==2.3.1
10
  git+https://github.com/jhj0517/jhj0517-whisper.git
11
  faster-whisper==1.0.3
12
+ transformers
13
+ gradio
14
  pytubefix
15
  ruamel.yaml==0.18.6
16
  pyannote.audio==3.3.1
tests/test_bgm_separation.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+ from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.whisper_parameter import WhisperValues
4
+ from test_config import *
5
+ from test_transcription import download_file, test_transcribe
6
+
7
+ import gradio as gr
8
+ import pytest
9
+ import torch
10
+ import os
11
+
12
+
13
+ @pytest.mark.skipif(
14
+ not is_cuda_available(),
15
+ reason="Skipping because the test only works on GPU"
16
+ )
17
+ @pytest.mark.parametrize(
18
+ "whisper_type,vad_filter,bgm_separation,diarization",
19
+ [
20
+ ("whisper", False, True, False),
21
+ ("faster-whisper", False, True, False),
22
+ ("insanely_fast_whisper", False, True, False)
23
+ ]
24
+ )
25
+ def test_bgm_separation_pipeline(
26
+ whisper_type: str,
27
+ vad_filter: bool,
28
+ bgm_separation: bool,
29
+ diarization: bool,
30
+ ):
31
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
32
+
33
+
34
+ @pytest.mark.skipif(
35
+ not is_cuda_available(),
36
+ reason="Skipping because the test only works on GPU"
37
+ )
38
+ @pytest.mark.parametrize(
39
+ "whisper_type,vad_filter,bgm_separation,diarization",
40
+ [
41
+ ("whisper", True, True, False),
42
+ ("faster-whisper", True, True, False),
43
+ ("insanely_fast_whisper", True, True, False)
44
+ ]
45
+ )
46
+ def test_bgm_separation_with_vad_pipeline(
47
+ whisper_type: str,
48
+ vad_filter: bool,
49
+ bgm_separation: bool,
50
+ diarization: bool,
51
+ ):
52
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
53
+
tests/test_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+
3
+ import os
4
+ import torch
5
+
6
+ TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
7
+ TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
8
+ TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
9
+ TEST_WHISPER_MODEL = "tiny"
10
+ TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
11
+ TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
12
+ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
13
+ TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
14
+
15
+
16
+ def is_cuda_available():
17
+ return torch.cuda.is_available()
tests/test_diarization.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+ from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.whisper_parameter import WhisperValues
4
+ from test_config import *
5
+ from test_transcription import download_file, test_transcribe
6
+
7
+ import gradio as gr
8
+ import pytest
9
+ import os
10
+
11
+
12
+ @pytest.mark.skipif(
13
+ not is_cuda_available(),
14
+ reason="Skipping because the test only works on GPU"
15
+ )
16
+ @pytest.mark.parametrize(
17
+ "whisper_type,vad_filter,bgm_separation,diarization",
18
+ [
19
+ ("whisper", False, False, True),
20
+ ("faster-whisper", False, False, True),
21
+ ("insanely_fast_whisper", False, False, True)
22
+ ]
23
+ )
24
+ def test_diarization_pipeline(
25
+ whisper_type: str,
26
+ vad_filter: bool,
27
+ bgm_separation: bool,
28
+ diarization: bool,
29
+ ):
30
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
31
+
tests/test_srt.srt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 1
2
+ 00:00:00,000 --> 00:00:02,240
3
+ You've got
4
+
5
+ 2
6
+ 00:00:02,240 --> 00:00:04,160
7
+ a friend in me.
tests/test_transcription.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.whisper.whisper_factory import WhisperFactory
2
+ from modules.whisper.whisper_parameter import WhisperValues
3
+ from modules.utils.paths import WEBUI_DIR
4
+ from test_config import *
5
+
6
+ import requests
7
+ import pytest
8
+ import gradio as gr
9
+ import os
10
+
11
+
12
+ @pytest.mark.parametrize(
13
+ "whisper_type,vad_filter,bgm_separation,diarization",
14
+ [
15
+ ("whisper", False, False, False),
16
+ ("faster-whisper", False, False, False),
17
+ ("insanely_fast_whisper", False, False, False)
18
+ ]
19
+ )
20
+ def test_transcribe(
21
+ whisper_type: str,
22
+ vad_filter: bool,
23
+ bgm_separation: bool,
24
+ diarization: bool,
25
+ ):
26
+ audio_path_dir = os.path.join(WEBUI_DIR, "tests")
27
+ audio_path = os.path.join(audio_path_dir, "jfk.wav")
28
+ if not os.path.exists(audio_path):
29
+ download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
30
+
31
+ whisper_inferencer = WhisperFactory.create_whisper_inference(
32
+ whisper_type=whisper_type,
33
+ )
34
+ print(
35
+ f"""Whisper Device : {whisper_inferencer.device}\n"""
36
+ f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n"""
37
+ f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
38
+ )
39
+
40
+ hparams = WhisperValues(
41
+ model_size=TEST_WHISPER_MODEL,
42
+ vad_filter=vad_filter,
43
+ is_bgm_separate=bgm_separation,
44
+ compute_type=whisper_inferencer.current_compute_type,
45
+ uvr_enable_offload=True,
46
+ is_diarize=diarization,
47
+ ).as_list()
48
+
49
+ subtitle_str, file_path = whisper_inferencer.transcribe_file(
50
+ [audio_path],
51
+ None,
52
+ "SRT",
53
+ False,
54
+ gr.Progress(),
55
+ *hparams,
56
+ )
57
+
58
+ assert isinstance(subtitle_str, str) and subtitle_str
59
+ assert isinstance(file_path[0], str) and file_path
60
+
61
+ whisper_inferencer.transcribe_youtube(
62
+ TEST_YOUTUBE_URL,
63
+ "SRT",
64
+ False,
65
+ gr.Progress(),
66
+ *hparams,
67
+ )
68
+ assert isinstance(subtitle_str, str) and subtitle_str
69
+ assert isinstance(file_path[0], str) and file_path
70
+
71
+ whisper_inferencer.transcribe_mic(
72
+ audio_path,
73
+ "SRT",
74
+ False,
75
+ gr.Progress(),
76
+ *hparams,
77
+ )
78
+ assert isinstance(subtitle_str, str) and subtitle_str
79
+ assert isinstance(file_path[0], str) and file_path
80
+
81
+
82
+ def download_file(url, save_dir):
83
+ if os.path.exists(TEST_FILE_PATH):
84
+ return
85
+
86
+ if not os.path.exists(save_dir):
87
+ os.makedirs(save_dir)
88
+
89
+ file_name = url.split("/")[-1]
90
+ file_path = os.path.join(save_dir, file_name)
91
+
92
+ response = requests.get(url)
93
+
94
+ with open(file_path, "wb") as file:
95
+ file.write(response.content)
96
+
97
+ print(f"File downloaded to: {file_path}")
tests/test_translation.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.translation.deepl_api import DeepLAPI
2
+ from modules.translation.nllb_inference import NLLBInference
3
+ from test_config import *
4
+
5
+ import os
6
+ import pytest
7
+
8
+
9
+ @pytest.mark.parametrize("model_size, file_path", [
10
+ (TEST_NLLB_MODEL, TEST_SUBTITLE_SRT_PATH),
11
+ (TEST_NLLB_MODEL, TEST_SUBTITLE_VTT_PATH),
12
+ ])
13
+ def test_nllb_inference(
14
+ model_size: str,
15
+ file_path: str
16
+ ):
17
+ nllb_inferencer = NLLBInference()
18
+ print(f"NLLB Device : {nllb_inferencer.device}")
19
+
20
+ result_str, file_paths = nllb_inferencer.translate_file(
21
+ fileobjs=[file_path],
22
+ model_size=model_size,
23
+ src_lang="eng_Latn",
24
+ tgt_lang="kor_Hang",
25
+ )
26
+
27
+ assert isinstance(result_str, str)
28
+ assert isinstance(file_paths[0], str)
29
+
30
+
31
+ @pytest.mark.parametrize("file_path", [
32
+ TEST_SUBTITLE_SRT_PATH,
33
+ TEST_SUBTITLE_VTT_PATH,
34
+ ])
35
+ def test_deepl_api(
36
+ file_path: str
37
+ ):
38
+ deepl_api = DeepLAPI()
39
+
40
+ api_key = os.getenv("DEEPL_API_KEY")
41
+
42
+ result_str, file_paths = deepl_api.translate_deepl(
43
+ auth_key=api_key,
44
+ fileobjs=[file_path],
45
+ source_lang="English",
46
+ target_lang="Korean",
47
+ is_pro=False,
48
+ add_timestamp=True,
49
+ )
50
+
51
+ assert isinstance(result_str, str)
52
+ assert isinstance(file_paths[0], str)
tests/test_vad.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+ from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.whisper_parameter import WhisperValues
4
+ from test_config import *
5
+ from test_transcription import download_file, test_transcribe
6
+
7
+ import gradio as gr
8
+ import pytest
9
+ import os
10
+
11
+
12
+ @pytest.mark.parametrize(
13
+ "whisper_type,vad_filter,bgm_separation,diarization",
14
+ [
15
+ ("whisper", True, False, False),
16
+ ("faster-whisper", True, False, False),
17
+ ("insanely_fast_whisper", True, False, False)
18
+ ]
19
+ )
20
+ def test_vad_pipeline(
21
+ whisper_type: str,
22
+ vad_filter: bool,
23
+ bgm_separation: bool,
24
+ diarization: bool,
25
+ ):
26
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
tests/test_vtt.vtt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ WEBVTT
2
+ 00:00:00.500 --> 00:00:02.000
3
+ You've got
4
+
5
+ 00:00:02.500 --> 00:00:04.300
6
+ a friend in me.