aadnk commited on
Commit
33ee1bb
1 Parent(s): 2698c96

Support CLI into faster-whisper

Browse files
app.py CHANGED
@@ -126,7 +126,8 @@ class WhisperTranscriber:
126
  selectedModel = modelName if modelName is not None else "base"
127
 
128
  model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
129
- model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
 
130
 
131
  # Result
132
  download = []
@@ -518,6 +519,8 @@ if __name__ == '__main__':
518
  help="directory to save the outputs")
519
  parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
520
  help="the Whisper implementation to use")
 
 
521
 
522
  args = parser.parse_args().__dict__
523
 
 
126
  selectedModel = modelName if modelName is not None else "base"
127
 
128
  model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
129
+ model_name=selectedModel, compute_type=self.app_config.compute_type,
130
+ cache=self.model_cache, models=self.app_config.models)
131
 
132
  # Result
133
  download = []
 
519
  help="directory to save the outputs")
520
  parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
521
  help="the Whisper implementation to use")
522
+ parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["int8", "int8_float16", "int16", "float16"], \
523
+ help="the compute type to use for inference")
524
 
525
  args = parser.parse_args().__dict__
526
 
cli.py CHANGED
@@ -80,6 +80,8 @@ def cli():
80
  help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
81
  parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
82
  help="whether to perform inference in fp16; True by default")
 
 
83
 
84
  parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
85
  help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
@@ -119,12 +121,14 @@ def cli():
119
  vad_cpu_cores = args.pop("vad_cpu_cores")
120
  auto_parallel = args.pop("auto_parallel")
121
 
 
 
122
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
123
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
124
  transcriber.set_auto_parallel(auto_parallel)
125
 
126
- model = create_whisper_container(whisper_implementation=whisper_implementation,
127
- device=device, download_root=model_dir, models=app_config.models)
128
 
129
  if (transcriber._has_parallel_devices()):
130
  print("Using parallel devices:", transcriber.parallel_device_list)
 
80
  help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
81
  parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
82
  help="whether to perform inference in fp16; True by default")
83
+ parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["int8", "int8_float16", "int16", "float16"], \
84
+ help="the compute type to use for inference")
85
 
86
  parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
87
  help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
 
121
  vad_cpu_cores = args.pop("vad_cpu_cores")
122
  auto_parallel = args.pop("auto_parallel")
123
 
124
+ compute_type = args.pop("compute_type")
125
+
126
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
127
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
128
  transcriber.set_auto_parallel(auto_parallel)
129
 
130
+ model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
131
+ device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
132
 
133
  if (transcriber._has_parallel_devices()):
134
  print("Using parallel devices:", transcriber.parallel_device_list)
config.json5 CHANGED
@@ -104,7 +104,7 @@
104
  // Number of beams in beam search, only applicable when temperature is zero
105
  "beam_size": 5,
106
  // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
107
- "patience": null,
108
  // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
109
  "length_penalty": null,
110
  // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
@@ -115,6 +115,8 @@
115
  "condition_on_previous_text": true,
116
  // Whether to perform inference in fp16; True by default
117
  "fp16": true,
 
 
118
  // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
119
  "temperature_increment_on_fallback": 0.2,
120
  // If the gzip compression ratio is higher than this value, treat the decoding as failed
 
104
  // Number of beams in beam search, only applicable when temperature is zero
105
  "beam_size": 5,
106
  // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
107
+ "patience": 1,
108
  // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
109
  "length_penalty": null,
110
  // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
 
115
  "condition_on_previous_text": true,
116
  // Whether to perform inference in fp16; True by default
117
  "fp16": true,
118
+ // The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
119
+ "compute_type": "float16",
120
  // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
121
  "temperature_increment_on_fallback": 0.2,
122
  // If the gzip compression ratio is higher than this value, treat the decoding as failed
src/config.py CHANGED
@@ -39,12 +39,10 @@ class ApplicationConfig:
39
  patience: float = None, length_penalty: float = None,
40
  suppress_tokens: str = "-1", initial_prompt: str = None,
41
  condition_on_previous_text: bool = True, fp16: bool = True,
 
42
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
43
  logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
44
 
45
- if device is None:
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
-
48
  self.models = models
49
 
50
  # WebUI settings
@@ -82,6 +80,7 @@ class ApplicationConfig:
82
  self.initial_prompt = initial_prompt
83
  self.condition_on_previous_text = condition_on_previous_text
84
  self.fp16 = fp16
 
85
  self.temperature_increment_on_fallback = temperature_increment_on_fallback
86
  self.compression_ratio_threshold = compression_ratio_threshold
87
  self.logprob_threshold = logprob_threshold
 
39
  patience: float = None, length_penalty: float = None,
40
  suppress_tokens: str = "-1", initial_prompt: str = None,
41
  condition_on_previous_text: bool = True, fp16: bool = True,
42
+ compute_type: str = "float16",
43
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
44
  logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
45
 
 
 
 
46
  self.models = models
47
 
48
  # WebUI settings
 
80
  self.initial_prompt = initial_prompt
81
  self.condition_on_previous_text = condition_on_previous_text
82
  self.fp16 = fp16
83
+ self.compute_type = compute_type
84
  self.temperature_increment_on_fallback = temperature_increment_on_fallback
85
  self.compression_ratio_threshold = compression_ratio_threshold
86
  self.logprob_threshold = logprob_threshold
src/whisper/abstractWhisperContainer.py CHANGED
@@ -33,10 +33,12 @@ class AbstractWhisperCallback:
33
  return prompt1 + " " + prompt2
34
 
35
  class AbstractWhisperContainer:
36
- def __init__(self, model_name: str, device: str = None, download_root: str = None,
37
- cache: ModelCache = None, models: List[ModelConfig] = []):
 
38
  self.model_name = model_name
39
  self.device = device
 
40
  self.download_root = download_root
41
  self.cache = cache
42
 
@@ -87,13 +89,20 @@ class AbstractWhisperContainer:
87
 
88
  # This is required for multiprocessing
89
  def __getstate__(self):
90
- return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
 
 
 
 
 
 
91
 
92
  def __setstate__(self, state):
93
  self.model_name = state["model_name"]
94
  self.device = state["device"]
95
  self.download_root = state["download_root"]
96
  self.models = state["models"]
 
97
  self.model = None
98
  # Depickled objects must use the global cache
99
  self.cache = GLOBAL_MODEL_CACHE
 
33
  return prompt1 + " " + prompt2
34
 
35
  class AbstractWhisperContainer:
36
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
37
+ download_root: str = None,
38
+ cache: ModelCache = None, models: List[ModelConfig] = []):
39
  self.model_name = model_name
40
  self.device = device
41
+ self.compute_type = compute_type
42
  self.download_root = download_root
43
  self.cache = cache
44
 
 
89
 
90
  # This is required for multiprocessing
91
  def __getstate__(self):
92
+ return {
93
+ "model_name": self.model_name,
94
+ "device": self.device,
95
+ "download_root": self.download_root,
96
+ "models": self.models,
97
+ "compute_type": self.compute_type
98
+ }
99
 
100
  def __setstate__(self, state):
101
  self.model_name = state["model_name"]
102
  self.device = state["device"]
103
  self.download_root = state["download_root"]
104
  self.models = state["models"]
105
+ self.compute_type = state["compute_type"]
106
  self.model = None
107
  # Depickled objects must use the global cache
108
  self.cache = GLOBAL_MODEL_CACHE
src/whisper/fasterWhisperContainer.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import List
3
 
4
  from faster_whisper import WhisperModel, download_model
5
  from src.config import ModelConfig
@@ -8,10 +8,10 @@ from src.modelCache import ModelCache
8
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
9
 
10
  class FasterWhisperContainer(AbstractWhisperContainer):
11
- def __init__(self, model_name: str, device: str = None, download_root: str = None,
12
- cache: ModelCache = None,
13
- models: List[ModelConfig] = []):
14
- super().__init__(model_name, device, download_root, cache, models)
15
 
16
  def ensure_downloaded(self):
17
  """
@@ -35,7 +35,7 @@ class FasterWhisperContainer(AbstractWhisperContainer):
35
  return None
36
 
37
  def _create_model(self):
38
- print("Loading faster whisper model " + self.model_name)
39
  model_config = self._get_model_config()
40
 
41
  if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
@@ -46,7 +46,7 @@ class FasterWhisperContainer(AbstractWhisperContainer):
46
  if (device is None):
47
  device = "auto"
48
 
49
- model = WhisperModel(model_config.url, device=device, compute_type="float16")
50
  return model
51
 
52
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
@@ -96,10 +96,33 @@ class FasterWhisperCallback(AbstractWhisperCallback):
96
  model: WhisperModel = self.model_container.get_model()
97
  language_code = self._lookup_language_code(self.language) if self.language else None
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  segments_generator, info = model.transcribe(audio, \
100
  language=language_code if language_code else detected_language, task=self.task, \
101
  initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
102
- **self.decodeOptions
103
  )
104
 
105
  segments = []
@@ -109,6 +132,8 @@ class FasterWhisperCallback(AbstractWhisperCallback):
109
 
110
  if progress_listener is not None:
111
  progress_listener.on_progress(segment.end, info.duration)
 
 
112
 
113
  text = " ".join([segment.text for segment in segments])
114
 
@@ -141,6 +166,14 @@ class FasterWhisperCallback(AbstractWhisperCallback):
141
  progress_listener.on_finished()
142
  return result
143
 
 
 
 
 
 
 
 
 
144
  def _lookup_language_code(self, language: str):
145
  lookup = {
146
  "english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
 
1
  import os
2
+ from typing import List, Union
3
 
4
  from faster_whisper import WhisperModel, download_model
5
  from src.config import ModelConfig
 
8
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
9
 
10
  class FasterWhisperContainer(AbstractWhisperContainer):
11
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
12
+ download_root: str = None,
13
+ cache: ModelCache = None, models: List[ModelConfig] = []):
14
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
15
 
16
  def ensure_downloaded(self):
17
  """
 
35
  return None
36
 
37
  def _create_model(self):
38
+ print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
39
  model_config = self._get_model_config()
40
 
41
  if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
 
46
  if (device is None):
47
  device = "auto"
48
 
49
+ model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
50
  return model
51
 
52
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
 
96
  model: WhisperModel = self.model_container.get_model()
97
  language_code = self._lookup_language_code(self.language) if self.language else None
98
 
99
+ # Copy decode options and remove options that are not supported by faster-whisper
100
+ decodeOptions = self.decodeOptions.copy()
101
+ verbose = decodeOptions.pop("verbose", None)
102
+
103
+ logprob_threshold = decodeOptions.pop("logprob_threshold", None)
104
+
105
+ patience = decodeOptions.pop("patience", None)
106
+ length_penalty = decodeOptions.pop("length_penalty", None)
107
+ suppress_tokens = decodeOptions.pop("suppress_tokens", None)
108
+
109
+ if (decodeOptions.pop("fp16", None) is not None):
110
+ print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
111
+
112
+ # Fix up decode options
113
+ if (logprob_threshold is not None):
114
+ decodeOptions["log_prob_threshold"] = logprob_threshold
115
+
116
+ decodeOptions["patience"] = float(patience) if patience is not None else 1.0
117
+ decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
118
+
119
+ # See if supress_tokens is a string - if so, convert it to a list of ints
120
+ decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
121
+
122
  segments_generator, info = model.transcribe(audio, \
123
  language=language_code if language_code else detected_language, task=self.task, \
124
  initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
125
+ **decodeOptions
126
  )
127
 
128
  segments = []
 
132
 
133
  if progress_listener is not None:
134
  progress_listener.on_progress(segment.end, info.duration)
135
+ if verbose:
136
+ print(segment.text)
137
 
138
  text = " ".join([segment.text for segment in segments])
139
 
 
166
  progress_listener.on_finished()
167
  return result
168
 
169
+ def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
170
+ if (suppress_tokens is None):
171
+ return None
172
+ if (isinstance(suppress_tokens, list)):
173
+ return suppress_tokens
174
+
175
+ return [int(token) for token in suppress_tokens.split(",")]
176
+
177
  def _lookup_language_code(self, language: str):
178
  lookup = {
179
  "english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
src/whisper/whisperContainer.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import sys
5
  from typing import List
6
  from urllib.parse import urlparse
 
7
  import urllib3
8
  from src.hooks.progressListener import ProgressListener
9
 
@@ -18,9 +19,12 @@ from src.utils import download_file
18
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
19
 
20
  class WhisperContainer(AbstractWhisperContainer):
21
- def __init__(self, model_name: str, device: str = None, download_root: str = None,
22
- cache: ModelCache = None, models: List[ModelConfig] = []):
23
- super().__init__(model_name, device, download_root, cache, models)
 
 
 
24
 
25
  def ensure_downloaded(self):
26
  """
@@ -184,8 +188,14 @@ class WhisperCallback(AbstractWhisperCallback):
184
  return self._transcribe(model, audio, segment_index, prompt, detected_language)
185
 
186
  def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
 
 
 
 
 
 
187
  return model.transcribe(audio, \
188
  language=self.language if self.language else detected_language, task=self.task, \
189
  initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
190
- **self.decodeOptions
191
  )
 
4
  import sys
5
  from typing import List
6
  from urllib.parse import urlparse
7
+ import torch
8
  import urllib3
9
  from src.hooks.progressListener import ProgressListener
10
 
 
19
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
20
 
21
  class WhisperContainer(AbstractWhisperContainer):
22
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
23
+ download_root: str = None,
24
+ cache: ModelCache = None, models: List[ModelConfig] = []):
25
+ if device is None:
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
28
 
29
  def ensure_downloaded(self):
30
  """
 
188
  return self._transcribe(model, audio, segment_index, prompt, detected_language)
189
 
190
  def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
191
+ decodeOptions = self.decodeOptions.copy()
192
+
193
+ # Add fp16
194
+ if self.model_container.compute_type in ["fp16", "float16"]:
195
+ decodeOptions["fp16"] = True
196
+
197
  return model.transcribe(audio, \
198
  language=self.language if self.language else detected_language, task=self.task, \
199
  initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
200
+ **decodeOptions
201
  )
src/whisper/whisperFactory.py CHANGED
@@ -4,15 +4,16 @@ from src.config import ModelConfig
4
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
5
 
6
  def create_whisper_container(whisper_implementation: str,
7
- model_name: str, device: str = None, download_root: str = None,
 
8
  cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
9
  print("Creating whisper container for " + whisper_implementation)
10
 
11
  if (whisper_implementation == "whisper"):
12
  from src.whisper.whisperContainer import WhisperContainer
13
- return WhisperContainer(model_name, device, download_root, cache, models)
14
  elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
15
  from src.whisper.fasterWhisperContainer import FasterWhisperContainer
16
- return FasterWhisperContainer(model_name, device, download_root, cache, models)
17
  else:
18
  raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
 
4
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
5
 
6
  def create_whisper_container(whisper_implementation: str,
7
+ model_name: str, device: str = None, compute_type: str = "float16",
8
+ download_root: str = None,
9
  cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
10
  print("Creating whisper container for " + whisper_implementation)
11
 
12
  if (whisper_implementation == "whisper"):
13
  from src.whisper.whisperContainer import WhisperContainer
14
+ return WhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
15
  elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
  from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
+ return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
18
  else:
19
  raise ValueError("Unknown Whisper implementation: " + whisper_implementation)