jhj0517 commited on
Commit
201b316
·
1 Parent(s): 595b5f3

add `--diarization_model_dir` cli arg

Browse files
app.py CHANGED
@@ -36,23 +36,27 @@ class App:
36
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
37
  whisper_inf = FasterWhisperInference(
38
  model_dir=self.args.faster_whisper_model_dir,
39
- output_dir=self.args.output_dir
 
40
  )
41
  elif whisper_type in ["whisper"]:
42
  whisper_inf = WhisperInference(
43
  model_dir=self.args.whisper_model_dir,
44
- output_dir=self.args.output_dir
 
45
  )
46
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
47
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
48
  whisper_inf = InsanelyFastWhisperInference(
49
  model_dir=self.args.insanely_fast_whisper_model_dir,
50
- output_dir=self.args.output_dir
 
51
  )
52
  else:
53
  whisper_inf = FasterWhisperInference(
54
  model_dir=self.args.faster_whisper_model_dir,
55
- output_dir=self.args.output_dir
 
56
  )
57
  return whisper_inf
58
 
@@ -90,7 +94,7 @@ class App:
90
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
91
  with gr.Row():
92
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
93
- with gr.Accordion("Advanced_Parameters", open=False):
94
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
95
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
96
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
@@ -101,7 +105,7 @@ class App:
101
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
102
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
103
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
104
- with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
105
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
106
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
107
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
@@ -109,12 +113,14 @@ class App:
109
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
110
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
111
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
 
 
 
 
 
112
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
113
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
114
  nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
115
- with gr.Accordion("Diarization Parameters", open=False):
116
- cb_diarize = gr.Checkbox(label="Enable Diarization")
117
- tb_hf_token = gr.Text(label="HuggingFace Token", value="")
118
  with gr.Row():
119
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
120
  with gr.Row():
@@ -146,7 +152,8 @@ class App:
146
  chunk_length_s=nb_chunk_length_s,
147
  batch_size=nb_batch_size,
148
  is_diarize=cb_diarize,
149
- hf_token=tb_hf_token)
 
150
 
151
  btn_run.click(fn=self.whisper_inf.transcribe_file,
152
  inputs=params + whisper_params.as_list(),
@@ -174,7 +181,7 @@ class App:
174
  with gr.Row():
175
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
176
  interactive=True)
177
- with gr.Accordion("Advanced_Parameters", open=False):
178
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
179
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
180
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
@@ -185,7 +192,7 @@ class App:
185
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
186
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
187
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
188
- with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
189
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
190
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
191
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
@@ -193,6 +200,11 @@ class App:
193
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
194
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
195
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
 
 
 
 
 
196
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
197
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
198
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
@@ -228,7 +240,8 @@ class App:
228
  chunk_length_s=nb_chunk_length_s,
229
  batch_size=nb_batch_size,
230
  is_diarize=cb_diarize,
231
- hf_token=tb_hf_token)
 
232
 
233
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
234
  inputs=params + whisper_params.as_list(),
@@ -249,7 +262,7 @@ class App:
249
  dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
250
  with gr.Row():
251
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
252
- with gr.Accordion("Advanced_Parameters", open=False):
253
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
254
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
255
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
@@ -259,7 +272,7 @@ class App:
259
  cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
260
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
261
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
262
- with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
263
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
264
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
265
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
@@ -267,6 +280,11 @@ class App:
267
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
268
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
269
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
 
 
 
 
 
270
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
271
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
272
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
@@ -302,7 +320,8 @@ class App:
302
  chunk_length_s=nb_chunk_length_s,
303
  batch_size=nb_batch_size,
304
  is_diarize=cb_diarize,
305
- hf_token=tb_hf_token)
 
306
 
307
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
308
  inputs=params + whisper_params.as_list(),
@@ -404,6 +423,7 @@ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=Tru
404
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
405
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
406
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
 
407
  parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
408
  parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
409
  _args = parser.parse_args()
 
36
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
37
  whisper_inf = FasterWhisperInference(
38
  model_dir=self.args.faster_whisper_model_dir,
39
+ output_dir=self.args.output_dir,
40
+ args=self.args
41
  )
42
  elif whisper_type in ["whisper"]:
43
  whisper_inf = WhisperInference(
44
  model_dir=self.args.whisper_model_dir,
45
+ output_dir=self.args.output_dir,
46
+ args=self.args
47
  )
48
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
49
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
50
  whisper_inf = InsanelyFastWhisperInference(
51
  model_dir=self.args.insanely_fast_whisper_model_dir,
52
+ output_dir=self.args.output_dir,
53
+ args=self.args
54
  )
55
  else:
56
  whisper_inf = FasterWhisperInference(
57
  model_dir=self.args.faster_whisper_model_dir,
58
+ output_dir=self.args.output_dir,
59
+ args=self.args
60
  )
61
  return whisper_inf
62
 
 
94
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
95
  with gr.Row():
96
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
97
+ with gr.Accordion("Advanced Parameters", open=False):
98
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
99
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
100
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
105
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
106
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
107
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
108
+ with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
109
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
110
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
111
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
 
113
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
114
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
115
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
116
+ with gr.Accordion("Diarization", open=False):
117
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
118
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
119
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter.")
120
+ dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
121
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
122
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
123
  nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
 
 
 
124
  with gr.Row():
125
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
126
  with gr.Row():
 
152
  chunk_length_s=nb_chunk_length_s,
153
  batch_size=nb_batch_size,
154
  is_diarize=cb_diarize,
155
+ hf_token=tb_hf_token,
156
+ diarization_device=dd_diarization_device)
157
 
158
  btn_run.click(fn=self.whisper_inf.transcribe_file,
159
  inputs=params + whisper_params.as_list(),
 
181
  with gr.Row():
182
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
183
  interactive=True)
184
+ with gr.Accordion("Advanced Parameters", open=False):
185
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
186
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
187
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
192
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
193
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
194
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
195
+ with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
196
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
197
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
198
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
 
200
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
201
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
202
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
203
+ with gr.Accordion("Diarization", open=False):
204
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
205
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
206
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter.")
207
+ dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
208
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
209
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
210
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
 
240
  chunk_length_s=nb_chunk_length_s,
241
  batch_size=nb_batch_size,
242
  is_diarize=cb_diarize,
243
+ hf_token=tb_hf_token,
244
+ diarization_device=dd_diarization_device)
245
 
246
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
247
  inputs=params + whisper_params.as_list(),
 
262
  dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
263
  with gr.Row():
264
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
265
+ with gr.Accordion("Advanced Parameters", open=False):
266
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
267
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
268
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
272
  cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
273
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
274
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
275
+ with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
276
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
277
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
278
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
 
280
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
281
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
282
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
283
+ with gr.Accordion("Diarization", open=False):
284
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
285
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
286
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter.")
287
+ dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
288
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
289
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
290
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
 
320
  chunk_length_s=nb_chunk_length_s,
321
  batch_size=nb_batch_size,
322
  is_diarize=cb_diarize,
323
+ hf_token=tb_hf_token,
324
+ diarization_device=dd_diarization_device)
325
 
326
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
327
  inputs=params + whisper_params.as_list(),
 
423
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
424
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
425
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
426
+ parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"), help='Directory path of the diarization model')
427
  parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
428
  parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
429
  _args = parser.parse_args()
modules/diarize_pipeline.py CHANGED
@@ -11,7 +11,7 @@ class DiarizationPipeline:
11
  def __init__(
12
  self,
13
  model_name="pyannote/speaker-diarization-3.1",
14
- cache_dir: str = os.path.join("models", "Whisper", "whisperx"),
15
  use_auth_token=None,
16
  device: Optional[Union[str, torch.device]] = "cpu",
17
  ):
 
11
  def __init__(
12
  self,
13
  model_name="pyannote/speaker-diarization-3.1",
14
+ cache_dir: str = os.path.join("models", "Diarization"),
15
  use_auth_token=None,
16
  device: Optional[Union[str, torch.device]] = "cpu",
17
  ):
modules/diarizer.py CHANGED
@@ -9,7 +9,7 @@ from modules.diarize_pipeline import DiarizationPipeline
9
 
10
  class Diarizer:
11
  def __init__(self,
12
- model_dir: str = os.path.join("models", "Whisper", "whisperx")
13
  ):
14
  self.device = self.get_device()
15
  self.available_device = self.get_available_device()
 
9
 
10
  class Diarizer:
11
  def __init__(self,
12
+ model_dir: str = os.path.join("models", "Diarization")
13
  ):
14
  self.device = self.get_device()
15
  self.available_device = self.get_available_device()
modules/faster_whisper_inference.py CHANGED
@@ -7,6 +7,7 @@ from faster_whisper.vad import VadOptions
7
  import ctranslate2
8
  import whisper
9
  import gradio as gr
 
10
 
11
  from modules.whisper_parameter import *
12
  from modules.whisper_base import WhisperBase
@@ -15,11 +16,13 @@ from modules.whisper_base import WhisperBase
15
  class FasterWhisperInference(WhisperBase):
16
  def __init__(self,
17
  model_dir: str,
18
- output_dir: str
 
19
  ):
20
  super().__init__(
21
  model_dir=model_dir,
22
- output_dir=output_dir
 
23
  )
24
  self.model_paths = self.get_model_paths()
25
  self.available_models = self.model_paths.keys()
 
7
  import ctranslate2
8
  import whisper
9
  import gradio as gr
10
+ from argparse import Namespace
11
 
12
  from modules.whisper_parameter import *
13
  from modules.whisper_base import WhisperBase
 
16
  class FasterWhisperInference(WhisperBase):
17
  def __init__(self,
18
  model_dir: str,
19
+ output_dir: str,
20
+ args: Namespace
21
  ):
22
  super().__init__(
23
  model_dir=model_dir,
24
+ output_dir=output_dir,
25
+ args=args
26
  )
27
  self.model_paths = self.get_model_paths()
28
  self.available_models = self.model_paths.keys()
modules/insanely_fast_whisper_inference.py CHANGED
@@ -9,6 +9,7 @@ import gradio as gr
9
  from huggingface_hub import hf_hub_download
10
  import whisper
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
 
12
 
13
  from modules.whisper_parameter import *
14
  from modules.whisper_base import WhisperBase
@@ -17,11 +18,13 @@ from modules.whisper_base import WhisperBase
17
  class InsanelyFastWhisperInference(WhisperBase):
18
  def __init__(self,
19
  model_dir: str,
20
- output_dir: str
 
21
  ):
22
  super().__init__(
23
  model_dir=model_dir,
24
- output_dir=output_dir
 
25
  )
26
  openai_models = whisper.available_models()
27
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
 
9
  from huggingface_hub import hf_hub_download
10
  import whisper
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
+ from argparse import Namespace
13
 
14
  from modules.whisper_parameter import *
15
  from modules.whisper_base import WhisperBase
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
  model_dir: str,
21
+ output_dir: str,
22
+ args: Namespace
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
+ output_dir=output_dir,
27
+ args=args
28
  )
29
  openai_models = whisper.available_models()
30
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
modules/whisper_Inference.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
7
  import torch
 
8
 
9
  from modules.whisper_base import WhisperBase
10
  from modules.whisper_parameter import *
@@ -13,11 +14,13 @@ from modules.whisper_parameter import *
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
  model_dir: str,
16
- output_dir: str
 
17
  ):
18
  super().__init__(
19
  model_dir=model_dir,
20
- output_dir=output_dir
 
21
  )
22
 
23
  def transcribe(self,
 
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
7
  import torch
8
+ from argparse import Namespace
9
 
10
  from modules.whisper_base import WhisperBase
11
  from modules.whisper_parameter import *
 
14
  class WhisperInference(WhisperBase):
15
  def __init__(self,
16
  model_dir: str,
17
+ output_dir: str,
18
+ args: Namespace
19
  ):
20
  super().__init__(
21
  model_dir=model_dir,
22
+ output_dir=output_dir,
23
+ args=args
24
  )
25
 
26
  def transcribe(self,
modules/whisper_base.py CHANGED
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
7
  from typing import BinaryIO, Union, Tuple, List
8
  import numpy as np
9
  from datetime import datetime
 
10
  import time
11
 
12
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
@@ -18,7 +19,8 @@ from modules.diarizer import Diarizer
18
  class WhisperBase(ABC):
19
  def __init__(self,
20
  model_dir: str,
21
- output_dir: str
 
22
  ):
23
  self.model = None
24
  self.current_model_size = None
@@ -32,7 +34,9 @@ class WhisperBase(ABC):
32
  self.device = self.get_device()
33
  self.available_compute_types = ["float16", "float32"]
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
- self.diarizer = Diarizer()
 
 
36
 
37
  @abstractmethod
38
  def transcribe(self,
 
7
  from typing import BinaryIO, Union, Tuple, List
8
  import numpy as np
9
  from datetime import datetime
10
+ from argparse import Namespace
11
  import time
12
 
13
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
 
19
  class WhisperBase(ABC):
20
  def __init__(self,
21
  model_dir: str,
22
+ output_dir: str,
23
+ args: Namespace
24
  ):
25
  self.model = None
26
  self.current_model_size = None
 
34
  self.device = self.get_device()
35
  self.available_compute_types = ["float16", "float32"]
36
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
37
+ self.diarizer = Diarizer(
38
+ model_dir=args.diarization_model_dir
39
+ )
40
 
41
  @abstractmethod
42
  def transcribe(self,
modules/whisper_parameter.py CHANGED
@@ -29,6 +29,7 @@ class WhisperParameters:
29
  batch_size: gr.Number
30
  is_diarize: gr.Checkbox
31
  hf_token: gr.Textbox
 
32
  """
33
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
34
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
@@ -131,6 +132,9 @@ class WhisperParameters:
131
  hf_token: gr.Textbox
132
  This parameter is related with whisperx. Huggingface token is needed to download diarization models.
133
  Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
 
 
 
134
  """
135
 
136
  def as_list(self) -> list:
@@ -180,6 +184,7 @@ class WhisperParameters:
180
  batch_size=args[21],
181
  is_diarize=args[22],
182
  hf_token=args[23],
 
183
  )
184
 
185
 
@@ -209,6 +214,7 @@ class WhisperValues:
209
  batch_size: int
210
  is_diarize: bool
211
  hf_token: str
 
212
  """
213
  A data class to use Whisper parameters.
214
  """
 
29
  batch_size: gr.Number
30
  is_diarize: gr.Checkbox
31
  hf_token: gr.Textbox
32
+ diarization_device: gr.Dropdown
33
  """
34
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
35
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
 
132
  hf_token: gr.Textbox
133
  This parameter is related with whisperx. Huggingface token is needed to download diarization models.
134
  Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
135
+
136
+ diarization_device: gr.Dropdown
137
+ This parameter is related with whisperx. Device to run diarization model
138
  """
139
 
140
  def as_list(self) -> list:
 
184
  batch_size=args[21],
185
  is_diarize=args[22],
186
  hf_token=args[23],
187
+ diarization_device=args[24]
188
  )
189
 
190
 
 
214
  batch_size: int
215
  is_diarize: bool
216
  hf_token: str
217
+ diarization_device: str
218
  """
219
  A data class to use Whisper parameters.
220
  """