note-taker / engines.py
msaelices's picture
Allow to customize the whisper model
ab9ec7c
raw
history blame
4.45 kB
import gc
import tempfile
from typing import Any, Protocol
from io import BytesIO
import requests
import torch
import whisperx
class TranscriptEngine(Protocol):
"""Protocol for a transcription engine"""
def transcribe(self, language, audio_file: bytes) -> str:
"""transcribe audio file to text"""
...
class AssemblyAI:
transcript = 'https://api.assemblyai.com/v2/transcript'
upload = 'https://api.assemblyai.com/v2/upload'
def __init__(self, api_key: str, **kwargs: Any):
self.api_key = api_key
def transcribe(self, language, audio_file: BytesIO) -> str:
headers = {'authorization': self.api_key, 'content-type': 'application/json'}
upload_response = requests.post(
AssemblyAI.upload, headers=headers, data=audio_file
)
audio_url = upload_response.json()['upload_url']
json = {
'audio_url': audio_url,
'iab_categories': True,
'language_code': language,
'speaker_labels': True,
}
response = requests.post(AssemblyAI.transcript, json=json, headers=headers)
if not response.ok:
# TODO: Handle errors
return response.json()
polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}'
status = 'submitted'
while status != 'completed':
polling_response = requests.get(polling_endpoint, headers=headers)
status = polling_response.json()['status']
return '\n'.join(
f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances']
)
class WhisperX:
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8, whisper_model: str = 'large-v2', **kwargs: Any):
self.api_key = api_key # HuggingFace API key
self.device = device
self.compute_type = compute_type
self.batch_size = batch_size
_setup_whisperx(self.device, self.compute_type, whisper_model=whisper_model)
def transcribe(self, language, audio_file: BytesIO) -> str:
global _whisperx_model
# Write the bytes to a temporary file, load it and transcribe it with original whisper
with tempfile.NamedTemporaryFile() as f:
f.write(audio_file.read())
f.seek(0)
audio = whisperx.load_audio(f.name)
result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del _whisperx_model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device)
result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
whisperx.assign_word_speakers(diarize_segments, result)
# segments are now assigned speaker IDs
return '\n'.join(
f'{segment.get("speaker", "UNKNOWN")}: {segment["text"]}' for segment in result['segments']
)
def get_engine(engine_type: str, **kwargs: Any) -> TranscriptEngine:
engine_cls = {
'AssemblyAI': AssemblyAI,
'WhisperX': WhisperX,
}[engine_type]
return engine_cls(**kwargs)
# WhisperX auxiliary functions
_whisperx_initialized = False
_whisperx_model = None
_whisperx_model_a = None
_whisperx_model_a_metadata = None
def _setup_whisperx(device, compute_type, whisper_model='large-v2'):
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
if _whisperx_initialized:
return
s = 32
if device == 'cuda':
# Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch
dev = torch.device(device)
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
_whisperx_model = whisperx.load_model(whisper_model, device, compute_type=compute_type)