File size: 2,374 Bytes
5ec554b a275607 5ec554b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
Please refer to
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
for usages.
"""
"""
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
```
# self._embedding = PretrainedSpeakerEmbedding(
# self.embedding, use_auth_token=use_auth_token
# )
self._embedding = embedding
```
"""
import argparse
from pathlib import Path
import torch
from pyannote.audio import Model
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
from pyannote.audio.pipelines.speaker_verification import (
ONNXWeSpeakerPretrainedSpeakerEmbedding,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
return parser.parse_args()
def build_pipeline():
embedding_filename = "./speaker-embedding.onnx"
if Path(embedding_filename).is_file():
# You need to modify line 166
# of pyannote/audio/pipelines/speaker_diarization.py
# Please see the comments at the start of this script for details
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
else:
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
pt_filename = "./pytorch_model.bin"
segmentation = Model.from_pretrained(pt_filename)
segmentation.eval()
pipeline = SpeakerDiarizationPipeline(
segmentation=segmentation,
embedding=embedding,
embedding_exclude_overlap=True,
)
params = {
"clustering": {
"method": "centroid",
"min_cluster_size": 12,
"threshold": 0.7045654963945799,
},
"segmentation": {"min_duration_off": 0.5},
}
pipeline.instantiate(params)
return pipeline
@torch.no_grad()
def main():
args = get_args()
assert Path(args.wav).is_file(), args.wav
pipeline = build_pipeline()
print(pipeline)
t = pipeline(args.wav)
print(type(t))
print(t)
if __name__ == "__main__":
main()
|