asahi417 commited on
Commit
bdaf5fc
1 Parent(s): 3d49cbd
pipeline/kotoba_whisper.py CHANGED
@@ -1,5 +1,6 @@
1
- from typing import Union, Optional, Dict, List, Any
2
  import requests
 
 
3
 
4
  import torch
5
  import numpy as np
@@ -38,12 +39,12 @@ class Punctuator:
38
  return [
39
  {
40
  'timestamp': c['timestamp'],
 
41
  'text': validate_punctuation(c['text'], "".join(e))
42
  } for c, e in zip(pipeline_chunk, text_edit)
43
  ]
44
 
45
 
46
-
47
  class SpeakerDiarization:
48
 
49
  def __init__(self,
@@ -58,7 +59,12 @@ class SpeakerDiarization:
58
  model_id_diarizers
59
  ).to_pyannote_model().to(self.device)
60
 
61
- def __call__(self, audio: Union[torch.Tensor, np.ndarray], sampling_rate: int) -> Annotation:
 
 
 
 
 
62
  if sampling_rate is None:
63
  raise ValueError("sampling_rate must be provided")
64
  if type(audio) is np.ndarray:
@@ -69,7 +75,7 @@ class SpeakerDiarization:
69
  elif len(audio.shape) > 3:
70
  raise ValueError("audio shape must be (channel, time)")
71
  audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
72
- output = self.pipeline(audio)
73
  return output
74
 
75
 
@@ -84,8 +90,6 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
84
  device: Union[int, "torch.device"] = None,
85
  device_pyannote: Union[int, "torch.device"] = None,
86
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
87
- return_unique_speaker: bool = True,
88
- punctuator: bool = False,
89
  **kwargs):
90
  self.type = "seq2seq_whisper"
91
  if device is None:
@@ -99,11 +103,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
99
  model_id=model_pyannote,
100
  model_id_diarizers=model_diarizers
101
  )
102
- self.return_unique_speaker = return_unique_speaker
103
- if punctuator:
104
- self.punctuator = Punctuator()
105
- else:
106
- self.punctuator = None
107
  super().__init__(
108
  model=model,
109
  feature_extractor=feature_extractor,
@@ -113,6 +113,71 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
113
  **kwargs
114
  )
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
117
  if isinstance(inputs, str):
118
  if inputs.startswith("http://") or inputs.startswith("https://"):
@@ -259,18 +324,31 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
259
  model_outputs,
260
  decoder_kwargs: Optional[Dict] = None,
261
  return_language=None,
 
 
 
 
 
262
  *args,
263
  **kwargs):
264
  assert len(model_outputs) > 0
265
- audio_array = list(model_outputs)[0]["audio_array"]
266
- sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
267
- timelines = sd.get_timeline()
268
  outputs = super().postprocess(
269
  model_outputs=model_outputs,
270
  decoder_kwargs=decoder_kwargs,
271
  return_timestamps=True,
272
  return_language=return_language
273
  )
 
 
 
 
 
 
 
 
 
 
 
274
  pointer_ts = 0
275
  pointer_chunk = 0
276
  new_chunks = []
@@ -306,18 +384,19 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
306
  pointer_ts += 1
307
  for i in new_chunks:
308
  if "speaker" in i:
309
- if self.return_unique_speaker:
310
  i["speaker"] = [i["speaker"][0]]
311
  else:
312
  i["speaker"] = list(set(i["speaker"]))
313
  else:
314
  i["speaker"] = []
315
  outputs["chunks"] = new_chunks
316
- if self.punctuator:
 
 
317
  outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
318
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
319
  outputs["speakers"] = sd.labels()
320
- outputs.pop("audio_array")
321
  speakers = []
322
  for s in outputs["speakers"]:
323
  chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
@@ -326,5 +405,5 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
326
  outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
327
  speakers.append(s)
328
  outputs["speakers"] = speakers
 
329
  return outputs
330
-
 
 
1
  import requests
2
+ from typing import Union, Optional, Dict, List, Any
3
+ from collections import defaultdict
4
 
5
  import torch
6
  import numpy as np
 
39
  return [
40
  {
41
  'timestamp': c['timestamp'],
42
+ 'speaker': c['speaker'],
43
  'text': validate_punctuation(c['text'], "".join(e))
44
  } for c, e in zip(pipeline_chunk, text_edit)
45
  ]
46
 
47
 
 
48
  class SpeakerDiarization:
49
 
50
  def __init__(self,
 
59
  model_id_diarizers
60
  ).to_pyannote_model().to(self.device)
61
 
62
+ def __call__(self,
63
+ audio: Union[torch.Tensor, np.ndarray],
64
+ sampling_rate: int,
65
+ num_speakers: Optional[int] = None,
66
+ min_speakers: Optional[int] = None,
67
+ max_speakers: Optional[int] = None) -> Annotation:
68
  if sampling_rate is None:
69
  raise ValueError("sampling_rate must be provided")
70
  if type(audio) is np.ndarray:
 
75
  elif len(audio.shape) > 3:
76
  raise ValueError("audio shape must be (channel, time)")
77
  audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
78
+ output = self.pipeline(audio, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
79
  return output
80
 
81
 
 
90
  device: Union[int, "torch.device"] = None,
91
  device_pyannote: Union[int, "torch.device"] = None,
92
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
 
 
93
  **kwargs):
94
  self.type = "seq2seq_whisper"
95
  if device is None:
 
103
  model_id=model_pyannote,
104
  model_id_diarizers=model_diarizers
105
  )
106
+ self.punctuator = None
 
 
 
 
107
  super().__init__(
108
  model=model,
109
  feature_extractor=feature_extractor,
 
113
  **kwargs
114
  )
115
 
116
+ def _sanitize_parameters(self,
117
+ chunk_length_s=None,
118
+ stride_length_s=None,
119
+ ignore_warning=None,
120
+ decoder_kwargs=None,
121
+ return_timestamps=None,
122
+ return_language=None,
123
+ generate_kwargs=None,
124
+ max_new_tokens=None,
125
+ add_punctuation: bool =False,
126
+ return_unique_speaker: bool =True,
127
+ num_speakers: Optional[int] = None,
128
+ min_speakers: Optional[int] = None,
129
+ max_speakers: Optional[int] = None):
130
+ # No parameters on this pipeline right now
131
+ preprocess_params = {}
132
+ if chunk_length_s is not None:
133
+ preprocess_params["chunk_length_s"] = chunk_length_s
134
+ if stride_length_s is not None:
135
+ preprocess_params["stride_length_s"] = stride_length_s
136
+
137
+ forward_params = defaultdict(dict)
138
+ if max_new_tokens is not None:
139
+ forward_params["max_new_tokens"] = max_new_tokens
140
+ if generate_kwargs is not None:
141
+ if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
142
+ raise ValueError(
143
+ "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
144
+ " only 1 version"
145
+ )
146
+ forward_params.update(generate_kwargs)
147
+
148
+ postprocess_params = {}
149
+ if decoder_kwargs is not None:
150
+ postprocess_params["decoder_kwargs"] = decoder_kwargs
151
+ if return_timestamps is not None:
152
+ # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
153
+ if self.type == "seq2seq" and return_timestamps:
154
+ raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
155
+ if self.type == "ctc_with_lm" and return_timestamps != "word":
156
+ raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
157
+ if self.type == "ctc" and return_timestamps not in ["char", "word"]:
158
+ raise ValueError(
159
+ "CTC can either predict character level timestamps, or word level timestamps. "
160
+ "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
161
+ )
162
+ if self.type == "seq2seq_whisper" and return_timestamps == "char":
163
+ raise ValueError(
164
+ "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
165
+ "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
166
+ )
167
+ forward_params["return_timestamps"] = return_timestamps
168
+ postprocess_params["return_timestamps"] = return_timestamps
169
+ if return_language is not None:
170
+ if self.type != "seq2seq_whisper":
171
+ raise ValueError("Only Whisper can return language for now.")
172
+ postprocess_params["return_language"] = return_language
173
+ postprocess_params["return_language"] = return_language
174
+ postprocess_params["add_punctuation"] = add_punctuation
175
+ postprocess_params["return_unique_speaker"] = return_unique_speaker
176
+ postprocess_params["num_speakers"] = num_speakers
177
+ postprocess_params["min_speakers"] = min_speakers
178
+ postprocess_params["max_speakers"] = max_speakers
179
+ return preprocess_params, forward_params, postprocess_params
180
+
181
  def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
182
  if isinstance(inputs, str):
183
  if inputs.startswith("http://") or inputs.startswith("https://"):
 
324
  model_outputs,
325
  decoder_kwargs: Optional[Dict] = None,
326
  return_language=None,
327
+ add_punctuation: bool = False,
328
+ return_unique_speaker: bool = True,
329
+ num_speakers: Optional[int] = None,
330
+ min_speakers: Optional[int] = None,
331
+ max_speakers: Optional[int] = None,
332
  *args,
333
  **kwargs):
334
  assert len(model_outputs) > 0
 
 
 
335
  outputs = super().postprocess(
336
  model_outputs=model_outputs,
337
  decoder_kwargs=decoder_kwargs,
338
  return_timestamps=True,
339
  return_language=return_language
340
  )
341
+ audio_array = outputs.pop("audio_array")[0]
342
+ sd = self.model_speaker_diarization(
343
+ audio_array,
344
+ num_speakers=num_speakers,
345
+ min_speakers=min_speakers,
346
+ max_speakers=max_speakers,
347
+ sampling_rate=self.feature_extractor.sampling_rate
348
+ )
349
+ diarization_result = {s: [[i.start, i.end] for i in sd.label_timeline(s)] for s in sd.labels()}
350
+ timelines = sd.get_timeline()
351
+
352
  pointer_ts = 0
353
  pointer_chunk = 0
354
  new_chunks = []
 
384
  pointer_ts += 1
385
  for i in new_chunks:
386
  if "speaker" in i:
387
+ if return_unique_speaker:
388
  i["speaker"] = [i["speaker"][0]]
389
  else:
390
  i["speaker"] = list(set(i["speaker"]))
391
  else:
392
  i["speaker"] = []
393
  outputs["chunks"] = new_chunks
394
+ if add_punctuation:
395
+ if self.punctuator is None:
396
+ self.punctuator = Punctuator()
397
  outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
398
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
399
  outputs["speakers"] = sd.labels()
 
400
  speakers = []
401
  for s in outputs["speakers"]:
402
  chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
 
405
  outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
406
  speakers.append(s)
407
  outputs["speakers"] = speakers
408
+ outputs["diarization_result"] = diarization_result
409
  return outputs
 
pipeline/push_pipeline.py CHANGED
@@ -14,7 +14,7 @@ PIPELINE_REGISTRY.register_pipeline(
14
  tf_model=TFWhisperForConditionalGeneration
15
  )
16
  pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
17
- output = pipe(test_audio)
18
  pprint(output)
19
  pipe.push_to_hub(model_alias)
20
 
 
14
  tf_model=TFWhisperForConditionalGeneration
15
  )
16
  pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
17
+ output = pipe(test_audio, add_punctuation=True)
18
  pprint(output)
19
  pipe.push_to_hub(model_alias)
20
 
pipeline/test_speaker_diarization.py CHANGED
@@ -1,44 +1,11 @@
1
- # Setup:
2
- # pip install pyannote.audio>=3.1
3
- # Requirement: Sumit access request for the following models.
4
- # https://huggingface.co/pyannote/speaker-diarization-3.1
5
- # https://huggingface.co/pyannote/segmentation-3.0
6
- # wget https://huggingface.co/kotoba-tech/kotoba-whisper-v2.2/resolve/main/sample_audio/sample_diarization_japanese.mp3
7
  import soundfile as sf
8
- import numpy as np
9
- from typing import Union, Dict, List
10
-
11
  import torch
12
- from pyannote.audio import Pipeline
13
- from diarizers import SegmentationModel
14
-
15
-
16
- class SpeakerDiarization:
17
-
18
- def __init__(self):
19
- self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
20
- self.pipeline._segmentation.model = SegmentationModel().from_pretrained(
21
- 'diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn'
22
- ).to_pyannote_model()
23
-
24
- def __call__(self,
25
- audio: Union[torch.Tensor, np.ndarray],
26
- sampling_rate: int) -> Dict[str, List[List[float]]]:
27
- if sampling_rate is None:
28
- raise ValueError("sampling_rate must be provided")
29
- if type(audio) is np.ndarray:
30
- audio = torch.as_tensor(audio)
31
- audio = torch.as_tensor(audio, dtype=torch.float32)
32
- if len(audio.shape) == 1:
33
- audio = audio.unsqueeze(0)
34
- elif len(audio.shape) > 3:
35
- raise ValueError("audio shape must be (channel, time)")
36
- audio = {"waveform": audio, "sample_rate": sampling_rate}
37
- output = self.pipeline(audio)
38
- return {s: [[i.start, i.end] for i in output.label_timeline(s)] for s in output.labels()}
39
 
40
 
41
- pipeline = SpeakerDiarization()
42
  a, sr = sf.read("sample_diarization_japanese.mp3")
43
- print(pipeline(a.T, sampling_rate=sr))
 
 
44
 
 
 
 
 
 
 
 
1
  import soundfile as sf
 
 
 
2
  import torch
3
+ from kotoba_whisper import SpeakerDiarization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
+ pipeline = SpeakerDiarization(device=torch.device("cpu"))
7
  a, sr = sf.read("sample_diarization_japanese.mp3")
8
+ output = pipeline(a.T, sampling_rate=sr)
9
+ output = {s: [[i.start, i.end] for i in output.label_timeline(s)] for s in output.labels()}
10
+ print(output)
11