jhj0517 commited on
Commit
00efe30
Β·
1 Parent(s): f962fd9

add compute_type dropdown

Browse files
app.py CHANGED
@@ -59,6 +59,7 @@ class App:
59
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
60
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
61
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
62
  with gr.Row():
63
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
64
  with gr.Row():
@@ -66,7 +67,7 @@ class App:
66
  btn_openfolder = gr.Button('πŸ“‚', scale=2)
67
 
68
  params = [input_file, dd_model, dd_lang, dd_subformat, cb_translate, cb_timestamp]
69
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold]
70
  btn_run.click(fn=self.whisper_inf.transcribe_file,
71
  inputs=params + advanced_params,
72
  outputs=[tb_indicator])
@@ -97,6 +98,7 @@ class App:
97
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
98
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
99
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
100
  with gr.Row():
101
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
102
  with gr.Row():
@@ -104,7 +106,7 @@ class App:
104
  btn_openfolder = gr.Button('πŸ“‚', scale=2)
105
 
106
  params = [tb_youtubelink, dd_model, dd_lang, dd_subformat, cb_translate, cb_timestamp]
107
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold]
108
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
109
  inputs=params + advanced_params,
110
  outputs=[tb_indicator])
@@ -128,6 +130,7 @@ class App:
128
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
129
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
130
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
131
  with gr.Row():
132
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
133
  with gr.Row():
@@ -135,7 +138,7 @@ class App:
135
  btn_openfolder = gr.Button('πŸ“‚', scale=2)
136
 
137
  params = [mic_input, dd_model, dd_lang, dd_subformat, cb_translate]
138
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold]
139
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
140
  inputs=params + advanced_params,
141
  outputs=[tb_indicator])
 
59
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
60
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
61
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
62
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
63
  with gr.Row():
64
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
65
  with gr.Row():
 
67
  btn_openfolder = gr.Button('πŸ“‚', scale=2)
68
 
69
  params = [input_file, dd_model, dd_lang, dd_subformat, cb_translate, cb_timestamp]
70
+ advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
71
  btn_run.click(fn=self.whisper_inf.transcribe_file,
72
  inputs=params + advanced_params,
73
  outputs=[tb_indicator])
 
98
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
99
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
100
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
101
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
102
  with gr.Row():
103
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
104
  with gr.Row():
 
106
  btn_openfolder = gr.Button('πŸ“‚', scale=2)
107
 
108
  params = [tb_youtubelink, dd_model, dd_lang, dd_subformat, cb_translate, cb_timestamp]
109
+ advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
110
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
111
  inputs=params + advanced_params,
112
  outputs=[tb_indicator])
 
130
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
131
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
132
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
133
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
134
  with gr.Row():
135
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
136
  with gr.Row():
 
138
  btn_openfolder = gr.Button('πŸ“‚', scale=2)
139
 
140
  params = [mic_input, dd_model, dd_lang, dd_subformat, cb_translate]
141
+ advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
142
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
143
  inputs=params + advanced_params,
144
  outputs=[tb_indicator])
modules/faster_whisper_inference.py CHANGED
@@ -24,9 +24,10 @@ class FasterWhisperInference(BaseInterface):
24
  self.available_models = whisper.available_models()
25
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
26
  self.translatable_models = ["large", "large-v1", "large-v2"]
27
- self.default_beam_size = 1
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
- self.compute_type = "float16" if self.device == "cuda" else "float32"
 
 
30
 
31
  def transcribe_file(self,
32
  fileobjs: list,
@@ -38,6 +39,7 @@ class FasterWhisperInference(BaseInterface):
38
  beam_size: int,
39
  log_prob_threshold: float,
40
  no_speech_threshold: float,
 
41
  progress=gr.Progress()
42
  ) -> str:
43
  """
@@ -67,6 +69,9 @@ class FasterWhisperInference(BaseInterface):
67
  float value from gr.Number(). If the no_speech probability is higher than this value AND
68
  the average log probability over sampled tokens is below `log_prob_threshold`,
69
  consider the segment as silent.
 
 
 
70
  progress: gr.Progress
71
  Indicator to show progress directly in gradio.
72
 
@@ -75,8 +80,7 @@ class FasterWhisperInference(BaseInterface):
75
  String to return to gr.Textbox()
76
  """
77
  try:
78
- if model_size != self.current_model_size or self.model is None:
79
- self.initialize_model(model_size=model_size, progress=progress)
80
 
81
  if lang == "Automatic Detection":
82
  lang = None
@@ -129,6 +133,7 @@ class FasterWhisperInference(BaseInterface):
129
  beam_size: int,
130
  log_prob_threshold: float,
131
  no_speech_threshold: float,
 
132
  progress=gr.Progress()
133
  ) -> str:
134
  """
@@ -158,6 +163,9 @@ class FasterWhisperInference(BaseInterface):
158
  float value from gr.Number(). If the no_speech probability is higher than this value AND
159
  the average log probability over sampled tokens is below `log_prob_threshold`,
160
  consider the segment as silent.
 
 
 
161
  progress: gr.Progress
162
  Indicator to show progress directly in gradio.
163
 
@@ -166,8 +174,7 @@ class FasterWhisperInference(BaseInterface):
166
  String to return to gr.Textbox()
167
  """
168
  try:
169
- if model_size != self.current_model_size or self.model is None:
170
- self.initialize_model(model_size=model_size, progress=progress)
171
 
172
  if lang == "Automatic Detection":
173
  lang = None
@@ -220,6 +227,7 @@ class FasterWhisperInference(BaseInterface):
220
  beam_size: int,
221
  log_prob_threshold: float,
222
  no_speech_threshold: float,
 
223
  progress=gr.Progress()
224
  ) -> str:
225
  """
@@ -246,6 +254,9 @@ class FasterWhisperInference(BaseInterface):
246
  no_speech_threshold: float
247
  float value from gr.Number(). If the no_speech probability is higher than this value AND
248
  the average log probability over sampled tokens is below `log_prob_threshold`,
 
 
 
249
  consider the segment as silent.
250
  progress: gr.Progress
251
  Indicator to show progress directly in gradio.
@@ -255,8 +266,7 @@ class FasterWhisperInference(BaseInterface):
255
  String to return to gr.Textbox()
256
  """
257
  try:
258
- if model_size != self.current_model_size or self.model is None:
259
- self.initialize_model(model_size=model_size, progress=progress)
260
 
261
  if lang == "Automatic Detection":
262
  lang = None
@@ -353,21 +363,24 @@ class FasterWhisperInference(BaseInterface):
353
  elapsed_time = time.time() - start_time
354
  return segments_result, elapsed_time
355
 
356
- def initialize_model(self,
357
- model_size: str,
358
- progress: gr.Progress
359
- ):
 
360
  """
361
- Initialize model if it doesn't match with current model size
362
  """
363
- progress(0, desc="Initializing Model..")
364
- self.current_model_size = model_size
365
- self.model = faster_whisper.WhisperModel(
366
- device=self.device,
367
- model_size_or_path=model_size,
368
- download_root=os.path.join("models", "Whisper", "faster-whisper"),
369
- compute_type=self.compute_type
370
- )
 
 
371
 
372
  @staticmethod
373
  def generate_and_write_subtitle(file_name: str,
 
24
  self.available_models = whisper.available_models()
25
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
26
  self.translatable_models = ["large", "large-v1", "large-v2"]
 
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ self.available_compute_types = ["int8", "int8_float32", "int8_float16", "int8_bfloat16", "int16", "float16", "bfloat16", "float32"]
29
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
30
+ self.default_beam_size = 1
31
 
32
  def transcribe_file(self,
33
  fileobjs: list,
 
39
  beam_size: int,
40
  log_prob_threshold: float,
41
  no_speech_threshold: float,
42
+ compute_type: str,
43
  progress=gr.Progress()
44
  ) -> str:
45
  """
 
69
  float value from gr.Number(). If the no_speech probability is higher than this value AND
70
  the average log probability over sampled tokens is below `log_prob_threshold`,
71
  consider the segment as silent.
72
+ compute_type: str
73
+ compute type from gr.Dropdown().
74
+ see more info : https://opennmt.net/CTranslate2/quantization.html
75
  progress: gr.Progress
76
  Indicator to show progress directly in gradio.
77
 
 
80
  String to return to gr.Textbox()
81
  """
82
  try:
83
+ self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
 
84
 
85
  if lang == "Automatic Detection":
86
  lang = None
 
133
  beam_size: int,
134
  log_prob_threshold: float,
135
  no_speech_threshold: float,
136
+ compute_type: str,
137
  progress=gr.Progress()
138
  ) -> str:
139
  """
 
163
  float value from gr.Number(). If the no_speech probability is higher than this value AND
164
  the average log probability over sampled tokens is below `log_prob_threshold`,
165
  consider the segment as silent.
166
+ compute_type: str
167
+ compute type from gr.Dropdown().
168
+ see more info : https://opennmt.net/CTranslate2/quantization.html
169
  progress: gr.Progress
170
  Indicator to show progress directly in gradio.
171
 
 
174
  String to return to gr.Textbox()
175
  """
176
  try:
177
+ self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
 
178
 
179
  if lang == "Automatic Detection":
180
  lang = None
 
227
  beam_size: int,
228
  log_prob_threshold: float,
229
  no_speech_threshold: float,
230
+ compute_type: str,
231
  progress=gr.Progress()
232
  ) -> str:
233
  """
 
254
  no_speech_threshold: float
255
  float value from gr.Number(). If the no_speech probability is higher than this value AND
256
  the average log probability over sampled tokens is below `log_prob_threshold`,
257
+ compute_type: str
258
+ compute type from gr.Dropdown().
259
+ see more info : https://opennmt.net/CTranslate2/quantization.html
260
  consider the segment as silent.
261
  progress: gr.Progress
262
  Indicator to show progress directly in gradio.
 
266
  String to return to gr.Textbox()
267
  """
268
  try:
269
+ self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
 
270
 
271
  if lang == "Automatic Detection":
272
  lang = None
 
363
  elapsed_time = time.time() - start_time
364
  return segments_result, elapsed_time
365
 
366
+ def update_model_if_needed(self,
367
+ model_size: str,
368
+ compute_type: str,
369
+ progress: gr.Progress
370
+ ):
371
  """
372
+ Initialize model if it doesn't match with current model setting
373
  """
374
+ if model_size != self.current_model_size or self.model is None or self.current_compute_type != compute_type:
375
+ progress(0, desc="Initializing Model..")
376
+ self.current_model_size = model_size
377
+ self.current_compute_type = compute_type
378
+ self.model = faster_whisper.WhisperModel(
379
+ device=self.device,
380
+ model_size_or_path=model_size,
381
+ download_root=os.path.join("models", "Whisper", "faster-whisper"),
382
+ compute_type=self.current_compute_type
383
+ )
384
 
385
  @staticmethod
386
  def generate_and_write_subtitle(file_name: str,
modules/whisper_Inference.py CHANGED
@@ -22,6 +22,8 @@ class WhisperInference(BaseInterface):
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
25
  self.default_beam_size = 1
26
 
27
  def transcribe_file(self,
@@ -34,6 +36,7 @@ class WhisperInference(BaseInterface):
34
  beam_size: int,
35
  log_prob_threshold: float,
36
  no_speech_threshold: float,
 
37
  progress=gr.Progress()):
38
  """
39
  Write subtitle file from Files
@@ -62,14 +65,15 @@ class WhisperInference(BaseInterface):
62
  float value from gr.Number(). If the no_speech probability is higher than this value AND
63
  the average log probability over sampled tokens is below `log_prob_threshold`,
64
  consider the segment as silent.
 
 
65
  progress: gr.Progress
66
  Indicator to show progress directly in gradio.
67
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
68
  """
69
 
70
  try:
71
- if model_size != self.current_model_size or self.model is None:
72
- self.initialize_model(model_size=model_size, progress=progress)
73
 
74
  files_info = {}
75
  for fileobj in fileobjs:
@@ -82,7 +86,9 @@ class WhisperInference(BaseInterface):
82
  beam_size=beam_size,
83
  log_prob_threshold=log_prob_threshold,
84
  no_speech_threshold=no_speech_threshold,
85
- progress=progress)
 
 
86
  progress(1, desc="Completed!")
87
 
88
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
@@ -122,6 +128,7 @@ class WhisperInference(BaseInterface):
122
  beam_size: int,
123
  log_prob_threshold: float,
124
  no_speech_threshold: float,
 
125
  progress=gr.Progress()):
126
  """
127
  Write subtitle file from Youtube
@@ -150,13 +157,14 @@ class WhisperInference(BaseInterface):
150
  float value from gr.Number(). If the no_speech probability is higher than this value AND
151
  the average log probability over sampled tokens is below `log_prob_threshold`,
152
  consider the segment as silent.
 
 
153
  progress: gr.Progress
154
  Indicator to show progress directly in gradio.
155
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
156
  """
157
  try:
158
- if model_size != self.current_model_size or self.model is None:
159
- self.initialize_model(model_size=model_size, progress=progress)
160
 
161
  progress(0, desc="Loading Audio from Youtube..")
162
  yt = get_ytdata(youtubelink)
@@ -168,6 +176,7 @@ class WhisperInference(BaseInterface):
168
  beam_size=beam_size,
169
  log_prob_threshold=log_prob_threshold,
170
  no_speech_threshold=no_speech_threshold,
 
171
  progress=progress)
172
  progress(1, desc="Completed!")
173
 
@@ -205,6 +214,7 @@ class WhisperInference(BaseInterface):
205
  beam_size: int,
206
  log_prob_threshold: float,
207
  no_speech_threshold: float,
 
208
  progress=gr.Progress()):
209
  """
210
  Write subtitle file from microphone
@@ -231,14 +241,15 @@ class WhisperInference(BaseInterface):
231
  float value from gr.Number(). If the no_speech probability is higher than this value AND
232
  the average log probability over sampled tokens is below `log_prob_threshold`,
233
  consider the segment as silent.
 
 
234
  progress: gr.Progress
235
  Indicator to show progress directly in gradio.
236
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
237
  """
238
 
239
  try:
240
- if model_size != self.current_model_size or self.model is None:
241
- self.initialize_model(model_size=model_size, progress=progress)
242
 
243
  result, elapsed_time = self.transcribe(audio=micaudio,
244
  lang=lang,
@@ -246,6 +257,7 @@ class WhisperInference(BaseInterface):
246
  beam_size=beam_size,
247
  log_prob_threshold=log_prob_threshold,
248
  no_speech_threshold=no_speech_threshold,
 
249
  progress=progress)
250
  progress(1, desc="Completed!")
251
 
@@ -271,6 +283,7 @@ class WhisperInference(BaseInterface):
271
  beam_size: int,
272
  log_prob_threshold: float,
273
  no_speech_threshold: float,
 
274
  progress: gr.Progress
275
  ) -> Tuple[list[dict], float]:
276
  """
@@ -294,6 +307,8 @@ class WhisperInference(BaseInterface):
294
  float value from gr.Number(). If the no_speech probability is higher than this value AND
295
  the average log probability over sampled tokens is below `log_prob_threshold`,
296
  consider the segment as silent.
 
 
297
  progress: gr.Progress
298
  Indicator to show progress directly in gradio.
299
 
@@ -320,21 +335,30 @@ class WhisperInference(BaseInterface):
320
  logprob_threshold=log_prob_threshold,
321
  no_speech_threshold=no_speech_threshold,
322
  task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
 
323
  progress_callback=progress_callback)["segments"]
324
  elapsed_time = time.time() - start_time
325
 
326
  return segments_result, elapsed_time
327
 
328
- def initialize_model(self,
329
- model_size: str,
330
- progress: gr.Progress
331
- ):
 
332
  """
333
- Initialize model if it doesn't match with current model size
334
  """
335
- progress(0, desc="Initializing Model..")
336
- self.current_model_size = model_size
337
- self.model = whisper.load_model(name=model_size, download_root=os.path.join("models", "Whisper"))
 
 
 
 
 
 
 
338
 
339
  @staticmethod
340
  def generate_and_write_subtitle(file_name: str,
 
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ self.available_compute_types = ["float16", "float32"]
26
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
27
  self.default_beam_size = 1
28
 
29
  def transcribe_file(self,
 
36
  beam_size: int,
37
  log_prob_threshold: float,
38
  no_speech_threshold: float,
39
+ compute_type: str,
40
  progress=gr.Progress()):
41
  """
42
  Write subtitle file from Files
 
65
  float value from gr.Number(). If the no_speech probability is higher than this value AND
66
  the average log probability over sampled tokens is below `log_prob_threshold`,
67
  consider the segment as silent.
68
+ compute_type: str
69
+ compute type from gr.Dropdown().
70
  progress: gr.Progress
71
  Indicator to show progress directly in gradio.
72
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
73
  """
74
 
75
  try:
76
+ self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
 
77
 
78
  files_info = {}
79
  for fileobj in fileobjs:
 
86
  beam_size=beam_size,
87
  log_prob_threshold=log_prob_threshold,
88
  no_speech_threshold=no_speech_threshold,
89
+ compute_type=compute_type,
90
+ progress=progress
91
+ )
92
  progress(1, desc="Completed!")
93
 
94
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
 
128
  beam_size: int,
129
  log_prob_threshold: float,
130
  no_speech_threshold: float,
131
+ compute_type: str,
132
  progress=gr.Progress()):
133
  """
134
  Write subtitle file from Youtube
 
157
  float value from gr.Number(). If the no_speech probability is higher than this value AND
158
  the average log probability over sampled tokens is below `log_prob_threshold`,
159
  consider the segment as silent.
160
+ compute_type: str
161
+ compute type from gr.Dropdown().
162
  progress: gr.Progress
163
  Indicator to show progress directly in gradio.
164
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
165
  """
166
  try:
167
+ self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
 
168
 
169
  progress(0, desc="Loading Audio from Youtube..")
170
  yt = get_ytdata(youtubelink)
 
176
  beam_size=beam_size,
177
  log_prob_threshold=log_prob_threshold,
178
  no_speech_threshold=no_speech_threshold,
179
+ compute_type=compute_type,
180
  progress=progress)
181
  progress(1, desc="Completed!")
182
 
 
214
  beam_size: int,
215
  log_prob_threshold: float,
216
  no_speech_threshold: float,
217
+ compute_type: str,
218
  progress=gr.Progress()):
219
  """
220
  Write subtitle file from microphone
 
241
  float value from gr.Number(). If the no_speech probability is higher than this value AND
242
  the average log probability over sampled tokens is below `log_prob_threshold`,
243
  consider the segment as silent.
244
+ compute_type: str
245
+ compute type from gr.Dropdown().
246
  progress: gr.Progress
247
  Indicator to show progress directly in gradio.
248
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
249
  """
250
 
251
  try:
252
+ self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
 
253
 
254
  result, elapsed_time = self.transcribe(audio=micaudio,
255
  lang=lang,
 
257
  beam_size=beam_size,
258
  log_prob_threshold=log_prob_threshold,
259
  no_speech_threshold=no_speech_threshold,
260
+ compute_type=compute_type,
261
  progress=progress)
262
  progress(1, desc="Completed!")
263
 
 
283
  beam_size: int,
284
  log_prob_threshold: float,
285
  no_speech_threshold: float,
286
+ compute_type: str,
287
  progress: gr.Progress
288
  ) -> Tuple[list[dict], float]:
289
  """
 
307
  float value from gr.Number(). If the no_speech probability is higher than this value AND
308
  the average log probability over sampled tokens is below `log_prob_threshold`,
309
  consider the segment as silent.
310
+ compute_type: str
311
+ compute type from gr.Dropdown().
312
  progress: gr.Progress
313
  Indicator to show progress directly in gradio.
314
 
 
335
  logprob_threshold=log_prob_threshold,
336
  no_speech_threshold=no_speech_threshold,
337
  task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
338
+ fp16=True if compute_type == "float16" else False,
339
  progress_callback=progress_callback)["segments"]
340
  elapsed_time = time.time() - start_time
341
 
342
  return segments_result, elapsed_time
343
 
344
+ def update_model_if_needed(self,
345
+ model_size: str,
346
+ compute_type: str,
347
+ progress: gr.Progress,
348
+ ):
349
  """
350
+ Initialize model if it doesn't match with current model setting
351
  """
352
+ if compute_type != self.current_compute_type:
353
+ self.current_compute_type = compute_type
354
+ if model_size != self.current_model_size or self.model is None:
355
+ progress(0, desc="Initializing Model..")
356
+ self.current_model_size = model_size
357
+ self.model = whisper.load_model(
358
+ name=model_size,
359
+ device=self.device,
360
+ download_root=os.path.join("models", "Whisper")
361
+ )
362
 
363
  @staticmethod
364
  def generate_and_write_subtitle(file_name: str,