Spaces:
Runtime error
Runtime error
File size: 4,384 Bytes
c298c87 8fe64f7 c298c87 8fe64f7 4a59798 8fe64f7 c298c87 e83878b c298c87 8fe64f7 c298c87 8fe64f7 c298c87 fad63af c298c87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gc
import tempfile
from typing import Protocol
from io import BytesIO
import requests
import torch
import whisperx
class TranscriptEngine(Protocol):
"""Protocol for a transcription engine"""
def transcribe(self, language, audio_file: bytes) -> str:
"""transcribe audio file to text"""
...
class AssemblyAI:
transcript = 'https://api.assemblyai.com/v2/transcript'
upload = 'https://api.assemblyai.com/v2/upload'
def __init__(self, api_key: str):
self.api_key = api_key
def transcribe(self, language, audio_file: BytesIO) -> str:
headers = {'authorization': self.api_key, 'content-type': 'application/json'}
upload_response = requests.post(
AssemblyAI.upload, headers=headers, data=audio_file
)
audio_url = upload_response.json()['upload_url']
json = {
'audio_url': audio_url,
'iab_categories': True,
'language_code': language,
'speaker_labels': True,
}
response = requests.post(AssemblyAI.transcript, json=json, headers=headers)
if not response.ok:
# TODO: Handle errors
return response.json()
polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}'
status = 'submitted'
while status != 'completed':
polling_response = requests.get(polling_endpoint, headers=headers)
status = polling_response.json()['status']
return '\n'.join(
f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances']
)
class WhisperX:
def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8):
self.api_key = api_key # HuggingFace API key
self.device = device
self.compute_type = compute_type
self.batch_size = batch_size
_setup_whisperx(self.device, self.compute_type)
def transcribe(self, language, audio_file: BytesIO) -> str:
global _whisperx_model
# Write the bytes to a temporary file, load it and transcribe it with original whisper
with tempfile.NamedTemporaryFile() as f:
f.write(audio_file.read())
f.seek(0)
audio = whisperx.load_audio(f.name)
result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del _whisperx_model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device)
result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False)
# Delete model to prevent memory errors low on GPU resources
gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
whisperx.assign_word_speakers(diarize_segments, result)
# segments are now assigned speaker IDs
for segment in result['segments']:
print(f'{segment["speaker"]}: {segment["text"]}')
return '\n'.join(
f'{segment["speaker"]}: {segment["text"]}' for segment in result['segments']
)
def get_engine(engine_type: str, api_key: str | None) -> TranscriptEngine:
engine_cls = {
'AssemblyAI': AssemblyAI,
'WhisperX': WhisperX,
}[engine_type]
return engine_cls(api_key)
# WhisperX auxiliary functions
_whisperx_initialized = False
_whisperx_model = None
_whisperx_model_a = None
_whisperx_model_a_metadata = None
def _setup_whisperx(device, compute_type):
global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
if _whisperx_initialized:
return
s = 32
dev = torch.device('cuda')
# Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))
_whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type)
|