|
|
|
|
|
|
|
import os |
|
from typing import Any, Dict |
|
|
|
import onnx |
|
import torch |
|
from onnxruntime.quantization import QuantType, quantize_dynamic |
|
from pyannote.audio import Model |
|
from pyannote.audio.core.task import Problem, Resolution |
|
|
|
|
|
def add_meta_data(filename: str, meta_data: Dict[str, Any]): |
|
"""Add meta data to an ONNX model. It is changed in-place. |
|
|
|
Args: |
|
filename: |
|
Filename of the ONNX model to be changed. |
|
meta_data: |
|
Key-value pairs. |
|
""" |
|
model = onnx.load(filename) |
|
|
|
while len(model.metadata_props): |
|
model.metadata_props.pop() |
|
|
|
for key, value in meta_data.items(): |
|
meta = model.metadata_props.add() |
|
meta.key = key |
|
meta.value = str(value) |
|
|
|
onnx.save(model, filename) |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
|
|
|
|
|
|
|
|
pt_filename = "./pytorch_model.bin" |
|
model = Model.from_pretrained(pt_filename) |
|
model.eval() |
|
assert model.dimension == 7, model.dimension |
|
print(model.specifications) |
|
|
|
assert ( |
|
model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION |
|
), model.specifications.problem |
|
|
|
assert ( |
|
model.specifications.resolution == Resolution.FRAME |
|
), model.specifications.resolution |
|
|
|
assert model.specifications.duration == 10.0, model.specifications.duration |
|
|
|
assert model.audio.sample_rate == 16000, model.audio.sample_rate |
|
|
|
|
|
assert list(model.example_input_array.shape) == [ |
|
1, |
|
1, |
|
16000 * 10, |
|
], model.example_input_array.shape |
|
|
|
example_output = model(model.example_input_array) |
|
|
|
|
|
assert list(example_output.shape) == [1, 589, 7], example_output.shape |
|
|
|
assert model.receptive_field.step == 0.016875, model.receptive_field.step |
|
assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration |
|
assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000 |
|
assert model.receptive_field.duration * 16000 == 991, ( |
|
model.receptive_field.duration * 16000 |
|
) |
|
|
|
opset_version = 13 |
|
|
|
filename = "model.onnx" |
|
torch.onnx.export( |
|
model, |
|
model.example_input_array, |
|
filename, |
|
opset_version=opset_version, |
|
input_names=["x"], |
|
output_names=["y"], |
|
dynamic_axes={ |
|
"x": {0: "N", 2: "T"}, |
|
"y": {0: "N", 1: "T"}, |
|
}, |
|
) |
|
|
|
sample_rate = model.audio.sample_rate |
|
|
|
window_size = int(model.specifications.duration) * 16000 |
|
receptive_field_size = int(model.receptive_field.duration * 16000) |
|
receptive_field_shift = int(model.receptive_field.step * 16000) |
|
|
|
is_revai = os.getenv("SHERPA_ONNX_IS_REVAI", "") |
|
if is_revai == "": |
|
url_1 = "https://huggingface.co/pyannote/segmentation-3.0" |
|
url_2 = "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0" |
|
license_url = ( |
|
"https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE" |
|
) |
|
else: |
|
url_1 = "https://huggingface.co/Revai/reverb-diarization-v1" |
|
url_2 = "https://huggingface.co/csukuangfj/sherpa-onnx-reverb-diarization-v1" |
|
license_url = ( |
|
"https://huggingface.co/Revai/reverb-diarization-v1/blob/main/LICENSE" |
|
) |
|
|
|
meta_data = { |
|
"num_speakers": len(model.specifications.classes), |
|
"powerset_max_classes": model.specifications.powerset_max_classes, |
|
"num_classes": model.dimension, |
|
"sample_rate": sample_rate, |
|
"window_size": window_size, |
|
"receptive_field_size": receptive_field_size, |
|
"receptive_field_shift": receptive_field_shift, |
|
"model_type": "pyannote-segmentation-3.0", |
|
"version": "1", |
|
"model_author": "pyannote", |
|
"maintainer": "k2-fsa", |
|
"url_1": url_1, |
|
"url_2": url_2, |
|
"license": license_url, |
|
} |
|
add_meta_data(filename=filename, meta_data=meta_data) |
|
|
|
print("Generate int8 quantization models") |
|
|
|
filename_int8 = "model.int8.onnx" |
|
quantize_dynamic( |
|
model_input=filename, |
|
model_output=filename_int8, |
|
weight_type=QuantType.QUInt8, |
|
) |
|
|
|
print(f"Saved to {filename} and {filename_int8}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|