note-taker / engines.py
msaelices's picture
Implement transcript with speakers in AssemblyAI
4a59798
raw
history blame
5.28 kB
import gc
import tempfile
from typing import Protocol
from io import BytesIO
import requests
import torch
import whisperx
from google.cloud import speech_v2 as speech
class TranscriptEngine(Protocol):
"""Protocol for a transcription engine"""
def transcribe(self, language, audio_file: bytes) -> str:
"""transcribe audio file to text"""
...
class AssemblyAI:
transcript = 'https://api.assemblyai.com/v2/transcript'
upload = 'https://api.assemblyai.com/v2/upload'
def __init__(self, api_key: str):
self.api_key = api_key
def transcribe(self, language, audio_file: BytesIO) -> str:
headers = {'authorization': self.api_key, 'content-type': 'application/json'}
upload_response = requests.post(
AssemblyAI.upload, headers=headers, data=audio_file
)
audio_url = upload_response.json()['upload_url']
json = {
'audio_url': audio_url,
'iab_categories': True,
'language_code': language,
'speaker_labels': True,
}
response = requests.post(AssemblyAI.transcript, json=json, headers=headers)
if not response.ok:
# TODO: Handle errors
return response.json()
polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}'
status = 'submitted'
while status != 'completed':
polling_response = requests.get(polling_endpoint, headers=headers)
status = polling_response.json()['status']
return '\n'.join(
f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances']
)
class GoogleCloud:
def __init__(self, api_key: str):
pass # do not need an API key for Google Cloud
def transcribe(self, language, audio_file: BytesIO) -> str:
client = speech.SpeechClient()
audio = speech.RecognitionAudio(content=audio_file.read())
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED,
language_code=language,
diarization_config=speech.SpeakerDiarizationConfig(
enable_speaker_diarization=True,
),
)
operation = client.long_running_recognize(config=config, audio=audio)
response = operation.result()
return ' '.join(
result.alternatives[0].transcript for result in response.results
)
class WhisperX:
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8):
self.api_key = api_key # HuggingFace API key
self.device = device
self.compute_type = compute_type
self.batch_size = batch_size
_setup_whisperx(self.device, self.compute_type)
def transcribe(self, language, audio_file: BytesIO) -> str:
global _whisperx_model
# Write the bytes to a temporary file, load it and transcribe it with original whisper
with tempfile.NamedTemporaryFile() as f:
f.write(audio_file.read())
f.seek(0)
audio = whisperx.load_audio(f.name)
result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del _whisperx_model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device)
result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
whisperx.assign_word_speakers(diarize_segments, result)
# segments are now assigned speaker IDs
for segment in result['segments']:
print(f'{segment["speaker"]}: {segment["text"]}')
return '\n'.join(
f'{segment["speaker"]}: {segment["text"]}' for segment in result['segments']
)
def get_engine(engine_type: str, api_key: str | None) -> TranscriptEngine:
engine_cls = {
'AssemblyAI': AssemblyAI,
'Google': GoogleCloud,
'WhisperX': WhisperX,
}[engine_type]
return engine_cls(api_key)
# WhisperX auxiliary functions
_whisperx_initialized = False
_whisperx_model = None
_whisperx_model_a = None
_whisperx_model_a_metadata = None
def _setup_whisperx(device, compute_type):
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
if _whisperx_initialized:
return
s = 32
dev = torch.device('cuda')
# Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
_whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type)