File size: 9,048 Bytes
95261ed
295de00
7fd072f
33a2c1e
44d964a
295de00
33ee1bb
295de00
 
33a2c1e
95261ed
33a2c1e
 
8031785
295de00
95261ed
c0e541b
74b7d77
295de00
 
31f7bdb
295de00
33ee1bb
 
 
 
 
 
95261ed
7c5d37e
 
 
 
 
 
 
7fd072f
295de00
7fd072f
 
 
 
7c5d37e
7fd072f
44d964a
 
 
7c5d37e
44d964a
7c5d37e
 
 
 
 
295de00
44d964a
 
 
 
 
 
 
 
c0e541b
 
295de00
 
44d964a
295de00
44d964a
 
c0e541b
74b7d77
 
8031785
95261ed
 
 
 
 
 
 
 
 
74b7d77
 
95261ed
 
 
 
 
 
 
74b7d77
95261ed
295de00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95261ed
295de00
 
95261ed
295de00
95261ed
295de00
74b7d77
 
 
95261ed
 
 
74b7d77
 
95261ed
 
33a2c1e
95261ed
 
 
 
 
 
 
 
 
 
 
295de00
 
95261ed
 
 
33a2c1e
 
 
 
 
 
 
33ee1bb
 
 
 
 
 
74b7d77
 
8031785
f55c594
33a2c1e
8031785
33ee1bb
f55c594
74b7d77
 
 
 
 
f55c594
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# External programs
import abc
import os
import sys
from typing import List
from urllib.parse import urlparse
import torch
import urllib3
from src.hooks.progressListener import ProgressListener

import whisper
from whisper import Whisper

from src.config import ModelConfig, VadInitialPromptMode
from src.hooks.whisperProgressHook import create_progress_listener_handle

from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
from src.utils import download_file
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer

class WhisperContainer(AbstractWhisperContainer):
    def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
                 download_root: str = None,
                 cache: ModelCache = None, models: List[ModelConfig] = []):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        super().__init__(model_name, device, compute_type, download_root, cache, models)
    
    def ensure_downloaded(self):
        """
        Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
        passing the container to a subprocess.
        """
        # Warning: Using private API here
        try:
            root_dir = self.download_root
            model_config = self._get_model_config()

            if root_dir is None:
                root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")

            if self.model_name in whisper._MODELS:
                whisper._download(whisper._MODELS[self.model_name], root_dir, False)
            else:
                # If the model is not in the official list, see if it needs to be downloaded
                model_config.download_url(root_dir)
            return True
        
        except Exception as e:
            # Given that the API is private, it could change at any time. We don't want to crash the program
            print("Error pre-downloading model: " + str(e))
            return False

    def _get_model_config(self) -> ModelConfig:
        """
        Get the model configuration for the model.
        """
        for model in self.models:
            if model.name == self.model_name:
                return model
        return None

    def _create_model(self):
        print("Loading whisper model " + self.model_name)
        model_config = self._get_model_config()

        # Note that the model will not be downloaded in the case of an official Whisper model
        model_path = self._get_model_path(model_config, self.download_root)

        return whisper.load_model(model_path, device=self.device, download_root=self.download_root)

    def create_callback(self, language: str = None, task: str = None, 
                        prompt_strategy: AbstractPromptStrategy = None,
                        **decodeOptions: dict) -> AbstractWhisperCallback:
        """
        Create a WhisperCallback object that can be used to transcript audio files.

        Parameters
        ----------
        language: str
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        prompt_strategy: AbstractPromptStrategy
            The prompt strategy to use. If not specified, the prompt from Whisper will be used.
        decodeOptions: dict
            Additional options to pass to the decoder. Must be pickleable.

        Returns
        -------
        A WhisperCallback object.
        """
        return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)

    def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
        from src.conversion.hf_converter import convert_hf_whisper
        """
        Download the model.

        Parameters
        ----------
        model_config: ModelConfig
            The model configuration.
        """
        # See if path is already set
        if model_config.path is not None:
            return model_config.path
        
        if root_dir is None:
            root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")

        model_type = model_config.type.lower() if model_config.type is not None else "whisper"

        if model_type in ["huggingface", "hf"]:
            model_config.path = model_config.url
            destination_target = os.path.join(root_dir, model_config.name + ".pt")

            # Convert from HuggingFace format to Whisper format
            if os.path.exists(destination_target):
                print(f"File {destination_target} already exists, skipping conversion")
            else:
                print("Saving HuggingFace model in Whisper format to " + destination_target)
                convert_hf_whisper(model_config.url, destination_target)

            model_config.path = destination_target

        elif model_type in ["whisper", "w"]:
            model_config.path = model_config.url

            # See if URL is just a file
            if model_config.url in whisper._MODELS:
                # No need to download anything - Whisper will handle it
                model_config.path = model_config.url
            elif model_config.url.startswith("file://"):
                # Get file path
                model_config.path = urlparse(model_config.url).path
            # See if it is an URL
            elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
                # Extension (or file name)
                extension = os.path.splitext(model_config.url)[-1]
                download_target = os.path.join(root_dir, model_config.name + extension)

                if os.path.exists(download_target) and not os.path.isfile(download_target):
                    raise RuntimeError(f"{download_target} exists and is not a regular file")

                if not os.path.isfile(download_target):
                    download_file(model_config.url, download_target)
                else:
                    print(f"File {download_target} already exists, skipping download")

                model_config.path = download_target
            # Must be a local file
            else:
                model_config.path = model_config.url

        else:
            raise ValueError(f"Unknown model type {model_type}")

        return model_config.path

class WhisperCallback(AbstractWhisperCallback):
    def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, 
                 prompt_strategy: AbstractPromptStrategy = None, 
                 **decodeOptions: dict):
        self.model_container = model_container
        self.language = language
        self.task = task
        self.prompt_strategy = prompt_strategy

        self.decodeOptions = decodeOptions
        
    def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
        """
        Peform the transcription of the given audio file or data.

        Parameters
        ----------
        audio: Union[str, np.ndarray, torch.Tensor]
            The audio file to transcribe, or the audio data as a numpy array or torch tensor.
        segment_index: int
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        progress_listener: ProgressListener
            A callback to receive progress updates.
        """
        model = self.model_container.get_model()

        if progress_listener is not None:
            with create_progress_listener_handle(progress_listener):
                return self._transcribe(model, audio, segment_index, prompt, detected_language)
        else:
            return self._transcribe(model, audio, segment_index, prompt, detected_language)
    
    def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
        decodeOptions = self.decodeOptions.copy()

        # Add fp16
        if self.model_container.compute_type in ["fp16", "float16"]:
            decodeOptions["fp16"] = True

        initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
                           if self.prompt_strategy else prompt

        result = model.transcribe(audio, \
            language=self.language if self.language else detected_language, task=self.task, \
            initial_prompt=initial_prompt, \
            **decodeOptions
        )

        # If we have a prompt strategy, we need to increment the current prompt
        if self.prompt_strategy:
            self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)

        return result