Explnation on using the ONNX model

#5
by yovelcohen1 - opened

Hi! Thanks for the lovely model :)
I'm trying to run the ONNX version (tried both 6/24 versions) but I get the same error with both (am able to run the pytorch based prediction that you provided).

# Load the model
six_l_onnx_path = "./ONNX-6 Layers/model.onnx"
twenty_four_l_onnx_path = "./ONNX-24 Layers/model.onnx"
ort_session = ort.InferenceSession(twenty_four_l_onnx_path)

model_name = 'audeering/wav2vec2-large-robust-24-ft-age-gender'
processor = Wav2Vec2Processor.from_pretrained(model_name)


def process_func_onnx(x: np.ndarray, sampling_rate: int) -> np.ndarray:
    y = processor(x, sampling_rate=sampling_rate)
    y = y['input_values'][0]
    y = y.reshape(1, -1)
    inputs = {ort_session.get_inputs()[0].name: y}
    output_name = ort_session.get_outputs()[0].name
    ort_outputs = ort_session.run(output_name, inputs)
    y = np.hstack(ort_outputs)
    return y


def test_stuff():
    folder = './data/audio_samples'
    results = {}
    import glob
    for file in glob.glob(f'{folder}/*.wav'):
        from pydub import AudioSegment
        segment = AudioSegment.from_wav(file)
        segment = segment.set_frame_rate(16000)
        signal = np.array(segment.get_array_of_samples(), dtype=np.float32)
        results[file] = process_func_onnx(signal, segment.frame_rate)
    return results

The error I'm seeing from onnx is:

TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]
Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x10a178230>, 'hidden_states', {'signal': array([[ 0.6791527 , -0.40761387,  1.4003403 , ..., -1.1661903 ,
        -0.94601196, -0.813905  ]], dtype=float32)}, None

any Ideas?

audEERING GmbH org

Hi Yovelcohen,

thanks for your interest in the model.
output_name must be a list but anyway, you probably don't want the hidden_states but the logits for age and gender:

import numpy as np
import onnxruntime as ort

# Load the model
six_l_onnx_path = "./ONNX-6 Layers/model.onnx"
ort_session = ort.InferenceSession(six_l_onnx_path)

def process_func_onnx(x: np.ndarray, sampling_rate: int) -> np.ndarray:
    y = x.reshape(1, -1)
    inputs = {ort_session.get_inputs()[0].name: y}
    # output_name = [ort_session.get_outputs()[0].name]
    output_name = ['logits_age', 'logits_gender']
    ort_outputs = ort_session.run(output_name, inputs)
    y = np.hstack(ort_outputs)
    return y

def test_stuff():
    folder = './data/audio_samples'
    results = {}
    import glob
    for file in glob.glob(f'{folder}/*.wav'):
        from pydub import AudioSegment
        segment = AudioSegment.from_wav(file)
        segment = segment.set_frame_rate(16000)
        signal = np.array(segment.get_array_of_samples(), dtype=np.float32)
        results[file] = process_func_onnx(signal, segment.frame_rate)
    return results

if __name__ == "__main__":
    print(test_stuff())

More information on how to interpret the outputs is found here: https://github.com/audeering/w2v2-age-gender-how-to/blob/master/notebook.ipynb

Note: As shown in the code, also the Wav2Vec2Processor is not required when using the ONNX model.

Hope this helps!
Cheers

@audmax Thanks! I'm kind of new tho the whole onnx thing, I didn't realize it was that simple :)

yovelcohen1 changed discussion status to closed

Sign up or log in to comment