File size: 13,672 Bytes
8069744 f2eb7f2 8069744 aaccb5f 6252154 aaccb5f f2eb7f2 aaccb5f 8069744 aaccb5f 8069744 aaccb5f 8069744 3d49cbd aaccb5f 3d49cbd 8069744 aaccb5f 8069744 aaccb5f 8069744 aaccb5f 3d49cbd 8069744 3d49cbd f2eb7f2 3d49cbd f2eb7f2 9814043 3d49cbd f2eb7f2 8069744 f2eb7f2 8069744 f2eb7f2 a7e3c07 8069744 f2eb7f2 8069744 a7e3c07 8069744 f2eb7f2 8069744 f2eb7f2 3d49cbd f2eb7f2 3d49cbd 986454a f2eb7f2 986454a 74437d7 f2eb7f2 aa56a4f f2eb7f2 8069744 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
import requests
from typing import Union, Optional, Dict
import torch
import numpy as np
from transformers.pipelines.audio_utils import ffmpeg_read
from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline, chunk_iter
from transformers.utils import is_torchaudio_available
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from pyannote.audio import Pipeline
from pyannote.core.annotation import Annotation
from punctuators.models import PunctCapSegModelONNX
from diarizers import SegmentationModel
class Punctuator:
ja_punctuations = ["!", "?", "、", "。"]
def __init__(self, model: str = "1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase"):
self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
def punctuate(self, text: str) -> str:
if any(p in text for p in self.ja_punctuations):
return text
punctuated = "".join(self.punctuation_model.infer([text])[0])
if 'unk' in punctuated.lower():
return text
return punctuated
class SpeakerDiarization:
def __init__(self,
device: torch.device,
model_id: str = "pyannote/speaker-diarization-3.1",
model_id_diarizers: Optional[str] = None):
self.device = device
self.pipeline = Pipeline.from_pretrained(model_id)
self.pipeline = self.pipeline.to(self.device)
if model_id_diarizers:
self.pipeline._segmentation.model = SegmentationModel().from_pretrained(
model_id_diarizers
).to_pyannote_model().to(self.device)
def __call__(self,
audio: Union[torch.Tensor, np.ndarray],
sampling_rate: int,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None) -> Annotation:
if sampling_rate is None:
raise ValueError("sampling_rate must be provided")
if type(audio) is np.ndarray:
audio = torch.as_tensor(audio)
audio = torch.as_tensor(audio, dtype=torch.float32)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
elif len(audio.shape) > 3:
raise ValueError("audio shape must be (channel, time)")
audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
output = self.pipeline(audio, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
return output
class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
def __init__(self,
model: "PreTrainedModel",
model_pyannote: str = "pyannote/speaker-diarization-3.1",
model_diarizers: Optional[str] = "diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn",
feature_extractor: Union["SequenceFeatureExtractor", str] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
device: Union[int, "torch.device"] = None,
device_pyannote: Union[int, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
**kwargs):
self.type = "seq2seq_whisper"
if device is None:
device = "cpu"
if device_pyannote is None:
device_pyannote = device
if type(device_pyannote) is str:
device_pyannote = torch.device(device_pyannote)
self.model_speaker_diarization = SpeakerDiarization(
device=device_pyannote,
model_id=model_pyannote,
model_id_diarizers=model_diarizers
)
self.punctuator = None
super().__init__(
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
device=device,
torch_dtype=torch_dtype,
**kwargs
)
def _sanitize_parameters(self,
chunk_length_s: Optional[int] = None,
stride_length_s: Optional[int] = None,
generate_kwargs: Optional[Dict] = None,
max_new_tokens: Optional[int] = None,
add_punctuation: bool = False,
return_unique_speaker: bool = True,
add_silence_end: Optional[float] = None,
add_silence_start: Optional[float] = None,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None):
preprocess_params = {
"chunk_length_s": chunk_length_s,
"stride_length_s": stride_length_s,
"add_silence_end": add_silence_end,
"add_silence_start": add_silence_start,
"num_speakers": num_speakers,
"min_speakers": min_speakers,
"max_speakers": max_speakers,
}
postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
forward_params = {} if generate_kwargs is None else generate_kwargs
forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True, "language": "ja", "task": "transcribe"})
return preprocess_params, forward_params, postprocess_params
def preprocess(self,
inputs,
chunk_length_s: Optional[int] = None,
stride_length_s: Optional[int] = None,
add_silence_end: Optional[float] = None,
add_silence_start: Optional[float] = None,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None):
def _pad_audio_array(_audio):
if add_silence_start:
_audio = np.concatenate([np.zeros(int(self.feature_extractor.sampling_rate * add_silence_start)), _audio])
if add_silence_end:
_audio = np.concatenate([_audio, np.zeros(int(self.feature_extractor.sampling_rate * add_silence_end))])
return _audio
# load file
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file like http_huggingface_co.png
inputs = requests.get(inputs).content
else:
with open(inputs, "rb") as f:
inputs = f.read()
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
if isinstance(inputs, dict):
# Accepting `"array"` which is the key defined in `datasets` for better integration
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
'"array" key containing the numpy array representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = inputs.pop("array", inputs.pop("raw", None))
if in_sampling_rate != self.feature_extractor.sampling_rate:
if is_torchaudio_available():
from torchaudio import functional as F
else:
raise ImportError(
"torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
"The torchaudio package can be installed through: `pip install torchaudio`."
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
).numpy()
# validate audio array
if not isinstance(inputs, np.ndarray):
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
# diarization
sd = self.model_speaker_diarization(
inputs,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
sampling_rate=self.feature_extractor.sampling_rate
)
# loop over audio chunks and speakers
labels = list(sd.labels())
for n, s in enumerate(labels):
timelines = list(sd.label_timeline(s))
for m, i in enumerate(timelines):
start = int(i.start * self.feature_extractor.sampling_rate)
end = int(i.end * self.feature_extractor.sampling_rate)
audio_array = _pad_audio_array(inputs[start: end])
if chunk_length_s is not None:
stride_length_s = chunk_length_s / 6 if stride_length_s is None else stride_length_s
stride_length_s = [stride_length_s, stride_length_s] if isinstance(stride_length_s, (int, float)) else stride_length_s
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")
for item in chunk_iter(
audio_array, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
):
item["speaker_id"] = s
item["speaker_span"] = [i.start, i.end]
item["is_last"] = m == len(timelines) - 1 and n == len(labels) - 1 and item["is_last"]
yield item
else:
if audio_array.shape[0] > self.feature_extractor.n_samples:
processed = self.feature_extractor(
audio_array,
sampling_rate=self.feature_extractor.sampling_rate,
truncation=False,
padding="longest",
return_tensors="pt",
)
else:
processed = self.feature_extractor(
audio_array,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt"
)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
processed["speaker_id"] = s
processed["speaker_span"] = [i.start, i.end]
processed["is_last"] = m == len(timelines) - 1 and n == len(labels) - 1
yield processed
def _forward(self, model_inputs, **generate_kwargs):
generate_kwargs["attention_mask"] = model_inputs.pop("attention_mask", None)
generate_kwargs["input_features"] = model_inputs.pop("input_features")
tokens = self.model.generate(**generate_kwargs)
return {"tokens": tokens, **model_inputs}
def postprocess(self, model_outputs, **postprocess_parameters):
if postprocess_parameters["add_punctuation"] and self.punctuator is None:
self.punctuator = Punctuator()
outputs = {"chunks": []}
for o in model_outputs:
text, chunks = self.tokenizer._decode_asr(
[o],
return_language=postprocess_parameters["return_language"],
return_timestamps=postprocess_parameters["return_timestamps"],
time_precision=self.feature_extractor.chunk_length / self.model.config.max_source_positions,
)
start, end = o["speaker_span"]
new_chunk = []
for c in chunks["chunks"]:
c["timestamp"] = [round(c["timestamp"][0] + start, 2), round(c["timestamp"][0] + end, 2)]
c["speaker_id"] = o["speaker_id"]
new_chunk.append(c)
outputs["chunks"] += new_chunk
outputs["speaker_ids"] = sorted(set([o["speaker_id"] for o in outputs["chunks"]]))
for s in outputs["speaker_ids"]:
outputs[f"chunks/{s}"] = sorted([o for o in outputs["chunks"] if o["speaker_id"] == s], key=lambda x: x["timestamp"][0])
outputs[f"text/{s}"] = "".join([i["text"] for i in outputs[f"chunks/{s}"]])
if postprocess_parameters["add_punctuation"]:
outputs[f"text/{s}"] = self.punctuator.punctuate(outputs[f"text/{s}"])
return outputs
|