sherpa-onnx-reverb-diarization-v1 / speaker-diarization-torch.py
csukuangfj's picture
add models
a275607
raw
history blame
2.37 kB
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
Please refer to
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
for usages.
"""
"""
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
```
# self._embedding = PretrainedSpeakerEmbedding(
# self.embedding, use_auth_token=use_auth_token
# )
self._embedding = embedding
```
"""
import argparse
from pathlib import Path
import torch
from pyannote.audio import Model
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
from pyannote.audio.pipelines.speaker_verification import (
ONNXWeSpeakerPretrainedSpeakerEmbedding,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
return parser.parse_args()
def build_pipeline():
embedding_filename = "./speaker-embedding.onnx"
if Path(embedding_filename).is_file():
# You need to modify line 166
# of pyannote/audio/pipelines/speaker_diarization.py
# Please see the comments at the start of this script for details
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
else:
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
pt_filename = "./pytorch_model.bin"
segmentation = Model.from_pretrained(pt_filename)
segmentation.eval()
pipeline = SpeakerDiarizationPipeline(
segmentation=segmentation,
embedding=embedding,
embedding_exclude_overlap=True,
)
params = {
"clustering": {
"method": "centroid",
"min_cluster_size": 12,
"threshold": 0.7045654963945799,
},
"segmentation": {"min_duration_off": 0.5},
}
pipeline.instantiate(params)
return pipeline
@torch.no_grad()
def main():
args = get_args()
assert Path(args.wav).is_file(), args.wav
pipeline = build_pipeline()
print(pipeline)
t = pipeline(args.wav)
print(type(t))
print(t)
if __name__ == "__main__":
main()