File size: 4,384 Bytes
c298c87
 
8fe64f7
 
 
 
c298c87
 
8fe64f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a59798
 
 
8fe64f7
 
c298c87
e83878b
c298c87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fe64f7
 
 
c298c87
8fe64f7
 
 
c298c87
 
fad63af
 
c298c87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gc 
import tempfile
from typing import Protocol
from io import BytesIO

import requests
import torch
import whisperx


class TranscriptEngine(Protocol):
    """Protocol for a transcription engine"""

    def transcribe(self, language, audio_file: bytes) -> str:
        """transcribe audio file to text"""
        ...


class AssemblyAI:
    transcript = 'https://api.assemblyai.com/v2/transcript'
    upload = 'https://api.assemblyai.com/v2/upload'

    def __init__(self, api_key: str):
        self.api_key = api_key

    def transcribe(self, language, audio_file: BytesIO) -> str:
        headers = {'authorization': self.api_key, 'content-type': 'application/json'}
        upload_response = requests.post(
            AssemblyAI.upload, headers=headers, data=audio_file
        )

        audio_url = upload_response.json()['upload_url']

        json = {
            'audio_url': audio_url,
            'iab_categories': True,
            'language_code': language,
            'speaker_labels': True,
        }

        response = requests.post(AssemblyAI.transcript, json=json, headers=headers)

        if not response.ok:
            # TODO: Handle errors
            return response.json()

        polling_endpoint = f'{AssemblyAI.transcript}/{response.json()["id"]}'

        status = 'submitted'
        while status != 'completed':
            polling_response = requests.get(polling_endpoint, headers=headers)
            status = polling_response.json()['status']

        return '\n'.join(
            f'{utterance["speaker"]}: {utterance["text"]}' for utterance in polling_response.json()['utterances']
        )


class WhisperX:
    def __init__(self, api_key: str, device: str = 'cuda', compute_type: str = 'int8', batch_size: int = 8):
        self.api_key = api_key  # HuggingFace API key
        self.device = device
        self.compute_type = compute_type
        self.batch_size = batch_size
        _setup_whisperx(self.device, self.compute_type)

    def transcribe(self, language, audio_file: BytesIO) -> str:
        global _whisperx_model
        # Write the bytes to a temporary file, load it and transcribe it with original whisper
        with tempfile.NamedTemporaryFile() as f:
            f.write(audio_file.read())
            f.seek(0)
            audio = whisperx.load_audio(f.name)
        result = _whisperx_model.transcribe(audio, batch_size=self.batch_size, language=language)

        # Delete model to prevent memory errors low on GPU resources
        gc.collect(); torch.cuda.empty_cache(); del _whisperx_model

        # 2. Align whisper output
        model_a, metadata = whisperx.load_align_model(language_code=language, device=self.device)
        result = whisperx.align(result['segments'], model_a, metadata, audio, self.device, return_char_alignments=False)

        # Delete model to prevent memory errors low on GPU resources
        gc.collect(); torch.cuda.empty_cache(); del model_a

        # 3. Assign speaker labels
        diarize_model = whisperx.DiarizationPipeline(use_auth_token=self.api_key, device=self.device)

        # add min/max number of speakers if known
        diarize_segments = diarize_model(audio)

        whisperx.assign_word_speakers(diarize_segments, result)

        # segments are now assigned speaker IDs
        for segment in result['segments']:
            print(f'{segment["speaker"]}: {segment["text"]}')
        return '\n'.join(
            f'{segment["speaker"]}: {segment["text"]}' for segment in result['segments']
        )


def get_engine(engine_type: str, api_key: str | None) -> TranscriptEngine:
    engine_cls = {
        'AssemblyAI': AssemblyAI,
        'WhisperX': WhisperX,
    }[engine_type]

    return engine_cls(api_key)


# WhisperX auxiliary functions

_whisperx_initialized = False
_whisperx_model = None
_whisperx_model_a = None
_whisperx_model_a_metadata = None

def _setup_whisperx(device, compute_type):
    global _whisperx_initialized, _whisperx_model, _whisperx_model_a, _whisperx_model_a_metadata
    if _whisperx_initialized:
        return
    s = 32
    dev = torch.device('cuda')
    # Prevent CUDNN_STATUS_NOT_INITIALIZED RuntimeError using pytorch
    torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))

    _whisperx_model = whisperx.load_model('large-v2', device, compute_type=compute_type)