File size: 4,875 Bytes
295de00
 
8031785
295de00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8031785
 
 
 
 
 
 
 
 
295de00
 
 
 
 
 
 
 
 
33ee1bb
 
 
295de00
 
33ee1bb
295de00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8031785
 
 
295de00
 
 
 
 
 
 
 
 
 
 
8031785
 
 
295de00
 
 
 
 
 
 
 
 
 
 
33ee1bb
 
 
 
 
 
 
295de00
 
 
 
 
 
33ee1bb
295de00
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import abc
from typing import List
from src.config import ModelConfig, VadInitialPromptMode

from src.hooks.progressListener import ProgressListener
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache

class AbstractWhisperCallback:
    @abc.abstractmethod
    def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
        """
        Peform the transcription of the given audio file or data.

        Parameters
        ----------
        audio: Union[str, np.ndarray, torch.Tensor]
            The audio file to transcribe, or the audio data as a numpy array or torch tensor.
        segment_index: int
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        progress_listener: ProgressListener
            A callback to receive progress updates.
        """
        raise NotImplementedError()

    def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode, 
                               prompt: str, segment_index: int):
        if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
            return self._concat_prompt(initial_prompt, prompt)
        elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
            return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
        else:
            raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")

    def _concat_prompt(self, prompt1, prompt2):
        if (prompt1 is None):
            return prompt2
        elif (prompt2 is None):
            return prompt1
        else:
            return prompt1 + " " + prompt2

class AbstractWhisperContainer:
    def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
                 download_root: str = None,
                 cache: ModelCache = None, models: List[ModelConfig] = []):
        self.model_name = model_name
        self.device = device
        self.compute_type = compute_type
        self.download_root = download_root
        self.cache = cache

        # Will be created on demand
        self.model = None

        # List of known models
        self.models = models
    
    def get_model(self):
        if self.model is None:

            if (self.cache is None):
                self.model = self._create_model()
            else:
                model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
                self.model = self.cache.get(model_key, self._create_model)
        return self.model
    
    @abc.abstractmethod
    def _create_model(self):
        raise NotImplementedError()

    def ensure_downloaded(self):
        pass

    @abc.abstractmethod
    def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, 
                        initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT, 
                        **decodeOptions: dict) -> AbstractWhisperCallback:
        """
        Create a WhisperCallback object that can be used to transcript audio files.

        Parameters
        ----------
        language: str
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        initial_prompt: str
            The initial prompt to use for the transcription.
        initial_prompt_mode: VadInitialPromptMode
            The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
            If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
        decodeOptions: dict
            Additional options to pass to the decoder. Must be pickleable.

        Returns
        -------
        A WhisperCallback object.
        """
        raise NotImplementedError()

    # This is required for multiprocessing
    def __getstate__(self):
        return { 
            "model_name": self.model_name, 
            "device": self.device, 
            "download_root": self.download_root, 
            "models": self.models, 
            "compute_type": self.compute_type 
        }

    def __setstate__(self, state):
        self.model_name = state["model_name"]
        self.device = state["device"]
        self.download_root = state["download_root"]
        self.models = state["models"]
        self.compute_type = state["compute_type"]
        self.model = None
        # Depickled objects must use the global cache
        self.cache = GLOBAL_MODEL_CACHE