File size: 4,309 Bytes
295de00
052fe7e
74b7d77
8031785
295de00
 
 
74b7d77
295de00
 
74b7d77
052fe7e
74b7d77
295de00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052fe7e
 
 
 
 
 
 
 
295de00
33ee1bb
 
 
295de00
 
33ee1bb
295de00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e744c4
74b7d77
8031785
295de00
 
 
 
 
1e744c4
 
295de00
 
74b7d77
 
295de00
 
 
 
 
 
 
 
 
 
 
33ee1bb
 
 
 
 
 
 
295de00
 
 
 
 
 
33ee1bb
295de00
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import abc
from typing import Any, Callable, List

from src.config import ModelConfig, VadInitialPromptMode

from src.hooks.progressListener import ProgressListener
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy

class AbstractWhisperCallback:
    def __init__(self):
        pass

    @abc.abstractmethod
    def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
        """
        Peform the transcription of the given audio file or data.

        Parameters
        ----------
        audio: Union[str, np.ndarray, torch.Tensor]
            The audio file to transcribe, or the audio data as a numpy array or torch tensor.
        segment_index: int
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        progress_listener: ProgressListener
            A callback to receive progress updates.
        """
        raise NotImplementedError()

class LambdaWhisperCallback(AbstractWhisperCallback):
    def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
        super().__init__()
        self.callback_lambda = callback_lambda

    def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
        return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)

class AbstractWhisperContainer:
    def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
                 download_root: str = None,
                 cache: ModelCache = None, models: List[ModelConfig] = []):
        self.model_name = model_name
        self.device = device
        self.compute_type = compute_type
        self.download_root = download_root
        self.cache = cache

        # Will be created on demand
        self.model = None

        # List of known models
        self.models = models
    
    def get_model(self):
        if self.model is None:

            if (self.cache is None):
                self.model = self._create_model()
            else:
                model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
                self.model = self.cache.get(model_key, self._create_model)
        return self.model
    
    @abc.abstractmethod
    def _create_model(self):
        raise NotImplementedError()

    def ensure_downloaded(self):
        pass

    @abc.abstractmethod
    def create_callback(self, languageCode: str = None, task: str = None, 
                        prompt_strategy: AbstractPromptStrategy = None, 
                        **decodeOptions: dict) -> AbstractWhisperCallback:
        """
        Create a WhisperCallback object that can be used to transcript audio files.

        Parameters
        ----------
        languageCode: str
            The target language code of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        prompt_strategy: AbstractPromptStrategy
            The prompt strategy to use for the transcription.
        decodeOptions: dict
            Additional options to pass to the decoder. Must be pickleable.

        Returns
        -------
        A WhisperCallback object.
        """
        raise NotImplementedError()

    # This is required for multiprocessing
    def __getstate__(self):
        return { 
            "model_name": self.model_name, 
            "device": self.device, 
            "download_root": self.download_root, 
            "models": self.models, 
            "compute_type": self.compute_type 
        }

    def __setstate__(self, state):
        self.model_name = state["model_name"]
        self.device = state["device"]
        self.download_root = state["download_root"]
        self.models = state["models"]
        self.compute_type = state["compute_type"]
        self.model = None
        # Depickled objects must use the global cache
        self.cache = GLOBAL_MODEL_CACHE