yamnet_test / test.py
Luis
add top_n
e31b1cf
# Imports
import csv
import sys
import numpy as np
import soundfile
import tensorflow as tf
from python.util.audio_util import audio_to_wav
from python.util.plt_util import plt_line, plt_mfcc, plt_mfcc2
from python.util.time_util import int_to_min_sec
from python.util.str_util import format_float, truncate_str
from python.util.tensorflow_util import predict
# Constants
# MODEL_PATH = 'res/lite-model_yamnet_tflite_1.tflite'
MODEL_PATH = 'res/lite-model_yamnet_classification_tflite_1.tflite'
OUT_SAMPLE_RATE = 16000
OUT_PCM = 'PCM_16'
CLASS_MAP_FILE = 'res/yamnet_class_map.csv'
DEBUG = True
# SNORING_TOP_N = 21
SNORING_INDEX = 38
IN_MODEL_SAMPLES = 15600
# Methods
def to_ndarray(data):
return np.array(data)
def data_to_single_channel(data):
result = data
try:
result = data[:, 0]
except IndexError:
print("An exception occurred")
return result
def read_single_channel(audio_path):
data, sample_rate = soundfile.read(audio_path)
print(' sample_rate, audio_path: ', str(sample_rate), str(audio_path))
# print(' sample_rate, len, type, shape, shape[1]: ', str(sample_rate), len(data), str(type(data)), str(data.shape), str(data.shape[1]))
single_channel = data_to_single_channel(data)
single_channel_seconds = len(single_channel) / OUT_SAMPLE_RATE
# print(' single_channel, shape: ', str(single_channel), str(single_channel.shape))
# print(' len, seconds: ', str(len(single_channel)), str(single_channel_seconds))
return single_channel, sample_rate
def class_names_from_csv(class_map_csv):
"""Read the class name definition file and return a list of strings."""
if tf.is_tensor(class_map_csv):
class_map_csv = class_map_csv.numpy()
with open(class_map_csv) as csv_file:
reader = csv.reader(csv_file)
next(reader) # Skip header
return np.array([display_name for (_, _, display_name) in reader])
def scores_to_index(scores, order):
means = scores.mean(axis=0)
return np.argsort(means, axis=0)[order]
def predict_waveform(idx, waveform, top_n):
# Download the YAMNet class map (see main YAMNet model docs) to yamnet_class_map.csv
# See YAMNet TF2 usage sample for class_names_from_csv() definition.
scores = predict(MODEL_PATH, waveform)
class_names = class_names_from_csv(CLASS_MAP_FILE)
# top_n = SNORING_TOP_N
top_n_res = ''
snoring_score = 0.0
for n in range(1, top_n):
index = scores_to_index(scores, -n)
means = scores.mean(axis=0)
score = means[index]
name = class_names[index]
if index == SNORING_INDEX:
snoring_score = score
top_n_res += ' ' + format_float(score) + ' [' + truncate_str(name, 4) + '], '
snoring_tail = ('打鼾, ' + format_float(snoring_score)) if snoring_score > 0 else ''
result = top_n_res + snoring_tail + '\n'
if DEBUG: print(top_n_res)
return result, snoring_score
def to_float32(data):
return np.float32(data)
def predict_float32(idx, data, top_n):
return predict_waveform(idx, to_float32(data), top_n)
def split_given_size(arr, size):
return np.split(arr, np.arange(size, len(arr), size))
def predict_uri(audio_uri1, audio_uri2, top_n):
result = ''
if DEBUG: print('audio_uri1:', audio_uri1, 'audio_uri2:', audio_uri2)
mp3_input = audio_uri1 if audio_uri2 in (None, '') else audio_uri2
wav_input = audio_to_wav(mp3_input) if not mp3_input.endswith('.mp3') == True else mp3_input
predict_seconds = int(str(sys.argv[2])) if len(sys.argv) > 2 else 1
predict_samples = IN_MODEL_SAMPLES # OUT_SAMPLE_RATE * predict_seconds
single_channel, sc_sample_rate = read_single_channel(wav_input)
splits = split_given_size(single_channel, predict_samples)
result += ' sc_sample_rate: ' + str(sc_sample_rate) + '\n'
second_total = len(splits) * predict_seconds
result += (' second_total: ' + int_to_min_sec(second_total) + ', \n')
result += '\n'
snoring_scores = []
for idx in range(len(splits)):
split = splits[idx]
second_start = idx * predict_seconds
result += (int_to_min_sec(second_start) + ', ')
if len(split) == predict_samples:
print_result, snoring_score = predict_float32(idx, split, top_n)
result += print_result
snoring_scores.append(snoring_score)
# plt waveform
waveform_line = plt_line(single_channel)
# plt mfcc
mfcc_line = plt_mfcc(single_channel, OUT_SAMPLE_RATE)
# plt mfcc2
mfcc2_line = plt_mfcc2(wav_input, OUT_SAMPLE_RATE)
# plt snoring_booleans
snoring_booleans = list(map(lambda x: 1 if x > 0 else 0, snoring_scores))
# calc snoring frequency
snoring_sec = len(list(filter(lambda x: 1 if x > 0 else 0, snoring_scores)))
snoring_frequency = snoring_sec / second_total
apnea_sec = second_total - snoring_sec
apnea_frequency = (apnea_sec / 10) / second_total
ahi_result = str(
'打鼾秒数snoring_sec=' + str(snoring_sec) + ', 暂停秒数apnea_sec=' + str(apnea_sec) + ', 总秒数second_total=' + str(second_total)
+ ', 打鼾频率snoring_frequency=' + str(snoring_sec) + '/' + str(second_total) + '=' + format_float(snoring_frequency)
+ ', 暂停频率apnea_frequency=(' + str(apnea_sec) + '/' + str(10) + ')/' + str(second_total) + '=' + format_float(apnea_frequency)
)
return waveform_line, mfcc_line, mfcc2_line, str(ahi_result), str(snoring_booleans), str(snoring_scores), str(result)
# sys.argv
if len(sys.argv) > 1 and len(sys.argv[1]) > 0:
res, plt = predict_uri(sys.argv[1])
plt.show()
else:
print('usage: python test.py /path/to/audio_file [predict_seconds]')