File size: 12,595 Bytes
e11256b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291fcec
 
 
 
e11256b
291fcec
 
e11256b
291fcec
 
 
 
e11256b
291fcec
 
 
 
 
 
 
 
 
 
 
 
e11256b
291fcec
 
 
e11256b
291fcec
e11256b
291fcec
 
e11256b
291fcec
e11256b
291fcec
e11256b
291fcec
 
 
e11256b
291fcec
 
e11256b
291fcec
e11256b
291fcec
e11256b
291fcec
 
 
e11256b
291fcec
 
 
e11256b
291fcec
 
 
e11256b
291fcec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e11256b
 
291fcec
 
e11256b
 
291fcec
e11256b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""
Diarisation Class
------------------

This class serves as the heart of the speaker diarization system, responsible for identifying
and segmenting individual speakers from a given audio file. It leverages a pretrained model
from pyannote.audio, providing an accessible interface for audio processing tasks such as
speaker separation, and timestamping.

By encapsulating the complexities of the underlying model, it allows for straightforward
integration into various applications, ranging from transcription services to voice assistants.

Available Classes:
- Diariser: Main class for performing speaker diarization. 
            Includes methods for loading models, processing audio files,
            and formatting the diarization output.

Constants:
- TOKEN_PATH (str): Path to the Pyannote token.
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.

Usage:
    from .diarisation import Diariser

    model = Diariser.load_model(model="path/to/model/config.yaml")
    diarisation_output = model.diarization("path/to/audiofile.wav")
"""

import warnings
import os
import yaml
from pathlib import Path
from typing import TypeVar, Union

from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from torch import device as torch_device

from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
Annotation = TypeVar('Annotation')

TOKEN_PATH = os.path.join(os.path.dirname(
    os.path.realpath(__file__)), '.pyannotetoken')


class Diariser:
    """
    Handles the diarization process of an audio file using a pretrained model
    from pyannote.audio. Diarization is the task of determining "who spoke when."

    Args:
        model: The pretrained model to use for diarization.
    """

    def __init__(self, model) -> None:

        self.model = model

    def diarization(self, audiofile: Union[str, Tensor, dict],
                    *args, **kwargs) -> Annotation:
        """
        Perform speaker diarization on the provided audio file, 
        effectively separating different speakers
        and providing a timestamp for each segment.

        Args:
            audiofile: The path to the audio file or a torch.Tensor
                        containing the audio data.
            args: Additional arguments for the diarization model.
            kwargs: Additional keyword arguments for the diarization model.

        Returns:
            dict: A dictionary containing speaker names,
                    segments, and other information related
                    to the diarization process.
        """
        kwargs = self._get_diarisation_kwargs(**kwargs)

        diarization = self.model(audiofile, *args, **kwargs)

        out = self.format_diarization_output(diarization)

        return out

    # @staticmethod
    # def format_diarization_output(dia: Annotation) -> dict:
    #     """
    #     Formats the raw diarization output into a more usable structure for this project.

    #     Args:
    #         dia: Raw diarization output.

    #     Returns:
    #         dict: A structured representation of the diarization, with speaker names
    #               as keys and a list of tuples representing segments as values.
    #     """

    #     dia_list = list(dia.itertracks(yield_label=True))
    #     diarization_output = {"speakers": [], "segments": []}

    #     normalized_output = []
    #     index_start_speaker = 0
    #     index_end_speaker = 0
    #     current_speaker = str()

    #     ###
    #     # Sometimes two consecutive speakers are the same
    #     # This loop removes these duplicates
    #     ###

    #     if len(dia_list) == 1:
    #         normalized_output.append([0, 0, dia_list[0][2]])
    #     else:

    #         for i, (_, _, speaker) in enumerate(dia_list):

    #             if i == 0:
    #                 current_speaker = speaker

    #             if speaker != current_speaker:

    #                 index_end_speaker = i - 1

    #                 normalized_output.append([index_start_speaker,
    #                                           index_end_speaker,
    #                                           current_speaker])

    #                 index_start_speaker = i
    #                 current_speaker = speaker

    #             if i == len(dia_list) - 1:

    #                 index_end_speaker = i

    #                 normalized_output.append([index_start_speaker,
    #                                           index_end_speaker,
    #                                           current_speaker])

    #     for outp in normalized_output:
    #         start = dia_list[outp[0]][0].start
    #         end = dia_list[outp[1]][0].end

    #         diarization_output["segments"].append([start, end])
    #         diarization_output["speakers"].append(outp[2])
    #     return diarization_output

    @staticmethod
    def format_diarization_output(dia: Annotation) -> dict:
        """
            Formats the raw diarization output into a more usable structure for this project,
            without combining consecutive segments of the same speaker.

            Args:
            dia: Raw diarization output.

            Returns:
            dict: A structured representation of the diarization, with speaker names
              as keys and a list of tuples representing segments as values.
        """
        dia_list = list(dia.itertracks(yield_label=True))
        diarization_output = {"speakers": [], "segments": []}

        for segment, _, speaker in dia_list:
            start = segment.start
            end = segment.end

            diarization_output["segments"].append([start, end])
            diarization_output["speakers"].append(speaker)
    
        return diarization_output


    @staticmethod
    def _get_token():
        """
        Retrieves the Huggingface token from a local file. This token is required
        for accessing certain online resources.

        Raises:
            ValueError: If the token is not found.

        Returns:
            str: The Huggingface token.
        """

        if os.path.exists(TOKEN_PATH):
            with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
                token = file.read()
        else:
            raise ValueError('No token found.'
                             'Please create a token at https://huggingface.co/settings/token'
                             f'and save it in a file called {TOKEN_PATH}')
        return token

    @staticmethod
    def _save_token(token):
        """
        Saves the provided Huggingface token to a local file. This facilitates future
        access to online resources without needing to repeatedly authenticate.

        Args:
            token: The Huggingface token to save.
        """
        with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
            file.write(token)

    @classmethod
    def load_model(cls,
                   model: str = PYANNOTE_DEFAULT_CONFIG,
                   use_auth_token: str = None,
                   cache_token: bool = False,
                   cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
                   hparams_file: Union[str, Path] = None,
                   device: str = SCRAIBE_TORCH_DEVICE,
                   ) -> Pipeline:
        """
        Loads a pretrained model from pyannote.audio, 
        either from a local cache or some online repository.

        Args:
            model: Path or identifier for the pyannote model.
                default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
                or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
            token: Optional HUGGINGFACE_TOKEN for authenticated access.
            cache_token: Whether to cache the token locally for future use.
            cache_dir: Directory for caching models.
            hparams_file: Path to a YAML file containing hyperparameters.
            device: Device to load the model on.
            args: Additional arguments only to avoid errors.
            kwargs: Additional keyword arguments only to avoid errors.

        Returns:
            Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
        """
        if isinstance(model, str) and os.path.exists(model):
            # check if model can be found locally nearby the config file
            with open(model, 'r') as file:
                config = yaml.safe_load(file)

            path_to_model = config['pipeline']['params']['segmentation']

            if not os.path.exists(path_to_model):
                warnings.warn(f"Model not found at {path_to_model}. "
                              "Trying to find it nearby the config file.")

                pwd = model.split("/")[:-1]
                pwd = "/".join(pwd)

                path_to_model = os.path.join(pwd, "pytorch_model.bin")

                if not os.path.exists(path_to_model):
                    warnings.warn(f"Model not found at {path_to_model}. \
                        'Trying to find it nearby .bin files instead.")
                    warnings.warn(
                        'Searching for nearby files in a folder path is '
                        'deprecated and will be removed in future versions.',
                        category=DeprecationWarning)
                    # list elementes with the ending .bin
                    bin_files = [f for f in os.listdir(
                        pwd) if f.endswith(".bin")]
                    if len(bin_files) == 1:
                        path_to_model = os.path.join(pwd, bin_files[0])
                    else:
                        warnings.warn("Found more than one .bin file. "
                                      "or none. Please specify the path to the model "
                                      "or setup a huggingface token.")
                        raise FileNotFoundError

                warnings.warn(
                    f"Found model at {path_to_model} overwriting config file.")

                config['pipeline']['params']['segmentation'] = path_to_model

                with open(model, 'w') as file:
                    yaml.dump(config, file)
        elif isinstance(model, tuple):
            try:
                _model = model[0]
                HfApi().model_info(_model)
                model = _model
                use_auth_token = None
            except RepositoryNotFoundError:
                print(f'{model[0]} not found on Huggingface, \
                      trying {model[1]}')
                _model = model[1]
                HfApi().model_info(_model)
                model = _model
                if cache_token and use_auth_token is not None:
                    cls._save_token(use_auth_token)

                if use_auth_token is None:
                    use_auth_token = cls._get_token()
        else:
            raise FileNotFoundError(
                f'No local model or directory found at {model}.')

        _model = Pipeline.from_pretrained(model,
                                          use_auth_token=use_auth_token,
                                          cache_dir=cache_dir,
                                          hparams_file=hparams_file,)
        if _model is None:
            raise ValueError('Unable to load model either from local cache'
                             'or from huggingface.co models. Please check your token'
                             'or your local model path')

        # torch_device is renamed from torch.device to avoid name conflict
        _model = _model.to(torch_device(device))

        return cls(_model)

    @staticmethod
    def _get_diarisation_kwargs(**kwargs) -> dict:
        """
        Validates and extracts the keyword arguments for the pyannote diarization model.

        Ensures that the provided keyword arguments match the expected parameters,
        filtering out any invalid or unnecessary arguments.

        Returns:
            dict: A dictionary containing the validated keyword arguments.
        """
        _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames

        diarisation_kwargs = {k: v for k,
                              v in kwargs.items() if k in _possible_kwargs}

        return diarisation_kwargs

    def __repr__(self):
        return f"Diarisation(model={self.model})"