|
|
|
|
|
|
|
""" |
|
Please refer to |
|
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml |
|
for usages. |
|
""" |
|
|
|
""" |
|
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main |
|
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx |
|
|
|
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py |
|
|
|
``` |
|
# self._embedding = PretrainedSpeakerEmbedding( |
|
# self.embedding, use_auth_token=use_auth_token |
|
# ) |
|
self._embedding = embedding |
|
``` |
|
""" |
|
|
|
import argparse |
|
from pathlib import Path |
|
|
|
import torch |
|
from pyannote.audio import Model |
|
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline |
|
from pyannote.audio.pipelines.speaker_verification import ( |
|
ONNXWeSpeakerPretrainedSpeakerEmbedding, |
|
) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def build_pipeline(): |
|
embedding_filename = "./speaker-embedding.onnx" |
|
if Path(embedding_filename).is_file(): |
|
|
|
|
|
|
|
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename) |
|
else: |
|
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM" |
|
|
|
pt_filename = "./pytorch_model.bin" |
|
segmentation = Model.from_pretrained(pt_filename) |
|
segmentation.eval() |
|
|
|
pipeline = SpeakerDiarizationPipeline( |
|
segmentation=segmentation, |
|
embedding=embedding, |
|
embedding_exclude_overlap=True, |
|
) |
|
|
|
params = { |
|
"clustering": { |
|
"method": "centroid", |
|
"min_cluster_size": 12, |
|
"threshold": 0.7045654963945799, |
|
}, |
|
"segmentation": {"min_duration_off": 0.5}, |
|
} |
|
|
|
pipeline.instantiate(params) |
|
return pipeline |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
args = get_args() |
|
assert Path(args.wav).is_file(), args.wav |
|
pipeline = build_pipeline() |
|
print(pipeline) |
|
t = pipeline(args.wav) |
|
print(type(t)) |
|
print(t) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|