msaelices commited on
Commit
43b181d
1 Parent(s): bb08fe0

Make batch_size configurable for WhisperX

Browse files
Files changed (2) hide show
  1. app.py +5 -1
  2. engines.py +1 -1
app.py CHANGED
@@ -38,9 +38,13 @@ def main():
38
  compute_type = os.environ.get('PYTORCH_COMPUTE_TYPE') or st.selectbox(
39
  'Select a compute type:', ['int8', 'float16']
40
  )
 
 
 
41
  else:
42
  device = None
43
  compute_type = None
 
44
 
45
  engine_api_key = os.environ.get(
46
  f'{engine_type.upper()}_API_KEY'
@@ -62,7 +66,7 @@ def main():
62
  if uploaded_audio:
63
  if openai_api_key:
64
  st.markdown('Transcribing the audio...')
65
- engine = get_engine(engine_type, api_key=engine_api_key, device=device, compute_type=compute_type)
66
  transcription = api.transcribe(engine, language, uploaded_audio)
67
 
68
  st.markdown(
 
38
  compute_type = os.environ.get('PYTORCH_COMPUTE_TYPE') or st.selectbox(
39
  'Select a compute type:', ['int8', 'float16']
40
  )
41
+ batch_size = os.environ.get('PYTORCH_BATCH_SIZE') or st.selectbox(
42
+ 'Select a batch size:', [4, 8, 16, 32, 64]
43
+ )
44
  else:
45
  device = None
46
  compute_type = None
47
+ batch_size = None
48
 
49
  engine_api_key = os.environ.get(
50
  f'{engine_type.upper()}_API_KEY'
 
66
  if uploaded_audio:
67
  if openai_api_key:
68
  st.markdown('Transcribing the audio...')
69
+ engine = get_engine(engine_type, api_key=engine_api_key, device=device, compute_type=compute_type, batch_size=batch_size)
70
  transcription = api.transcribe(engine, language, uploaded_audio)
71
 
72
  st.markdown(
engines.py CHANGED
@@ -20,7 +20,7 @@ class AssemblyAI:
20
  transcript = 'https://api.assemblyai.com/v2/transcript'
21
  upload = 'https://api.assemblyai.com/v2/upload'
22
 
23
- def __init__(self, api_key: str):
24
  self.api_key = api_key
25
 
26
  def transcribe(self, language, audio_file: BytesIO) -> str:
 
20
  transcript = 'https://api.assemblyai.com/v2/transcript'
21
  upload = 'https://api.assemblyai.com/v2/upload'
22
 
23
+ def __init__(self, api_key: str, **kwargs: Any):
24
  self.api_key = api_key
25
 
26
  def transcribe(self, language, audio_file: BytesIO) -> str: