Spaces:
Runtime error
Runtime error
Make batch_size configurable for WhisperX
Browse files- app.py +5 -1
- engines.py +1 -1
app.py
CHANGED
@@ -38,9 +38,13 @@ def main():
|
|
38 |
compute_type = os.environ.get('PYTORCH_COMPUTE_TYPE') or st.selectbox(
|
39 |
'Select a compute type:', ['int8', 'float16']
|
40 |
)
|
|
|
|
|
|
|
41 |
else:
|
42 |
device = None
|
43 |
compute_type = None
|
|
|
44 |
|
45 |
engine_api_key = os.environ.get(
|
46 |
f'{engine_type.upper()}_API_KEY'
|
@@ -62,7 +66,7 @@ def main():
|
|
62 |
if uploaded_audio:
|
63 |
if openai_api_key:
|
64 |
st.markdown('Transcribing the audio...')
|
65 |
-
engine = get_engine(engine_type, api_key=engine_api_key, device=device, compute_type=compute_type)
|
66 |
transcription = api.transcribe(engine, language, uploaded_audio)
|
67 |
|
68 |
st.markdown(
|
|
|
38 |
compute_type = os.environ.get('PYTORCH_COMPUTE_TYPE') or st.selectbox(
|
39 |
'Select a compute type:', ['int8', 'float16']
|
40 |
)
|
41 |
+
batch_size = os.environ.get('PYTORCH_BATCH_SIZE') or st.selectbox(
|
42 |
+
'Select a batch size:', [4, 8, 16, 32, 64]
|
43 |
+
)
|
44 |
else:
|
45 |
device = None
|
46 |
compute_type = None
|
47 |
+
batch_size = None
|
48 |
|
49 |
engine_api_key = os.environ.get(
|
50 |
f'{engine_type.upper()}_API_KEY'
|
|
|
66 |
if uploaded_audio:
|
67 |
if openai_api_key:
|
68 |
st.markdown('Transcribing the audio...')
|
69 |
+
engine = get_engine(engine_type, api_key=engine_api_key, device=device, compute_type=compute_type, batch_size=batch_size)
|
70 |
transcription = api.transcribe(engine, language, uploaded_audio)
|
71 |
|
72 |
st.markdown(
|
engines.py
CHANGED
@@ -20,7 +20,7 @@ class AssemblyAI:
|
|
20 |
transcript = 'https://api.assemblyai.com/v2/transcript'
|
21 |
upload = 'https://api.assemblyai.com/v2/upload'
|
22 |
|
23 |
-
def __init__(self, api_key: str):
|
24 |
self.api_key = api_key
|
25 |
|
26 |
def transcribe(self, language, audio_file: BytesIO) -> str:
|
|
|
20 |
transcript = 'https://api.assemblyai.com/v2/transcript'
|
21 |
upload = 'https://api.assemblyai.com/v2/upload'
|
22 |
|
23 |
+
def __init__(self, api_key: str, **kwargs: Any):
|
24 |
self.api_key = api_key
|
25 |
|
26 |
def transcribe(self, language, audio_file: BytesIO) -> str:
|