Spaces:
Runtime error
Runtime error
import gc | |
import tempfile | |
from typing import Protocol | |
from io import BytesIO | |
import requests | |
import torch | |
import whisperx | |
from google.cloud import speech_v2 as speech | |
class TranscriptEngine(Protocol): | |
"""Protocol for a transcription engine""" | |
def transcribe(self, language, audio_file: bytes) -> str: | |
"""transcribe audio file to text""" | |
... | |
class AssemblyAI: | |
transcript = 'https://api.assemblyai.com/v2/transcript' | |
upload = 'https://api.assemblyai.com/v2/upload' | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
def transcribe(self, language, audio_file: BytesIO) -> str: | |
headers = {'authorization': self.api_key, 'content-type': 'application/json'} | |
upload_response = requests.post( | |
AssemblyAI.upload, headers=headers, data=audio_file | |
) | |
audio_url = upload_response.json()['upload_url'] | |
json = { | |
'audio_url': audio_url, | |
'iab_categories': True, | |
'language_code': language, | |
'speaker_labels': True, | |
} | |
response = requests.post(AssemblyAI.transcript, json=json, headers=headers) | |
if not response.ok: | |
# TODO: Handle errors | |
return response.json() | |
polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}' | |
status = 'submitted' | |
while status != 'completed': | |
polling_response = requests.get(polling_endpoint, headers=headers) | |
status = polling_response.json()['status'] | |
return '\n'.join( | |
f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances'] | |
) | |
class GoogleCloud: | |
def __init__(self, api_key: str): | |
pass # do not need an API key for Google Cloud | |
def transcribe(self, language, audio_file: BytesIO) -> str: | |
client = speech.SpeechClient() | |
audio = speech.RecognitionAudio(content=audio_file.read()) | |
config = speech.RecognitionConfig( | |
encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED, | |
language_code=language, | |
diarization_config=speech.SpeakerDiarizationConfig( | |
enable_speaker_diarization=True, | |
), | |
) | |
operation = client.long_running_recognize(config=config, audio=audio) | |
response = operation.result() | |
return ' '.join( | |
result.alternatives[0].transcript for result in response.results | |
) | |
class WhisperX: | |
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8): | |
self.api_key = api_key # HuggingFace API key | |
self.device = device | |
self.compute_type = compute_type | |
self.batch_size = batch_size | |
_setup_whisperx(self.device, self.compute_type) | |
def transcribe(self, language, audio_file: BytesIO) -> str: | |
global _whisperx_model | |
# Write the bytes to a temporary file, load it and transcribe it with original whisper | |
with tempfile.NamedTemporaryFile() as f: | |
f.write(audio_file.read()) | |
f.seek(0) | |
audio = whisperx.load_audio(f.name) | |
result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language) | |
# Delete model to prevent memory errors low on GPU resources | |
gc.collect(); torch.cuda.empty_cache(); del _whisperx_model | |
# 2. Align whisper output | |
model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device) | |
result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False) | |
# Delete model to prevent memory errors low on GPU resources | |
gc.collect(); torch.cuda.empty_cache(); del model_a | |
# 3. Assign speaker labels | |
diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device) | |
# add min/max number of speakers if known | |
diarize_segments = diarize_model(audio) | |
whisperx.assign_word_speakers(diarize_segments, result) | |
# segments are now assigned speaker IDs | |
for segment in result['segments']: | |
print(f'{segment["speaker"]}: {segment["text"]}') | |
return '\n'.join( | |
f'{segment["speaker"]}: {segment["text"]}' for segment in result['segments'] | |
) | |
def get_engine(engine_type: str, api_key: str | None) -> TranscriptEngine: | |
engine_cls = { | |
'AssemblyAI': AssemblyAI, | |
'Google': GoogleCloud, | |
'WhisperX': WhisperX, | |
}[engine_type] | |
return engine_cls(api_key) | |
# WhisperX auxiliary functions | |
_whisperx_initialized = False | |
_whisperx_model = None | |
_whisperx_model_a = None | |
_whisperx_model_a_metadata = None | |
def _setup_whisperx(device, compute_type): | |
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata | |
if _whisperx_initialized: | |
return | |
s = 32 | |
dev = torch.device('cuda') | |
# Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch | |
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev)) | |
_whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type) | |