Luis commited on
Commit
e31b1cf
1 Parent(s): ce2098c
Files changed (2) hide show
  1. app.py +5 -1
  2. test.py +10 -10
app.py CHANGED
@@ -75,7 +75,11 @@ description = "An audio event classifier trained on the AudioSet dataset to pred
75
 
76
  demo = gr.Interface(
77
  predict_uri,
78
- inputs=[gr.inputs.Audio(type="filepath"), gr.inputs.Audio(source="microphone", type="filepath")],
 
 
 
 
79
  outputs=['image', 'image', 'image', 'text', 'text', 'text', 'text'],
80
  # examples=examples,
81
  title=title,
 
75
 
76
  demo = gr.Interface(
77
  predict_uri,
78
+ inputs=[
79
+ gr.inputs.Audio(type="filepath"),
80
+ gr.inputs.Audio(source="microphone", type="filepath"),
81
+ gr.Slider(minimum=7, maximum=21)
82
+ ],
83
  outputs=['image', 'image', 'image', 'text', 'text', 'text', 'text'],
84
  # examples=examples,
85
  title=title,
test.py CHANGED
@@ -19,7 +19,7 @@ OUT_SAMPLE_RATE = 16000
19
  OUT_PCM = 'PCM_16'
20
  CLASS_MAP_FILE = 'res/yamnet_class_map.csv'
21
  DEBUG = True
22
- SNORING_TOP_N = 7
23
  SNORING_INDEX = 38
24
  IN_MODEL_SAMPLES = 15600
25
 
@@ -68,13 +68,13 @@ def scores_to_index(scores, order):
68
  return np.argsort(means, axis=0)[order]
69
 
70
 
71
- def predict_waveform(idx, waveform):
72
  # Download the YAMNet class map (see main YAMNet model docs) to yamnet_class_map.csv
73
  # See YAMNet TF2 usage sample for class_names_from_csv() definition.
74
  scores = predict(MODEL_PATH, waveform)
75
  class_names = class_names_from_csv(CLASS_MAP_FILE)
76
 
77
- top_n = SNORING_TOP_N
78
  top_n_res = ''
79
  snoring_score = 0.0
80
  for n in range(1, top_n):
@@ -98,15 +98,15 @@ def to_float32(data):
98
  return np.float32(data)
99
 
100
 
101
- def predict_float32(idx, data):
102
- return predict_waveform(idx, to_float32(data))
103
 
104
 
105
  def split_given_size(arr, size):
106
  return np.split(arr, np.arange(size, len(arr), size))
107
 
108
 
109
- def predict_uri(audio_uri1, audio_uri2):
110
  result = ''
111
  if DEBUG: print('audio_uri1:', audio_uri1, 'audio_uri2:', audio_uri2)
112
 
@@ -129,7 +129,7 @@ def predict_uri(audio_uri1, audio_uri2):
129
  second_start = idx * predict_seconds
130
  result += (int_to_min_sec(second_start) + ', ')
131
  if len(split) == predict_samples:
132
- print_result, snoring_score = predict_float32(idx, split)
133
  result += print_result
134
  snoring_scores.append(snoring_score)
135
 
@@ -147,9 +147,9 @@ def predict_uri(audio_uri1, audio_uri2):
147
  apnea_sec = second_total - snoring_sec
148
  apnea_frequency = (apnea_sec / 10) / second_total
149
  ahi_result = str(
150
- 'snoring_sec:' + str(snoring_sec) + ', apnea_sec:' + str(apnea_sec) + ', second_total:' + str(second_total)
151
- + ', snoring_frequency:' + format_float(snoring_frequency)
152
- + ', apnea_frequency:' + format_float(apnea_frequency)
153
  )
154
 
155
  return waveform_line, mfcc_line, mfcc2_line, str(ahi_result), str(snoring_booleans), str(snoring_scores), str(result)
 
19
  OUT_PCM = 'PCM_16'
20
  CLASS_MAP_FILE = 'res/yamnet_class_map.csv'
21
  DEBUG = True
22
+ # SNORING_TOP_N = 21
23
  SNORING_INDEX = 38
24
  IN_MODEL_SAMPLES = 15600
25
 
 
68
  return np.argsort(means, axis=0)[order]
69
 
70
 
71
+ def predict_waveform(idx, waveform, top_n):
72
  # Download the YAMNet class map (see main YAMNet model docs) to yamnet_class_map.csv
73
  # See YAMNet TF2 usage sample for class_names_from_csv() definition.
74
  scores = predict(MODEL_PATH, waveform)
75
  class_names = class_names_from_csv(CLASS_MAP_FILE)
76
 
77
+ # top_n = SNORING_TOP_N
78
  top_n_res = ''
79
  snoring_score = 0.0
80
  for n in range(1, top_n):
 
98
  return np.float32(data)
99
 
100
 
101
+ def predict_float32(idx, data, top_n):
102
+ return predict_waveform(idx, to_float32(data), top_n)
103
 
104
 
105
  def split_given_size(arr, size):
106
  return np.split(arr, np.arange(size, len(arr), size))
107
 
108
 
109
+ def predict_uri(audio_uri1, audio_uri2, top_n):
110
  result = ''
111
  if DEBUG: print('audio_uri1:', audio_uri1, 'audio_uri2:', audio_uri2)
112
 
 
129
  second_start = idx * predict_seconds
130
  result += (int_to_min_sec(second_start) + ', ')
131
  if len(split) == predict_samples:
132
+ print_result, snoring_score = predict_float32(idx, split, top_n)
133
  result += print_result
134
  snoring_scores.append(snoring_score)
135
 
 
147
  apnea_sec = second_total - snoring_sec
148
  apnea_frequency = (apnea_sec / 10) / second_total
149
  ahi_result = str(
150
+ '打鼾秒数snoring_sec=' + str(snoring_sec) + ', 暂停秒数apnea_sec=' + str(apnea_sec) + ', 总秒数second_total=' + str(second_total)
151
+ + ', 打鼾频率snoring_frequency=' + str(snoring_sec) + '/' + str(second_total) + '=' + format_float(snoring_frequency)
152
+ + ', 暂停频率apnea_frequency=(' + str(apnea_sec) + '/' + str(10) + ')/' + str(second_total) + '=' + format_float(apnea_frequency)
153
  )
154
 
155
  return waveform_line, mfcc_line, mfcc2_line, str(ahi_result), str(snoring_booleans), str(snoring_scores), str(result)