yamnet_test / python /util /tensorflow_util.py
Luis
init2
97249f0
raw history blame
No virus
1.15 kB
import tensorflow as tf
import numpy as np
def predict(model_path, waveform):
# Download the model to yamnet.tflite
interpreter = tf.lite.Interpreter(model_path)
input_details = interpreter.get_input_details()
waveform_input_index = input_details[0]['index']
output_details = interpreter.get_output_details()
scores_output_index = output_details[0]['index']
# embeddings_output_index = output_details[1]['index']
# spectrogram_output_index = output_details[2]['index']
# Input: 0.975 seconds of silence as mono 16 kHz waveform samples.
# waveform = np.zeros(int(round(0.975 * 16000)), dtype=np.float32)
waveform2 = waveform[:156000]
print(waveform2.shape) # Should print (15600,)
interpreter.resize_tensor_input(waveform_input_index, [waveform.size], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(waveform_input_index, waveform)
interpreter.invoke()
scores = interpreter.get_tensor(scores_output_index)
# print(' scores, embeddings, spectrogram: ', scores.shape, embeddings.shape, spectrogram.shape) # (N, 521) (N, 1024) (M, 64)
return scores