import gc import tempfile from typing import Protocol from io import BytesIO import requests import torch import whisperx class TranscriptEngine(Protocol): """Protocol for a transcription engine""" def transcribe(self, language, audio_file: bytes) -> str: """transcribe audio file to text""" ... class AssemblyAI: transcript = 'https://api.assemblyai.com/v2/transcript' upload = 'https://api.assemblyai.com/v2/upload' def __init__(self, api_key: str): self.api_key = api_key def transcribe(self, language, audio_file: BytesIO) -> str: headers = {'authorization': self.api_key, 'content-type': 'application/json'} upload_response = requests.post( AssemblyAI.upload, headers=headers, data=audio_file ) audio_url = upload_response.json()['upload_url'] json = { 'audio_url': audio_url, 'iab_categories': True, 'language_code': language, 'speaker_labels': True, } response = requests.post(AssemblyAI.transcript, json=json, headers=headers) if not response.ok: # TODO: Handle errors return response.json() polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}' status = 'submitted' while status != 'completed': polling_response = requests.get(polling_endpoint, headers=headers) status = polling_response.json()['status'] return '\n'.join( f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances'] ) class WhisperX: def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8): self.api_key = api_key # HuggingFace API key self.device = device self.compute_type = compute_type self.batch_size = batch_size _setup_whisperx(self.device, self.compute_type) def transcribe(self, language, audio_file: BytesIO) -> str: global _whisperx_model # Write the bytes to a temporary file, load it and transcribe it with original whisper with tempfile.NamedTemporaryFile() as f: f.write(audio_file.read()) f.seek(0) audio = whisperx.load_audio(f.name) result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language) # Delete model to prevent memory errors low on GPU resources gc.collect(); torch.cuda.empty_cache(); del _whisperx_model # 2. Align whisper output model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device) result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False) # Delete model to prevent memory errors low on GPU resources gc.collect(); torch.cuda.empty_cache(); del model_a # 3. Assign speaker labels diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device) # add min/max number of speakers if known diarize_segments = diarize_model(audio) whisperx.assign_word_speakers(diarize_segments, result) # segments are now assigned speaker IDs for segment in result['segments']: print(f'{segment["speaker"]}: {segment["text"]}') return '\n'.join( f'{segment["speaker"]}: {segment["text"]}' for segment in result['segments'] ) def get_engine(engine_type: str, api_key: str | None) -> TranscriptEngine: engine_cls = { 'AssemblyAI': AssemblyAI, 'WhisperX': WhisperX, }[engine_type] return engine_cls(api_key) # WhisperX auxiliary functions _whisperx_initialized = False _whisperx_model = None _whisperx_model_a = None _whisperx_model_a_metadata = None def _setup_whisperx(device, compute_type): global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata if _whisperx_initialized: return s = 32 dev = torch.device('cuda') # Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev)) _whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type)