jhj0517 commited on
Commit
595b5f3
·
1 Parent(s): 6148cfe

add diarization

Browse files
modules/diarize_pipeline.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from pyannote.audio import Pipeline
4
+ from typing import Optional, Union
5
+ import torch
6
+ import whisperx
7
+ import os
8
+
9
+
10
+ class DiarizationPipeline:
11
+ def __init__(
12
+ self,
13
+ model_name="pyannote/speaker-diarization-3.1",
14
+ cache_dir: str = os.path.join("models", "Whisper", "whisperx"),
15
+ use_auth_token=None,
16
+ device: Optional[Union[str, torch.device]] = "cpu",
17
+ ):
18
+ if isinstance(device, str):
19
+ device = torch.device(device)
20
+ self.model = Pipeline.from_pretrained(
21
+ model_name,
22
+ use_auth_token=use_auth_token,
23
+ cache_dir=cache_dir
24
+ ).to(device)
25
+
26
+ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
27
+ if isinstance(audio, str):
28
+ audio = whisperx.load_audio(audio)
29
+ audio_data = {
30
+ 'waveform': torch.from_numpy(audio[None, :]),
31
+ 'sample_rate': whisperx.audio.SAMPLE_RATE
32
+ }
33
+ segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
34
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
35
+ diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
36
+ diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
37
+ return diarize_df
38
+
39
+
40
+ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
41
+ transcript_segments = transcript_result["segments"]
42
+ for seg in transcript_segments:
43
+ # assign speaker to segment (if any)
44
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
45
+ seg['start'])
46
+ diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
47
+
48
+ intersected = diarize_df[diarize_df["intersection"] > 0]
49
+
50
+ speaker = None
51
+ if len(intersected) > 0:
52
+ # Choosing most strong intersection
53
+ speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
54
+ elif fill_nearest:
55
+ # Otherwise choosing closest
56
+ speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
57
+
58
+ if speaker is not None:
59
+ seg["speaker"] = speaker
60
+
61
+ # assign speaker to words
62
+ if 'words' in seg:
63
+ for word in seg['words']:
64
+ if 'start' in word:
65
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
66
+ diarize_df['start'], word['start'])
67
+ diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
68
+ word['start'])
69
+
70
+ intersected = diarize_df[diarize_df["intersection"] > 0]
71
+
72
+ word_speaker = None
73
+ if len(intersected) > 0:
74
+ # Choosing most strong intersection
75
+ word_speaker = \
76
+ intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
77
+ elif fill_nearest:
78
+ # Otherwise choosing closest
79
+ word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
80
+
81
+ if word_speaker is not None:
82
+ word["speaker"] = word_speaker
83
+
84
+ return transcript_result
85
+
86
+
87
+ class Segment:
88
+ def __init__(self, start, end, speaker=None):
89
+ self.start = start
90
+ self.end = end
91
+ self.speaker = speaker
modules/diarizer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import whisperx
3
+ import torch
4
+ from typing import List
5
+ import time
6
+
7
+ from modules.diarize_pipeline import DiarizationPipeline
8
+
9
+
10
+ class Diarizer:
11
+ def __init__(self,
12
+ model_dir: str = os.path.join("models", "Whisper", "whisperx")
13
+ ):
14
+ self.device = self.get_device()
15
+ self.available_device = self.get_available_device()
16
+ self.compute_type = "float16"
17
+ self.model_dir = model_dir
18
+ os.makedirs(self.model_dir, exist_ok=True)
19
+ self.pipe = None
20
+
21
+ def run(self,
22
+ audio: str,
23
+ transcribed_result: List[dict],
24
+ use_auth_token: str,
25
+ device: str
26
+ ):
27
+ """
28
+ Diarize transcribed result as a post-processing
29
+
30
+ Parameters
31
+ ----------
32
+ audio: Union[str, BinaryIO, np.ndarray]
33
+ Audio input. This can be file path or binary type.
34
+ transcribed_result: List[dict]
35
+ transcribed result through whisper.
36
+ use_auth_token: str
37
+ Huggingface token with READ permission. This is only needed the first time you download the model.
38
+ You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
39
+ device: str
40
+ Device for diarization.
41
+
42
+ Returns
43
+ ----------
44
+ segments_result: List[dict]
45
+ list of dicts that includes start, end timestamps and transcribed text
46
+ elapsed_time: float
47
+ elapsed time for running
48
+ """
49
+ start_time = time.time()
50
+
51
+ if (device != self.device
52
+ or self.pipe is None):
53
+ self.update_pipe(
54
+ device=device,
55
+ use_auth_token=use_auth_token
56
+ )
57
+
58
+ audio = whisperx.load_audio(audio)
59
+ diarization_segments = self.pipe(audio)
60
+ diarized_result = whisperx.assign_word_speakers(
61
+ diarization_segments,
62
+ {"segments": transcribed_result}
63
+ )
64
+
65
+ for segment in diarized_result["segments"]:
66
+ speaker = "None"
67
+ if "speaker" in segment:
68
+ speaker = segment["speaker"]
69
+ segment["text"] = speaker + "|" + segment["text"][1:]
70
+
71
+ elapsed_time = time.time() - start_time
72
+ return diarized_result["segments"], elapsed_time
73
+
74
+ def update_pipe(self,
75
+ use_auth_token: str,
76
+ device: str
77
+ ):
78
+ """
79
+ Set pipeline for diarization
80
+
81
+ Parameters
82
+ ----------
83
+ use_auth_token: str
84
+ Huggingface token with READ permission. This is only needed the first time you download the model.
85
+ You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
86
+ device: str
87
+ Device for diarization.
88
+ """
89
+
90
+ os.makedirs(self.model_dir, exist_ok=True)
91
+
92
+ if (not os.listdir(self.model_dir) and
93
+ not use_auth_token):
94
+ print(
95
+ "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
96
+ "Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
97
+ )
98
+ return
99
+
100
+ self.pipe = DiarizationPipeline(
101
+ use_auth_token=use_auth_token,
102
+ device=device,
103
+ cache_dir=self.model_dir
104
+ )
105
+
106
+ @staticmethod
107
+ def get_device():
108
+ if torch.cuda.is_available():
109
+ return "cuda"
110
+ elif torch.backends.mps.is_available():
111
+ return "mps"
112
+ else:
113
+ return "cpu"
114
+
115
+ @staticmethod
116
+ def get_available_device():
117
+ devices = ["cpu"]
118
+ if torch.cuda.is_available():
119
+ devices.append("cuda")
120
+ elif torch.backends.mps.is_available():
121
+ devices.append("mps")
122
+ return devices
modules/whisper_base.py CHANGED
@@ -1,19 +1,18 @@
1
  import os
2
  import torch
3
  from typing import List
4
- import whisperx
5
  import whisper
6
  import gradio as gr
7
  from abc import ABC, abstractmethod
8
  from typing import BinaryIO, Union, Tuple, List
9
  import numpy as np
10
  from datetime import datetime
11
- from dataclasses import astuple
12
  import time
13
 
14
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
15
  from modules.youtube_manager import get_ytdata, get_ytaudio
16
  from modules.whisper_parameter import *
 
17
 
18
 
19
  class WhisperBase(ABC):
@@ -24,20 +23,16 @@ class WhisperBase(ABC):
24
  self.model = None
25
  self.current_model_size = None
26
  self.model_dir = model_dir
27
- self.diarization_model_dir = os.path.join(self.model_dir, "..", "whisperx")
28
  self.output_dir = output_dir
29
  os.makedirs(self.output_dir, exist_ok=True)
30
  os.makedirs(self.model_dir, exist_ok=True)
31
- os.makedirs(self.diarization_model_dir, exist_ok=True)
32
  self.available_models = whisper.available_models()
33
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
34
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
35
  self.device = self.get_device()
36
  self.available_compute_types = ["float16", "float32"]
37
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
38
- self.diarization_model = None
39
- self.diarization_model_metadata = None
40
- self.diarization_pipe = None
41
 
42
  @abstractmethod
43
  def transcribe(self,
@@ -59,8 +54,28 @@ class WhisperBase(ABC):
59
  audio: Union[str, BinaryIO, np.ndarray],
60
  progress: gr.Progress,
61
  *whisper_params,
62
- ):
63
- params = WhisperParameters.post_process(*whisper_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if params.lang == "Automatic Detection":
66
  params.lang = None
@@ -75,65 +90,14 @@ class WhisperBase(ABC):
75
  )
76
 
77
  if params.is_diarize:
78
- if params.lang is None:
79
- print("Diarization Failed!! You have to specify the language explicitly to use diarization")
80
- else:
81
- result, elapsed_time_diarization = self.diarize(
82
- audio=audio,
83
- language_code=params.lang,
84
- use_auth_token=params.hf_token,
85
- transcribed_result=result
86
- )
87
- elapsed_time += elapsed_time_diarization
88
- return result, elapsed_time
89
-
90
- def diarize(self,
91
- audio: str,
92
- language_code: str,
93
- use_auth_token: str,
94
- transcribed_result: List[dict]
95
- ):
96
- start_time = time.time()
97
-
98
- if (self.diarization_model is None or
99
- self.diarization_model_metadata is None or
100
- self.diarization_pipe is None):
101
- self._update_diarization_model(
102
- language_code=language_code,
103
- use_auth_token=use_auth_token
104
  )
105
-
106
- audio = whisperx.load_audio(audio)
107
- diarization_segments = self.diarization_pipe(audio)
108
- diarized_result = whisperx.assign_word_speakers(
109
- diarization_segments,
110
- {"segments": transcribed_result}
111
- )
112
-
113
- for segment in diarized_result["segments"]:
114
- speaker = "None"
115
- if "speaker" in segment:
116
- speaker = segment["speaker"]
117
-
118
- segment["text"] = speaker + "|" + segment["text"][1:]
119
-
120
- elapsed_time = time.time() - start_time
121
- return diarized_result["segments"], elapsed_time
122
-
123
- def _update_diarization_model(self,
124
- use_auth_token: str,
125
- language_code: str
126
- ):
127
- print("loading diarization model...")
128
- self.diarization_model, self.diarization_model_metadata = whisperx.load_align_model(
129
- language_code=language_code,
130
- device=self.device,
131
- model_dir=self.diarization_model_dir,
132
- )
133
- self.diarization_pipe = whisperx.DiarizationPipeline(
134
- use_auth_token=use_auth_token,
135
- device=self.device
136
- )
137
 
138
  def transcribe_file(self,
139
  files: list,
@@ -156,7 +120,7 @@ class WhisperBase(ABC):
156
  progress: gr.Progress
157
  Indicator to show progress directly in gradio.
158
  *whisper_params: tuple
159
- Gradio components related to Whisper. see whisper_data_class.py for details.
160
 
161
  Returns
162
  ----------
@@ -223,7 +187,7 @@ class WhisperBase(ABC):
223
  progress: gr.Progress
224
  Indicator to show progress directly in gradio.
225
  *whisper_params: tuple
226
- Gradio components related to Whisper. see whisper_data_class.py for details.
227
 
228
  Returns
229
  ----------
@@ -278,7 +242,7 @@ class WhisperBase(ABC):
278
  progress: gr.Progress
279
  Indicator to show progress directly in gradio.
280
  *whisper_params: tuple
281
- Gradio components related to Whisper. see whisper_data_class.py for details.
282
 
283
  Returns
284
  ----------
 
1
  import os
2
  import torch
3
  from typing import List
 
4
  import whisper
5
  import gradio as gr
6
  from abc import ABC, abstractmethod
7
  from typing import BinaryIO, Union, Tuple, List
8
  import numpy as np
9
  from datetime import datetime
 
10
  import time
11
 
12
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
13
  from modules.youtube_manager import get_ytdata, get_ytaudio
14
  from modules.whisper_parameter import *
15
+ from modules.diarizer import Diarizer
16
 
17
 
18
  class WhisperBase(ABC):
 
23
  self.model = None
24
  self.current_model_size = None
25
  self.model_dir = model_dir
 
26
  self.output_dir = output_dir
27
  os.makedirs(self.output_dir, exist_ok=True)
28
  os.makedirs(self.model_dir, exist_ok=True)
 
29
  self.available_models = whisper.available_models()
30
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
31
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
32
  self.device = self.get_device()
33
  self.available_compute_types = ["float16", "float32"]
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
+ self.diarizer = Diarizer()
 
 
36
 
37
  @abstractmethod
38
  def transcribe(self,
 
54
  audio: Union[str, BinaryIO, np.ndarray],
55
  progress: gr.Progress,
56
  *whisper_params,
57
+ ) -> Tuple[List[dict], float]:
58
+ """
59
+ Run transcription with conditional post-processing.
60
+ The diarization will be performed in post-processing if enabled.
61
+
62
+ Parameters
63
+ ----------
64
+ audio: Union[str, BinaryIO, np.ndarray]
65
+ Audio input. This can be file path or binary type.
66
+ progress: gr.Progress
67
+ Indicator to show progress directly in gradio.
68
+ *whisper_params: tuple
69
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
70
+
71
+ Returns
72
+ ----------
73
+ segments_result: List[dict]
74
+ list of dicts that includes start, end timestamps and transcribed text
75
+ elapsed_time: float
76
+ elapsed time for running
77
+ """
78
+ params = WhisperParameters.as_value(*whisper_params)
79
 
80
  if params.lang == "Automatic Detection":
81
  params.lang = None
 
90
  )
91
 
92
  if params.is_diarize:
93
+ result, elapsed_time_diarization = self.diarizer.run(
94
+ audio=audio,
95
+ use_auth_token=params.hf_token,
96
+ transcribed_result=result,
97
+ device=self.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
99
+ elapsed_time += elapsed_time_diarization
100
+ return result, elapsed_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def transcribe_file(self,
103
  files: list,
 
120
  progress: gr.Progress
121
  Indicator to show progress directly in gradio.
122
  *whisper_params: tuple
123
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
124
 
125
  Returns
126
  ----------
 
187
  progress: gr.Progress
188
  Indicator to show progress directly in gradio.
189
  *whisper_params: tuple
190
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
191
 
192
  Returns
193
  ----------
 
242
  progress: gr.Progress
243
  Indicator to show progress directly in gradio.
244
  *whisper_params: tuple
245
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
246
 
247
  Returns
248
  ----------