note-taker / engines.py
msaelices's picture
Remove support for GoogleCloud, which was not working yet actually. WhisperX is way better IMO
b073205
raw
history blame
No virus
4.38 kB
import gc
import tempfile
from typing import Protocol
from io import BytesIO
import requests
import torch
import whisperx
class TranscriptEngine(Protocol):
"""Protocol for a transcription engine"""
def transcribe(self, language, audio_file: bytes) -> str:
"""transcribe audio file to text"""
...
class AssemblyAI:
transcript = 'https://api.assemblyai.com/v2/transcript'
upload = 'https://api.assemblyai.com/v2/upload'
def __init__(self, api_key: str):
self.api_key = api_key
def transcribe(self, language, audio_file: BytesIO) -> str:
headers = {'authorization': self.api_key, 'content-type': 'application/json'}
upload_response = requests.post(
AssemblyAI.upload, headers=headers, data=audio_file
)
audio_url = upload_response.json()['upload_url']
json = {
'audio_url': audio_url,
'iab_categories': True,
'language_code': language,
'speaker_labels': True,
}
response = requests.post(AssemblyAI.transcript, json=json, headers=headers)
if not response.ok:
# TODO: Handle errors
return response.json()
polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}'
status = 'submitted'
while status != 'completed':
polling_response = requests.get(polling_endpoint, headers=headers)
status = polling_response.json()['status']
return '\n'.join(
f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances']
)
class WhisperX:
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8):
self.api_key = api_key # HuggingFace API key
self.device = device
self.compute_type = compute_type
self.batch_size = batch_size
_setup_whisperx(self.device, self.compute_type)
def transcribe(self, language, audio_file: BytesIO) -> str:
global _whisperx_model
# Write the bytes to a temporary file, load it and transcribe it with original whisper
with tempfile.NamedTemporaryFile() as f:
f.write(audio_file.read())
f.seek(0)
audio = whisperx.load_audio(f.name)
result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del _whisperx_model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device)
result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
whisperx.assign_word_speakers(diarize_segments, result)
# segments are now assigned speaker IDs
for segment in result['segments']:
print(f'{segment["speaker"]}: {segment["text"]}')
return '\n'.join(
f'{segment["speaker"]}: {segment["text"]}' for segment in result['segments']
)
def get_engine(engine_type: str, api_key: str | None) -> TranscriptEngine:
engine_cls = {
'AssemblyAI': AssemblyAI,
'WhisperX': WhisperX,
}[engine_type]
return engine_cls(api_key)
# WhisperX auxiliary functions
_whisperx_initialized = False
_whisperx_model = None
_whisperx_model_a = None
_whisperx_model_a_metadata = None
def _setup_whisperx(device, compute_type):
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
if _whisperx_initialized:
return
s = 32
dev = torch.device('cuda')
# Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
_whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type)