# Imports import csv import sys import numpy as np import soundfile import tensorflow as tf from python.util.audio_util import audio_to_wav from python.util.plt_util import plt_line, plt_mfcc, plt_mfcc2 from python.util.time_util import int_to_min_sec from python.util.str_util import format_float, truncate_str from python.util.tensorflow_util import predict # Constants # MODEL_PATH = 'res/lite-model_yamnet_tflite_1.tflite' MODEL_PATH = 'res/lite-model_yamnet_classification_tflite_1.tflite' OUT_SAMPLE_RATE = 16000 OUT_PCM = 'PCM_16' CLASS_MAP_FILE = 'res/yamnet_class_map.csv' DEBUG = True SNORING_TOP_N = 7 # Methods def to_ndarray(data): return np.array(data) def data_to_single_channel(data): result = data try: result = data[:, 0] except IndexError: print("An exception occurred") return result def read_single_channel(audio_path): data, sample_rate = soundfile.read(audio_path) print(' sample_rate, audio_path: ', str(sample_rate), str(audio_path)) # print(' sample_rate, len, type, shape, shape[1]: ', str(sample_rate), len(data), str(type(data)), str(data.shape), str(data.shape[1])) single_channel = data_to_single_channel(data) single_channel_seconds = len(single_channel) / OUT_SAMPLE_RATE # print(' single_channel, shape: ', str(single_channel), str(single_channel.shape)) # print(' len, seconds: ', str(len(single_channel)), str(single_channel_seconds)) return single_channel, sample_rate def class_names_from_csv(class_map_csv): """Read the class name definition file and return a list of strings.""" if tf.is_tensor(class_map_csv): class_map_csv = class_map_csv.numpy() with open(class_map_csv) as csv_file: reader = csv.reader(csv_file) next(reader) # Skip header return np.array([display_name for (_, _, display_name) in reader]) def scores_to_index(scores, order): means = scores.mean(axis=0) return np.argsort(means, axis=0)[order] def predict_waveform(idx, waveform): # Download the YAMNet class map (see main YAMNet model docs) to yamnet_class_map.csv # See YAMNet TF2 usage sample for class_names_from_csv() definition. scores = predict(MODEL_PATH, waveform) class_names = class_names_from_csv(CLASS_MAP_FILE) top_n = SNORING_TOP_N top_n_res = '' snoring_score = 0.0 for n in range(1, top_n): index = scores_to_index(scores, -n) means = scores.mean(axis=0) score = means[index] name = class_names[index] if name == 'Snoring': snoring_score = score top_n_res += ' ' + format_float(score) + ' [' + truncate_str(name, 4) + '], ' snoring_tail = ('打鼾, ' + format_float(snoring_score)) if snoring_score > 0 else '' result = top_n_res + snoring_tail + '\n' if DEBUG: print(top_n_res) return result, snoring_score def to_float32(data): return np.float32(data) def predict_float32(idx, data): return predict_waveform(idx, to_float32(data)) def split_given_size(arr, size): return np.split(arr, np.arange(size, len(arr), size)) def predict_uri(mp3_uri): result = '' # result = ' mp3_uri: ' # result += mp3_uri + '\n' mp3_input = mp3_uri wav_input = audio_to_wav(mp3_input) if not mp3_input.endswith('.mp3') == True else mp3_input predict_seconds = int(str(sys.argv[2])) if len(sys.argv) > 2 else 1 predict_samples = 15600 #OUT_SAMPLE_RATE * predict_seconds single_channel, sc_sample_rate = read_single_channel(wav_input) splits = split_given_size(single_channel, predict_samples) result += ' sc_sample_rate: ' + str(sc_sample_rate) + '\n' second_total = len(splits) * predict_seconds result += (' second_total: ' + int_to_min_sec(second_total) + ', \n') result += '\n' snoring_scores = [] for idx in range(len(splits)): split = splits[idx] second_start = idx * predict_seconds result += (int_to_min_sec(second_start) + ', ') if len(split) == predict_samples: print_result, snoring_score = predict_float32(idx, split) result += print_result snoring_scores.append(snoring_score) # plt waveform waveform_line = plt_line(single_channel) # plt mfcc mfcc_line = plt_mfcc(single_channel, OUT_SAMPLE_RATE) # plt mfcc2 mfcc2_line = plt_mfcc2(wav_input, OUT_SAMPLE_RATE) # plt snoring_booleans snoring_booleans = list(map(lambda x: 1 if x > 0 else 0, snoring_scores)) # calc snoring frequency snoring_sec = len(list(filter(lambda x: 1 if x > 0 else 0, snoring_scores))) snoring_frequency = snoring_sec / second_total apnea_sec = second_total - snoring_sec apnea_frequency = (apnea_sec / 10) / second_total ahi_result = str( 'snoring_sec:' + str(snoring_sec) + ', apnea_sec:' + str(apnea_sec) + ', second_total:' + str(second_total) + ', snoring_frequency:' + format_float(snoring_frequency) + ', apnea_frequency:' + format_float(apnea_frequency) ) return waveform_line, mfcc_line, mfcc2_line, str(ahi_result), str(snoring_booleans), str(snoring_scores), str(result) # sys.argv if len(sys.argv) > 1 and len(sys.argv[1]) > 0: res, plt = predict_uri(sys.argv[1]) plt.show() else: print('usage: python test.py /path/to/audio_file [predict_seconds]')