msaelices commited on
Commit
ab9ec7c
1 Parent(s): ce51475

Allow to customize the whisper model

Browse files
Files changed (2) hide show
  1. app.py +12 -1
  2. engines.py +4 -4
app.py CHANGED
@@ -41,10 +41,14 @@ def main():
41
  batch_size = os.environ.get('PYTORCH_BATCH_SIZE') or st.selectbox(
42
  'Select a batch size:', [4, 8, 16, 32, 64]
43
  )
 
 
 
44
  else:
45
  device = None
46
  compute_type = None
47
  batch_size = None
 
48
 
49
  engine_api_key = os.environ.get(
50
  f'{engine_type.upper()}_API_KEY'
@@ -69,7 +73,14 @@ def main():
69
  if uploaded_audio:
70
  if openai_api_key:
71
  st.markdown('Transcribing the audio...')
72
- engine = get_engine(engine_type, api_key=engine_api_key, device=device, compute_type=compute_type, batch_size=batch_size)
 
 
 
 
 
 
 
73
  transcription = api.transcribe(engine, language, uploaded_audio)
74
 
75
  st.markdown(
 
41
  batch_size = os.environ.get('PYTORCH_BATCH_SIZE') or st.selectbox(
42
  'Select a batch size:', [4, 8, 16, 32, 64]
43
  )
44
+ whisper_model = os.environ.get('WHISPER_MODEL') or st.selectbox(
45
+ 'Select a Whisper model:', ['large-v2', 'base']
46
+ )
47
  else:
48
  device = None
49
  compute_type = None
50
  batch_size = None
51
+ whisper_model = None
52
 
53
  engine_api_key = os.environ.get(
54
  f'{engine_type.upper()}_API_KEY'
 
73
  if uploaded_audio:
74
  if openai_api_key:
75
  st.markdown('Transcribing the audio...')
76
+ engine = get_engine(
77
+ engine_type,
78
+ api_key=engine_api_key,
79
+ device=device,
80
+ compute_type=compute_type,
81
+ batch_size=batch_size,
82
+ whisper_model=whisper_model,
83
+ )
84
  transcription = api.transcribe(engine, language, uploaded_audio)
85
 
86
  st.markdown(
engines.py CHANGED
@@ -57,12 +57,12 @@ class AssemblyAI:
57
 
58
 
59
  class WhisperX:
60
- def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8):
61
  self.api_key = api_key # HuggingFace API key
62
  self.device = device
63
  self.compute_type = compute_type
64
  self.batch_size = batch_size
65
- _setup_whisperx(self.device, self.compute_type)
66
 
67
  def transcribe(self, language, audio_file: BytesIO) -> str:
68
  global _whisperx_model
@@ -113,7 +113,7 @@ _whisperx_model = None
113
  _whisperx_model_a = None
114
  _whisperx_model_a_metadata = None
115
 
116
- def _setup_whisperx(device, compute_type):
117
  global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
118
  if _whisperx_initialized:
119
  return
@@ -123,4 +123,4 @@ def _setup_whisperx(device, compute_type):
123
  dev = torch.device(device)
124
  torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
125
 
126
- _whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type)
 
57
 
58
 
59
  class WhisperX:
60
+ def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8, whisper_model: str = 'large-v2', **kwargs: Any):
61
  self.api_key = api_key # HuggingFace API key
62
  self.device = device
63
  self.compute_type = compute_type
64
  self.batch_size = batch_size
65
+ _setup_whisperx(self.device, self.compute_type, whisper_model=whisper_model)
66
 
67
  def transcribe(self, language, audio_file: BytesIO) -> str:
68
  global _whisperx_model
 
113
  _whisperx_model_a = None
114
  _whisperx_model_a_metadata = None
115
 
116
+ def _setup_whisperx(device, compute_type, whisper_model='large-v2'):
117
  global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
118
  if _whisperx_initialized:
119
  return
 
123
  dev = torch.device(device)
124
  torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
125
 
126
+ _whisperx_model = whisperx.load_model(whisper_model, device, compute_type=compute_type)