asahi417 commited on
Commit
f2eb7f2
•
1 Parent(s): 64f7d68

Upload KotobaWhisperPipeline

Browse files
Files changed (1) hide show
  1. kotoba_whisper.py +130 -269
kotoba_whisper.py CHANGED
@@ -1,6 +1,5 @@
1
  import requests
2
- from typing import Union, Optional, Dict, List, Any
3
- from collections import defaultdict
4
 
5
  import torch
6
  import numpy as np
@@ -24,25 +23,13 @@ class Punctuator:
24
  def __init__(self, model: str = "pcs_47lang"):
25
  self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
26
 
27
- def punctuate(self, pipeline_chunk: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
28
-
29
- def validate_punctuation(raw: str, punctuated: str):
30
- if 'unk' in punctuated.lower() or any(p in raw for p in self.ja_punctuations):
31
- return raw
32
- if punctuated.count("。") > 1:
33
- ind = punctuated.rfind("。")
34
- punctuated = punctuated.replace("。", "")
35
- punctuated = punctuated[:ind] + "。" + punctuated[ind:]
36
- return punctuated
37
-
38
- text_edit = self.punctuation_model.infer([c['text'] for c in pipeline_chunk])
39
- return [
40
- {
41
- 'timestamp': c['timestamp'],
42
- 'speaker': c['speaker'],
43
- 'text': validate_punctuation(c['text'], "".join(e))
44
- } for c, e in zip(pipeline_chunk, text_edit)
45
- ]
46
 
47
 
48
  class SpeakerDiarization:
@@ -114,104 +101,68 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
114
  )
115
 
116
  def _sanitize_parameters(self,
117
- chunk_length_s=None,
118
- stride_length_s=None,
119
- ignore_warning=None,
120
- decoder_kwargs=None,
121
- return_timestamps=None,
122
- return_language=None,
123
- generate_kwargs=None,
124
- max_new_tokens=None,
125
- add_punctuation: bool =False,
126
- return_unique_speaker: bool =True,
127
  num_speakers: Optional[int] = None,
128
  min_speakers: Optional[int] = None,
129
  max_speakers: Optional[int] = None):
130
- # No parameters on this pipeline right now
131
- preprocess_params = {}
132
- if chunk_length_s is not None:
133
- preprocess_params["chunk_length_s"] = chunk_length_s
134
- if stride_length_s is not None:
135
- preprocess_params["stride_length_s"] = stride_length_s
136
-
137
- forward_params = defaultdict(dict)
138
- if max_new_tokens is not None:
139
- forward_params["max_new_tokens"] = max_new_tokens
140
- if generate_kwargs is not None:
141
- if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
142
- raise ValueError(
143
- "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
144
- " only 1 version"
145
- )
146
- forward_params.update(generate_kwargs)
147
-
148
- postprocess_params = {}
149
- if decoder_kwargs is not None:
150
- postprocess_params["decoder_kwargs"] = decoder_kwargs
151
- if return_timestamps is not None:
152
- # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
153
- if self.type == "seq2seq" and return_timestamps:
154
- raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
155
- if self.type == "ctc_with_lm" and return_timestamps != "word":
156
- raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
157
- if self.type == "ctc" and return_timestamps not in ["char", "word"]:
158
- raise ValueError(
159
- "CTC can either predict character level timestamps, or word level timestamps. "
160
- "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
161
- )
162
- if self.type == "seq2seq_whisper" and return_timestamps == "char":
163
- raise ValueError(
164
- "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
165
- "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
166
- )
167
- forward_params["return_timestamps"] = return_timestamps
168
- postprocess_params["return_timestamps"] = return_timestamps
169
- if return_language is not None:
170
- if self.type != "seq2seq_whisper":
171
- raise ValueError("Only Whisper can return language for now.")
172
- postprocess_params["return_language"] = return_language
173
- postprocess_params["return_language"] = return_language
174
- postprocess_params["add_punctuation"] = add_punctuation
175
- postprocess_params["return_unique_speaker"] = return_unique_speaker
176
- postprocess_params["num_speakers"] = num_speakers
177
- postprocess_params["min_speakers"] = min_speakers
178
- postprocess_params["max_speakers"] = max_speakers
179
  return preprocess_params, forward_params, postprocess_params
180
 
181
- def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if isinstance(inputs, str):
183
  if inputs.startswith("http://") or inputs.startswith("https://"):
184
- # We need to actually check for a real protocol, otherwise it's impossible to use a local file
185
- # like http_huggingface_co.png
186
  inputs = requests.get(inputs).content
187
  else:
188
  with open(inputs, "rb") as f:
189
  inputs = f.read()
190
-
191
  if isinstance(inputs, bytes):
192
  inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
193
-
194
- stride = None
195
- extra = {}
196
  if isinstance(inputs, dict):
197
- stride = inputs.pop("stride", None)
198
- # Accepting `"array"` which is the key defined in `datasets` for
199
- # better integration
200
- if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
201
  raise ValueError(
202
  "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
203
- '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
204
  "containing the sampling_rate associated with that array"
205
  )
206
-
207
- _inputs = inputs.pop("raw", None)
208
- if _inputs is None:
209
- # Remove path which will not be used from `datasets`.
210
- inputs.pop("path", None)
211
- _inputs = inputs.pop("array", None)
212
  in_sampling_rate = inputs.pop("sampling_rate")
213
- extra = inputs
214
- inputs = _inputs
215
  if in_sampling_rate != self.feature_extractor.sampling_rate:
216
  if is_torchaudio_available():
217
  from torchaudio import functional as F
@@ -220,190 +171,100 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
220
  "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
221
  "The torchaudio package can be installed through: `pip install torchaudio`."
222
  )
223
-
224
  inputs = F.resample(
225
  torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
226
  ).numpy()
227
- ratio = self.feature_extractor.sampling_rate / in_sampling_rate
228
- else:
229
- ratio = 1
230
- if stride is not None:
231
- if stride[0] + stride[1] > inputs.shape[0]:
232
- raise ValueError("Stride is too large for input")
233
 
234
- # Stride needs to get the chunk length here, it's going to get
235
- # swallowed by the `feature_extractor` later, and then batching
236
- # can add extra data in the inputs, so we need to keep track
237
- # of the original length in the stride so we can cut properly.
238
- stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
239
  if not isinstance(inputs, np.ndarray):
240
  raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
241
  if len(inputs.shape) != 1:
242
  raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
243
 
244
- if chunk_length_s:
245
- if stride_length_s is None:
246
- stride_length_s = chunk_length_s / 6
247
-
248
- if isinstance(stride_length_s, (int, float)):
249
- stride_length_s = [stride_length_s, stride_length_s]
250
-
251
- # XXX: Carefuly, this variable will not exist in `seq2seq` setting.
252
- # Currently chunking is not possible at this level for `seq2seq` so
253
- # it's ok.
254
- align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
255
- chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
256
- stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
257
- stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
258
-
259
- if chunk_len < stride_left + stride_right:
260
- raise ValueError("Chunk length must be superior to stride length")
261
-
262
- for item in chunk_iter(
263
- inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
264
- ):
265
- item["audio_array"] = inputs
266
- yield item
267
- else:
268
- if inputs.shape[0] > self.feature_extractor.n_samples:
269
- processed = self.feature_extractor(
270
- inputs,
271
- sampling_rate=self.feature_extractor.sampling_rate,
272
- truncation=False,
273
- padding="longest",
274
- return_tensors="pt",
275
- )
276
- else:
277
- processed = self.feature_extractor(
278
- inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
279
- )
280
-
281
- if self.torch_dtype is not None:
282
- processed = processed.to(dtype=self.torch_dtype)
283
- if stride is not None:
284
- processed["stride"] = stride
285
- yield {"is_last": True, "audio_array": inputs, **processed, **extra}
286
-
287
- def _forward(self, model_inputs, **generate_kwargs):
288
- attention_mask = model_inputs.pop("attention_mask", None)
289
- stride = model_inputs.pop("stride", None)
290
- is_last = model_inputs.pop("is_last")
291
- audio_array = model_inputs.pop("audio_array")
292
- encoder = self.model.get_encoder()
293
- # Consume values so we can let extra information flow freely through
294
- # the pipeline (important for `partial` in microphone)
295
- if "input_features" in model_inputs:
296
- inputs = model_inputs.pop("input_features")
297
- elif "input_values" in model_inputs:
298
- inputs = model_inputs.pop("input_values")
299
- else:
300
- raise ValueError(
301
- "Seq2Seq speech recognition model requires either a "
302
- f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
303
- )
304
-
305
- # custom processing for Whisper timestamps and word-level timestamps
306
- generate_kwargs["return_timestamps"] = True
307
- if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
308
- generate_kwargs["input_features"] = inputs
309
- else:
310
- generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
311
-
312
- tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
313
- # whisper longform generation stores timestamps in "segments"
314
- out = {"tokens": tokens}
315
- if self.type == "seq2seq_whisper":
316
- if stride is not None:
317
- out["stride"] = stride
318
-
319
- # Leftover
320
- extra = model_inputs
321
- return {"is_last": is_last, "audio_array": audio_array, **out, **extra}
322
-
323
- def postprocess(self,
324
- model_outputs,
325
- decoder_kwargs: Optional[Dict] = None,
326
- return_language=None,
327
- add_punctuation: bool = False,
328
- return_unique_speaker: bool = True,
329
- num_speakers: Optional[int] = None,
330
- min_speakers: Optional[int] = None,
331
- max_speakers: Optional[int] = None,
332
- *args,
333
- **kwargs):
334
- assert len(model_outputs) > 0
335
- outputs = super().postprocess(
336
- model_outputs=model_outputs,
337
- decoder_kwargs=decoder_kwargs,
338
- return_timestamps=True,
339
- return_language=return_language
340
- )
341
- audio_array = outputs.pop("audio_array")[0]
342
  sd = self.model_speaker_diarization(
343
- audio_array,
344
  num_speakers=num_speakers,
345
  min_speakers=min_speakers,
346
  max_speakers=max_speakers,
347
  sampling_rate=self.feature_extractor.sampling_rate
348
  )
349
- diarization_result = {s: [[i.start, i.end] for i in sd.label_timeline(s)] for s in sd.labels()}
350
- timelines = sd.get_timeline()
351
-
352
- pointer_ts = 0
353
- pointer_chunk = 0
354
- new_chunks = []
355
- while True:
356
- if pointer_ts == len(timelines):
357
- ts = timelines[-1]
358
- for chunk in outputs["chunks"][pointer_chunk:]:
359
- chunk["speaker"] = sd.get_labels(ts)
360
- new_chunks.append(chunk)
361
- break
362
- if pointer_chunk == len(outputs["chunks"]):
363
- break
364
- ts = timelines[pointer_ts]
365
-
366
- chunk = outputs["chunks"][pointer_chunk]
367
- if "speaker" not in chunk:
368
- chunk["speaker"] = []
369
 
370
- start, end = chunk["timestamp"]
371
- if ts.end <= start:
372
- pointer_ts += 1
373
- elif end <= ts.start:
374
- if len(chunk["speaker"]) == 0:
375
- chunk["speaker"] += list(sd.get_labels(ts))
376
- new_chunks.append(chunk)
377
- pointer_chunk += 1
378
- else:
379
- chunk["speaker"] += list(sd.get_labels(ts))
380
- if ts.end >= end:
381
- new_chunks.append(chunk)
382
- pointer_chunk += 1
383
- else:
384
- pointer_ts += 1
385
- for i in new_chunks:
386
- if "speaker" in i:
387
- if return_unique_speaker:
388
- i["speaker"] = [i["speaker"][0]]
 
 
 
 
 
 
389
  else:
390
- i["speaker"] = list(set(i["speaker"]))
391
- else:
392
- i["speaker"] = []
393
- outputs["chunks"] = new_chunks
394
- if add_punctuation:
395
- if self.punctuator is None:
396
- self.punctuator = Punctuator()
397
- outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
398
- outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
399
- outputs["speakers"] = sd.labels()
400
- speakers = []
401
- for s in outputs["speakers"]:
402
- chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
403
- if len(chunk_s) != 0:
404
- outputs[f"chunks/{s}"] = chunk_s
405
- outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
406
- speakers.append(s)
407
- outputs["speakers"] = speakers
408
- outputs["diarization_result"] = diarization_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  return outputs
 
1
  import requests
2
+ from typing import Union, Optional, Dict
 
3
 
4
  import torch
5
  import numpy as np
 
23
  def __init__(self, model: str = "pcs_47lang"):
24
  self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
25
 
26
+ def punctuate(self, text: str) -> str:
27
+ if any(p in text for p in self.ja_punctuations):
28
+ return text
29
+ punctuated = "".join(self.punctuation_model.infer([text])[0])
30
+ if 'unk' in punctuated.lower():
31
+ return text
32
+ return punctuated
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  class SpeakerDiarization:
 
101
  )
102
 
103
  def _sanitize_parameters(self,
104
+ chunk_length_s: Optional[int] = None,
105
+ stride_length_s: Optional[int] = None,
106
+ generate_kwargs: Optional[Dict] = None,
107
+ max_new_tokens: Optional[int] = None,
108
+ add_punctuation: bool = False,
109
+ return_unique_speaker: bool = True,
110
+ add_silence_end: Optional[float] = None,
111
+ add_silence_start: Optional[float] = None,
 
 
112
  num_speakers: Optional[int] = None,
113
  min_speakers: Optional[int] = None,
114
  max_speakers: Optional[int] = None):
115
+ preprocess_params = {
116
+ "chunk_length_s": chunk_length_s,
117
+ "stride_length_s": stride_length_s,
118
+ "add_silence_end": add_silence_end,
119
+ "add_silence_start": add_silence_start,
120
+ "num_speakers": num_speakers,
121
+ "min_speakers": min_speakers,
122
+ "max_speakers": max_speakers,
123
+ }
124
+ postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
125
+ forward_params = {} if generate_kwargs is None else generate_kwargs
126
+ forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  return preprocess_params, forward_params, postprocess_params
128
 
129
+ def preprocess(self,
130
+ inputs,
131
+ chunk_length_s: Optional[int] = None,
132
+ stride_length_s: Optional[int] = None,
133
+ add_silence_end: Optional[float] = None,
134
+ add_silence_start: Optional[float] = None,
135
+ num_speakers: Optional[int] = None,
136
+ min_speakers: Optional[int] = None,
137
+ max_speakers: Optional[int] = None):
138
+
139
+ def _pad_audio_array(_audio):
140
+ if add_silence_start:
141
+ _audio = np.concatenate([np.zeros(int(self.feature_extractor.sampling_rate * add_silence_start)), _audio])
142
+ if add_silence_end:
143
+ _audio = np.concatenate([_audio, np.zeros(int(self.feature_extractor.sampling_rate * add_silence_end))])
144
+ return _audio
145
+
146
+ # load file
147
  if isinstance(inputs, str):
148
  if inputs.startswith("http://") or inputs.startswith("https://"):
149
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file like http_huggingface_co.png
 
150
  inputs = requests.get(inputs).content
151
  else:
152
  with open(inputs, "rb") as f:
153
  inputs = f.read()
 
154
  if isinstance(inputs, bytes):
155
  inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
 
 
 
156
  if isinstance(inputs, dict):
157
+ # Accepting `"array"` which is the key defined in `datasets` for better integration
158
+ if not ("sampling_rate" in inputs and "array" in inputs):
 
 
159
  raise ValueError(
160
  "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
161
+ '"array" key containing the numpy array representing the audio and a "sampling_rate" key, '
162
  "containing the sampling_rate associated with that array"
163
  )
 
 
 
 
 
 
164
  in_sampling_rate = inputs.pop("sampling_rate")
165
+ inputs = inputs.pop("array", None)
 
166
  if in_sampling_rate != self.feature_extractor.sampling_rate:
167
  if is_torchaudio_available():
168
  from torchaudio import functional as F
 
171
  "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
172
  "The torchaudio package can be installed through: `pip install torchaudio`."
173
  )
 
174
  inputs = F.resample(
175
  torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
176
  ).numpy()
 
 
 
 
 
 
177
 
178
+ # validate audio array
 
 
 
 
179
  if not isinstance(inputs, np.ndarray):
180
  raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
181
  if len(inputs.shape) != 1:
182
  raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
183
 
184
+ # diarization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  sd = self.model_speaker_diarization(
186
+ inputs,
187
  num_speakers=num_speakers,
188
  min_speakers=min_speakers,
189
  max_speakers=max_speakers,
190
  sampling_rate=self.feature_extractor.sampling_rate
191
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # loop over audio chunks and speakers
194
+ labels = list(sd.labels())
195
+ for n, s in enumerate(labels):
196
+ timelines = list(sd.label_timeline(s))
197
+ for m, i in enumerate(timelines):
198
+ start = int(i.start * self.feature_extractor.sampling_rate)
199
+ end = int(i.end * self.feature_extractor.sampling_rate)
200
+ audio_array = _pad_audio_array(inputs[start: end])
201
+
202
+ if chunk_length_s is not None:
203
+ stride_length_s = chunk_length_s / 6 if stride_length_s is None else stride_length_s
204
+ stride_length_s = [stride_length_s, stride_length_s] if isinstance(stride_length_s, (int, float)) else stride_length_s
205
+ align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
206
+ chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
207
+ stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
208
+ stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
209
+ if chunk_len < stride_left + stride_right:
210
+ raise ValueError("Chunk length must be superior to stride length")
211
+ for item in chunk_iter(
212
+ audio_array, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
213
+ ):
214
+ item["speaker_id"] = s
215
+ item["speaker_span"] = [i.start, i.end]
216
+ item["is_last"] = m == len(timelines) - 1 and n == len(labels) - 1 and item["is_last"]
217
+ yield item
218
  else:
219
+ if inputs.shape[0] > self.feature_extractor.n_samples:
220
+ processed = self.feature_extractor(
221
+ audio_array,
222
+ sampling_rate=self.feature_extractor.sampling_rate,
223
+ truncation=False,
224
+ padding="longest",
225
+ return_tensors="pt",
226
+ )
227
+ else:
228
+ processed = self.feature_extractor(
229
+ audio_array,
230
+ sampling_rate=self.feature_extractor.sampling_rate,
231
+ return_tensors="pt"
232
+ )
233
+ if self.torch_dtype is not None:
234
+ processed = processed.to(dtype=self.torch_dtype)
235
+ processed["speaker_id"] = s
236
+ processed["speaker_span"] = [i.start, i.end]
237
+ processed["is_last"] = m == len(timelines) - 1 and n == len(labels) - 1
238
+ yield processed
239
+
240
+ def _forward(self, model_inputs, **generate_kwargs):
241
+ generate_kwargs["attention_mask"] = model_inputs.pop("attention_mask", None)
242
+ generate_kwargs["input_features"] = model_inputs.pop("input_features")
243
+ tokens = self.model.generate(**generate_kwargs)
244
+ return {"tokens": tokens, **model_inputs}
245
+
246
+ def postprocess(self, model_outputs, **postprocess_parameters):
247
+ if postprocess_parameters["add_punctuation"] and self.punctuator is None:
248
+ self.punctuator = Punctuator()
249
+ outputs = {"chunks": []}
250
+ for o in model_outputs:
251
+ text, chunks = self.tokenizer._decode_asr(
252
+ [o],
253
+ return_language=postprocess_parameters["return_language"],
254
+ return_timestamps=postprocess_parameters["return_timestamps"],
255
+ time_precision=self.feature_extractor.chunk_length / self.model.config.max_source_positions,
256
+ )
257
+ start, end = o["speaker_span"]
258
+ new_chunk = []
259
+ for c in chunks["chunks"]:
260
+ c["timestamp"] = [round(c["timestamp"][0] + start, 2), round(c["timestamp"][0] + end, 2)]
261
+ c["speaker_id"] = o["speaker_id"]
262
+ new_chunk.append(c)
263
+ outputs["chunks"] += new_chunk
264
+ outputs["speaker_ids"] = sorted(set([o["speaker_id"] for o in outputs["chunks"]]))
265
+ for s in outputs["speaker_ids"]:
266
+ outputs[f"chunk/{s}"] = sorted([o for o in outputs["chunks"] if o["speaker_id"] == s], key=lambda x: x["timestamp"][0])
267
+ outputs[f"text/{s}"] = "".join([i["text"] for i in outputs[f"chunk/{s}"]])
268
+ if postprocess_parameters["add_punctuation"]:
269
+ outputs[f"text/{s}"] = self.punctuator.punctuate(outputs[f"text/{s}"])
270
  return outputs