File size: 5,364 Bytes
97249f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Imports
import csv
import sys

import numpy as np
import soundfile
import tensorflow as tf

from python.util.audio_util import audio_to_wav
from python.util.plt_util import plt_line, plt_mfcc, plt_mfcc2
from python.util.time_util import int_to_min_sec
from python.util.str_util import format_float, truncate_str
from python.util.tensorflow_util import predict

# Constants
# MODEL_PATH = 'res/lite-model_yamnet_tflite_1.tflite'
MODEL_PATH = 'res/lite-model_yamnet_classification_tflite_1.tflite'
OUT_SAMPLE_RATE = 16000
OUT_PCM = 'PCM_16'
CLASS_MAP_FILE = 'res/yamnet_class_map.csv'
DEBUG = True
SNORING_TOP_N = 7


# Methods
def to_ndarray(data):
    return np.array(data)


def data_to_single_channel(data):
    result = data

    try:
        result = data[:, 0]
    except IndexError:
        print("An exception occurred")

    return result


def read_single_channel(audio_path):
    data, sample_rate = soundfile.read(audio_path)
    print(' sample_rate, audio_path: ', str(sample_rate), str(audio_path))
    # print(' sample_rate, len, type, shape, shape[1]: ', str(sample_rate), len(data), str(type(data)), str(data.shape), str(data.shape[1]))

    single_channel = data_to_single_channel(data)
    single_channel_seconds = len(single_channel) / OUT_SAMPLE_RATE
    # print(' single_channel, shape: ', str(single_channel), str(single_channel.shape))
    # print(' len, seconds: ', str(len(single_channel)), str(single_channel_seconds))

    return single_channel, sample_rate


def class_names_from_csv(class_map_csv):
    """Read the class name definition file and return a list of strings."""
    if tf.is_tensor(class_map_csv):
        class_map_csv = class_map_csv.numpy()
    with open(class_map_csv) as csv_file:
        reader = csv.reader(csv_file)
        next(reader)  # Skip header
        return np.array([display_name for (_, _, display_name) in reader])


def scores_to_index(scores, order):
    means = scores.mean(axis=0)
    return np.argsort(means, axis=0)[order]


def predict_waveform(idx, waveform):
    # Download the YAMNet class map (see main YAMNet model docs) to yamnet_class_map.csv
    # See YAMNet TF2 usage sample for class_names_from_csv() definition.
    scores = predict(MODEL_PATH, waveform)
    class_names = class_names_from_csv(CLASS_MAP_FILE)

    top_n = SNORING_TOP_N
    top_n_res = ''
    snoring_score = 0.0
    for n in range(1, top_n):
        index = scores_to_index(scores, -n)
        means = scores.mean(axis=0)
        score = means[index]
        name = class_names[index]

        if name == 'Snoring':
            snoring_score = score
        top_n_res += ' ' + format_float(score) + ' [' + truncate_str(name, 4) + '], '

    snoring_tail = ('打鼾, ' + format_float(snoring_score)) if snoring_score > 0 else ''
    result = top_n_res + snoring_tail + '\n'
    if DEBUG: print(top_n_res)

    return result, snoring_score


def to_float32(data):
    return np.float32(data)


def predict_float32(idx, data):
    return predict_waveform(idx, to_float32(data))


def split_given_size(arr, size):
    return np.split(arr, np.arange(size, len(arr), size))


def predict_uri(mp3_uri):
    result = ''
    # result = ' mp3_uri: '
    # result += mp3_uri + '\n'

    mp3_input = mp3_uri
    wav_input = audio_to_wav(mp3_input) if not mp3_input.endswith('.mp3') == True else mp3_input
    predict_seconds = int(str(sys.argv[2])) if len(sys.argv) > 2 else 1

    predict_samples = 15600 #OUT_SAMPLE_RATE * predict_seconds
    single_channel, sc_sample_rate = read_single_channel(wav_input)
    splits = split_given_size(single_channel, predict_samples)
    result += ' sc_sample_rate: ' + str(sc_sample_rate) + '\n'

    second_total = len(splits) * predict_seconds
    result += (' second_total: ' + int_to_min_sec(second_total) + ', \n')
    result += '\n'
    snoring_scores = []

    for idx in range(len(splits)):
        split = splits[idx]
        second_start = idx * predict_seconds
        result += (int_to_min_sec(second_start) + ', ')
        if len(split) == predict_samples:
            print_result, snoring_score = predict_float32(idx, split)
            result += print_result
            snoring_scores.append(snoring_score)

    # plt waveform
    waveform_line = plt_line(single_channel)
    # plt mfcc
    mfcc_line = plt_mfcc(single_channel, OUT_SAMPLE_RATE)
    # plt mfcc2
    mfcc2_line = plt_mfcc2(wav_input, OUT_SAMPLE_RATE)
    # plt snoring_booleans
    snoring_booleans = list(map(lambda x: 1 if x > 0 else 0, snoring_scores))
    # calc snoring frequency
    snoring_sec = len(list(filter(lambda x: 1 if x > 0 else 0, snoring_scores)))
    snoring_frequency = snoring_sec / second_total
    apnea_sec = second_total - snoring_sec
    apnea_frequency = (apnea_sec / 10) / second_total
    ahi_result = str(
        'snoring_sec:' + str(snoring_sec) + ', apnea_sec:' + str(apnea_sec) + ', second_total:' + str(second_total)
        + ', snoring_frequency:' + format_float(snoring_frequency)
        + ', apnea_frequency:' + format_float(apnea_frequency)
    )

    return waveform_line, mfcc_line, mfcc2_line, str(ahi_result), str(snoring_booleans), str(snoring_scores), str(result)


# sys.argv
if len(sys.argv) > 1 and len(sys.argv[1]) > 0:
    res, plt = predict_uri(sys.argv[1])
    plt.show()
else:
    print('usage: python test.py /path/to/audio_file [predict_seconds]')