Spaces:
Runtime error
Runtime error
Allow to customize the whisper model
Browse files- app.py +12 -1
- engines.py +4 -4
app.py
CHANGED
@@ -41,10 +41,14 @@ def main():
|
|
41 |
batch_size = os.environ.get('PYTORCH_BATCH_SIZE') or st.selectbox(
|
42 |
'Select a batch size:', [4, 8, 16, 32, 64]
|
43 |
)
|
|
|
|
|
|
|
44 |
else:
|
45 |
device = None
|
46 |
compute_type = None
|
47 |
batch_size = None
|
|
|
48 |
|
49 |
engine_api_key = os.environ.get(
|
50 |
f'{engine_type.upper()}_API_KEY'
|
@@ -69,7 +73,14 @@ def main():
|
|
69 |
if uploaded_audio:
|
70 |
if openai_api_key:
|
71 |
st.markdown('Transcribing the audio...')
|
72 |
-
engine = get_engine(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
transcription = api.transcribe(engine, language, uploaded_audio)
|
74 |
|
75 |
st.markdown(
|
|
|
41 |
batch_size = os.environ.get('PYTORCH_BATCH_SIZE') or st.selectbox(
|
42 |
'Select a batch size:', [4, 8, 16, 32, 64]
|
43 |
)
|
44 |
+
whisper_model = os.environ.get('WHISPER_MODEL') or st.selectbox(
|
45 |
+
'Select a Whisper model:', ['large-v2', 'base']
|
46 |
+
)
|
47 |
else:
|
48 |
device = None
|
49 |
compute_type = None
|
50 |
batch_size = None
|
51 |
+
whisper_model = None
|
52 |
|
53 |
engine_api_key = os.environ.get(
|
54 |
f'{engine_type.upper()}_API_KEY'
|
|
|
73 |
if uploaded_audio:
|
74 |
if openai_api_key:
|
75 |
st.markdown('Transcribing the audio...')
|
76 |
+
engine = get_engine(
|
77 |
+
engine_type,
|
78 |
+
api_key=engine_api_key,
|
79 |
+
device=device,
|
80 |
+
compute_type=compute_type,
|
81 |
+
batch_size=batch_size,
|
82 |
+
whisper_model=whisper_model,
|
83 |
+
)
|
84 |
transcription = api.transcribe(engine, language, uploaded_audio)
|
85 |
|
86 |
st.markdown(
|
engines.py
CHANGED
@@ -57,12 +57,12 @@ class AssemblyAI:
|
|
57 |
|
58 |
|
59 |
class WhisperX:
|
60 |
-
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8):
|
61 |
self.api_key = api_key # HuggingFace API key
|
62 |
self.device = device
|
63 |
self.compute_type = compute_type
|
64 |
self.batch_size = batch_size
|
65 |
-
_setup_whisperx(self.device, self.compute_type)
|
66 |
|
67 |
def transcribe(self, language, audio_file: BytesIO) -> str:
|
68 |
global _whisperx_model
|
@@ -113,7 +113,7 @@ _whisperx_model = None
|
|
113 |
_whisperx_model_a = None
|
114 |
_whisperx_model_a_metadata = None
|
115 |
|
116 |
-
def _setup_whisperx(device, compute_type):
|
117 |
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
|
118 |
if _whisperx_initialized:
|
119 |
return
|
@@ -123,4 +123,4 @@ def _setup_whisperx(device, compute_type):
|
|
123 |
dev = torch.device(device)
|
124 |
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
|
125 |
|
126 |
-
_whisperx_model = whisperx.load_model(
|
|
|
57 |
|
58 |
|
59 |
class WhisperX:
|
60 |
+
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8, whisper_model: str = 'large-v2', **kwargs: Any):
|
61 |
self.api_key = api_key # HuggingFace API key
|
62 |
self.device = device
|
63 |
self.compute_type = compute_type
|
64 |
self.batch_size = batch_size
|
65 |
+
_setup_whisperx(self.device, self.compute_type, whisper_model=whisper_model)
|
66 |
|
67 |
def transcribe(self, language, audio_file: BytesIO) -> str:
|
68 |
global _whisperx_model
|
|
|
113 |
_whisperx_model_a = None
|
114 |
_whisperx_model_a_metadata = None
|
115 |
|
116 |
+
def _setup_whisperx(device, compute_type, whisper_model='large-v2'):
|
117 |
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
|
118 |
if _whisperx_initialized:
|
119 |
return
|
|
|
123 |
dev = torch.device(device)
|
124 |
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
|
125 |
|
126 |
+
_whisperx_model = whisperx.load_model(whisper_model, device, compute_type=compute_type)
|