justus-tobias commited on
Commit
453eb24
1 Parent(s): 941bb7a

new beat detection and more features

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +70 -41
  3. utils.py +540 -47
README.md CHANGED
@@ -10,4 +10,4 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,15 +1,17 @@
1
  from plotly.subplots import make_subplots
2
  from scipy.signal import find_peaks, butter, filtfilt
3
  import plotly.graph_objects as go
 
4
  import pandas as pd
5
  import gradio as gr
6
  import numpy as np
7
  import itertools
 
8
  import librosa
9
  import random
10
  import os
11
 
12
- from utils import getaudiodata, getBeats, plotBeattimes
13
 
14
  example_dir = "Examples"
15
  example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))]
@@ -185,34 +187,67 @@ def analyze_single(audio:gr.Audio):
185
  def getBeatsv2(audio:gr.Audio):
186
 
187
  sr, audiodata = getaudiodata(audio)
188
- _, beattimes = getBeats(audiodata, sr)
189
 
190
- fig = plotBeattimes(beattimes, audiodata, sr)
191
  beattimes_table = pd.DataFrame(data={"Beattimes":beattimes})
192
 
193
- return fig, beattimes_table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- def updateBeatsv2(beattimes_table:gr.Dataframe, audio:gr.Audio)-> go.Figure:
196
  sr, audiodata = getaudiodata(audio)
197
- return plotBeattimes(beattimes_table["Beattimes"], audiodata, sr)
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
200
 
 
 
201
 
 
202
 
 
203
 
204
 
205
  with gr.Blocks() as app:
206
 
207
  gr.Markdown("# Heartbeat")
208
  gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios")
209
- gr.Markdown("""
210
- - Beat (mean) (average heartbeat duration)
211
- - S1, S2 (mean) (average S1,S2 duration)
212
- - mean - herzschlag (synthesised) - Bild (Wave & Spectogram)
213
- - FFT & Mel Spectogram
214
- - Plot of Wave & Spectogram (Beats annotated)
215
- """)
216
 
217
  audiofile = gr.Audio(
218
  type="filepath",
@@ -220,50 +255,44 @@ with gr.Blocks() as app:
220
  sources="upload")
221
 
222
 
223
- with gr.Tab("Single Audio V2"):
224
 
225
  getBeatsbtn = gr.Button("get Beats")
 
226
 
227
  beats_wave_plot = gr.Plot()
228
- beattimes_table = gr.Dataframe(
229
- col_count=1,
230
- type='pandas',
231
- interactive=True)
232
-
233
- updateBeatsbtn = gr.Button("update Beats")
234
-
235
 
 
 
 
 
 
236
 
237
- getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_table])
238
- updateBeatsbtn.click(updateBeatsv2, inputs=[beattimes_table, audiofile], outputs=[beats_wave_plot])
239
-
240
- gr.Examples(
241
- examples=example_files,
242
- inputs=audiofile,
243
- fn=getBeatsv2,
244
- cache_examples=False
 
245
  )
246
 
247
-
248
- with gr.Tab("Single Audio V1"):
249
-
250
- analyzebtn = gr.Button("analyze")
251
-
252
- results = gr.Markdown()
253
- spectogram_wave = gr.Plot()
254
- avg_beat_plot = gr.Plot()
255
-
256
- analyzebtn.click(analyze_single, audiofile, [results, spectogram_wave, avg_beat_plot])
257
 
258
  gr.Examples(
259
  examples=example_files,
260
  inputs=audiofile,
261
- outputs=[results, spectogram_wave],
262
- fn=analyze_single,
263
  cache_examples=False
264
  )
265
 
 
266
 
 
267
 
268
 
269
  app.launch()
 
1
  from plotly.subplots import make_subplots
2
  from scipy.signal import find_peaks, butter, filtfilt
3
  import plotly.graph_objects as go
4
+ from io import StringIO
5
  import pandas as pd
6
  import gradio as gr
7
  import numpy as np
8
  import itertools
9
+ import tempfile
10
  import librosa
11
  import random
12
  import os
13
 
14
+ from utils import getaudiodata, getBeats, plotBeattimes, find_s1s2
15
 
16
  example_dir = "Examples"
17
  example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(('.wav', '.mp3', '.ogg'))]
 
187
  def getBeatsv2(audio:gr.Audio):
188
 
189
  sr, audiodata = getaudiodata(audio)
190
+ _, beattimes, audiodata = getBeats(audiodata, sr)
191
 
 
192
  beattimes_table = pd.DataFrame(data={"Beattimes":beattimes})
193
 
194
+ feature_array = find_s1s2(beattimes_table)
195
+
196
+ featuredf = pd.DataFrame(
197
+ data=feature_array,
198
+ columns=[
199
+ "Beattimes",
200
+ "S1 to S2",
201
+ "S2 to S1",
202
+ "Label (S1=0/S2=1)"]
203
+ )
204
+
205
+ # Create boolean masks for each label
206
+ mask_ones = feature_array[:, 3] == 1
207
+ mask_zeros = feature_array[:, 3] == 0
208
+
209
+ # Extract time/positions using the masks
210
+ times_label_one = feature_array[mask_ones, 0]
211
+ times_label_zero = feature_array[mask_zeros, 0]
212
+
213
+ fig = plotBeattimes(times_label_one, audiodata, sr, times_label_zero)
214
+
215
+
216
+
217
+
218
+
219
+ return fig, featuredf, (sr, audiodata)
220
+
221
+ def updateBeatsv2(beattimes_table:gr.Dataframe, audio:gr.Audio, uploadeddf:gr.File=None)-> go.Figure:
222
 
 
223
  sr, audiodata = getaudiodata(audio)
 
224
 
225
 
226
+ if uploadeddf != None:
227
+ beattimes_table = pd.read_csv(uploadeddf)
228
+
229
+ s1_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 0]["Beattimes"].to_numpy()
230
+ s2_times = beattimes_table[beattimes_table["Label (S1=0/S2=1)"] == 1]["Beattimes"].to_numpy()
231
+
232
+ fig = plotBeattimes(s1_times, audiodata, sr, s2_times)
233
+
234
+ return fig, beattimes_table
235
+
236
+ def download_df (df: pd.DataFrame):
237
 
238
+ temp_dir = tempfile.gettempdir()
239
+ temp_path = os.path.join(temp_dir, "feature_data.csv")
240
 
241
+ df.to_csv(temp_path, index=False)
242
 
243
+ return temp_path
244
 
245
 
246
  with gr.Blocks() as app:
247
 
248
  gr.Markdown("# Heartbeat")
249
  gr.Markdown("This App helps to analyze and extract Information from Heartbeat Audios")
250
+
 
 
 
 
 
 
251
 
252
  audiofile = gr.Audio(
253
  type="filepath",
 
255
  sources="upload")
256
 
257
 
258
+ with gr.Tab("Preprocessing"):
259
 
260
  getBeatsbtn = gr.Button("get Beats")
261
+ cleanedaudio = gr.Audio(label="Cleaned Audio",show_download_button=True)
262
 
263
  beats_wave_plot = gr.Plot()
 
 
 
 
 
 
 
264
 
265
+ beattimes_df = gr.Dataframe(
266
+ value = pd.DataFrame({"Beattimes":[], "S1 to S2":[], "S2 to S1":[], "Label (S1=0/S2=1)":[]}),
267
+ label="Beattimes")
268
+
269
+ with gr.Row():
270
 
271
+ csv_download = gr.DownloadButton()
272
+
273
+ updateBeatsbtn = gr.Button("update Beats")
274
+
275
+ uploadDF = gr.File(
276
+ file_count="single",
277
+ file_types=[".csv"],
278
+ label="upload a csv",
279
+ height=25
280
  )
281
 
282
+ csv_download.click(download_df, inputs=[beattimes_df], outputs=[csv_download])
283
+ getBeatsbtn.click(getBeatsv2, inputs=audiofile, outputs=[beats_wave_plot, beattimes_df, cleanedaudio])
284
+ updateBeatsbtn.click(updateBeatsv2, inputs=[beattimes_df, audiofile, uploadDF], outputs=[beats_wave_plot, beattimes_df])
 
 
 
 
 
 
 
285
 
286
  gr.Examples(
287
  examples=example_files,
288
  inputs=audiofile,
289
+ fn=getBeatsv2,
 
290
  cache_examples=False
291
  )
292
 
293
+ with gr.Tab("Analysis"):
294
 
295
+ gr.Markdown("🚨 Please make sure to first run the 'Preprocessing'")
296
 
297
 
298
  app.launch()
utils.py CHANGED
@@ -1,76 +1,534 @@
1
  import librosa
2
  import numpy as np
3
  import plotly.graph_objects as go
4
- from scipy.signal import find_peaks
 
 
 
 
 
 
 
 
5
 
6
 
7
  # GENERAL HELPER FUNCTIONS
8
- def getaudiodata(filepath)->tuple[int,np.ndarray]:
9
-
10
- audiodata, sr = librosa.load(filepath, sr=None)
11
-
12
- # Ensure audiodata is a numpy array
13
- if not isinstance(audiodata, np.ndarray):
14
- audiodata = np.array(audiodata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Check if audio is mono or stereo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if len(audiodata.shape) > 1:
18
- # If stereo, convert to mono by averaging channels
19
  audiodata = np.mean(audiodata, axis=1)
20
-
21
- audiodata = np.astype(audiodata, np.float16)
22
-
 
 
 
 
 
 
 
 
 
23
  return sr, audiodata
24
 
25
- def getBeats(audiodata:np.ndarray, sr:int):
26
- # Convert audio data to float32
27
- audiodata = audiodata.astype(np.float32)
28
 
29
- # Normalize the audio data
30
- audiodata = audiodata / np.max(np.abs(audiodata))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Set the threshold for peak detection (adjust this value as needed)
33
- threshold = 0.5 # 50% of the maximum amplitude
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Find peaks above the threshold
36
- peaks, _ = find_peaks(np.abs(audiodata), height=threshold, distance=int(sr * 0.3))
 
 
 
 
 
37
 
38
- # Convert peak indices to times
39
- peak_times = (peaks / sr)*2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Calculate tempo (beats per minute)
42
  if len(peak_times) > 1:
43
- avg_interval = np.mean(np.diff(peak_times))
44
- tempo = 60 / avg_interval
 
 
 
 
 
 
 
 
 
45
  else:
46
  tempo = 0
47
-
48
- return [tempo], peak_times
49
 
50
- def plotBeattimes(beattimes:np.ndarray, audiodata:np.ndarray, sr:int)->go.Figure:
51
 
52
- # Time array for the full audio
53
- time = (np.arange(0, len(audiodata)) / sr) * 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # CREATE BEATTIMES PLOT
56
- # Waveform plot
57
- fig = go.Figure(
58
- go.Scatter(x=time, y=audiodata, mode='lines', name='Waveform', line=dict(color='blue', width=1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
- # Add beat markers
61
- beat_amplitudes = np.interp(beattimes, time, audiodata)
 
 
 
 
 
 
 
 
62
  fig.add_trace(
63
- go.Scatter(x=beattimes, y=beat_amplitudes, mode='markers', name='Beats',
64
- marker=dict(color='red', size=8, symbol='circle'))
 
 
 
 
 
 
 
 
 
 
65
  )
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  fig.update_layout(
68
- showlegend=False
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
-
71
  return fig
72
 
73
-
74
  def iterate_beat_segments(beat_times, sr, audio):
75
  """
76
  Iterate over audio segments between beats.
@@ -96,8 +554,6 @@ def iterate_beat_segments(beat_times, sr, audio):
96
 
97
  segment_metrics = segment_analysis(segment, sr)
98
 
99
-
100
-
101
  def segment_analysis(segment, sr):
102
  """
103
  Analyze an audio segment and compute various metrics.
@@ -138,4 +594,41 @@ def segment_analysis(segment, sr):
138
  duration,
139
  s1_to_s2_duration,
140
  s2_to_s1_duration
141
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import librosa
2
  import numpy as np
3
  import plotly.graph_objects as go
4
+ from scipy.signal import savgol_filter, find_peaks
5
+ from scipy.signal import butter, filtfilt, medfilt, find_peaks, hilbert
6
+ from scipy.ndimage import gaussian_filter1d
7
+ from sklearn.cluster import KMeans
8
+ from sklearn.preprocessing import StandardScaler
9
+ from io import StringIO
10
+ import soundfile as sf
11
+ import pywt
12
+ import pandas as pd
13
 
14
 
15
  # GENERAL HELPER FUNCTIONS
16
+ def denoise_audio(audiodata: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
17
+ """
18
+ Enhanced denoising of audio signals optimized for heart sounds.
19
+ Uses a combination of bandpass filtering, adaptive wavelet denoising,
20
+ and improved spectral subtraction.
21
+
22
+ Parameters:
23
+ -----------
24
+ audiodata : np.ndarray
25
+ Input audio signal (1D numpy array)
26
+ sr : int
27
+ Sampling rate in Hz
28
+
29
+ Returns:
30
+ --------
31
+ tuple[np.ndarray, int]
32
+ Tuple containing (denoised_signal, sampling_rate)
33
+ """
34
+ # Input validation and conversion
35
+ if not isinstance(audiodata, np.ndarray) or audiodata.ndim != 1:
36
+ raise ValueError("audiodata must be a 1D numpy array")
37
+ if not isinstance(sr, int) or sr <= 0:
38
+ raise ValueError("sr must be a positive integer")
39
+
40
+ # Convert to float32 and normalize
41
+ audio = audiodata.astype(np.float32)
42
+ audio = audio / np.max(np.abs(audio))
43
+
44
+ # 1. Enhanced Bandpass Filter
45
+ # Optimize frequency range for heart sounds (20-200 Hz)
46
+ nyquist = sr / 2
47
+ low, high = 20 / nyquist, 200 / nyquist
48
+ order = 4 # Filter order
49
+ b, a = butter(order, [low, high], btype='band')
50
+ filtered = filtfilt(b, a, audio)
51
+
52
+ # 2. Adaptive Wavelet Denoising
53
+ def apply_wavelet_denoising(sig):
54
+ # Use sym4 wavelet (good for biomedical signals)
55
+ wavelet = 'sym4'
56
+ level = min(6, pywt.dwt_max_level(len(sig), pywt.Wavelet(wavelet).dec_len))
57
+
58
+ # Decompose signal
59
+ coeffs = pywt.wavedec(sig, wavelet, level=level)
60
+
61
+ # Adaptive thresholding based on level
62
+ for i in range(1, len(coeffs)):
63
+ # Calculate level-dependent threshold
64
+ sigma = np.median(np.abs(coeffs[i])) / 0.6745
65
+ threshold = sigma * np.sqrt(2 * np.log(len(coeffs[i])))
66
+ # Adjust threshold based on decomposition level
67
+ level_factor = 1 - (i / len(coeffs)) # Higher levels get lower thresholds
68
+ coeffs[i] = pywt.threshold(coeffs[i], threshold * level_factor, mode='soft')
69
+
70
+ return pywt.waverec(coeffs, wavelet)
71
+
72
+ # Apply wavelet denoising
73
+ denoised = apply_wavelet_denoising(filtered)
74
+
75
+ # Ensure consistent length
76
+ if len(denoised) != len(audio):
77
+ denoised = librosa.util.fix_length(denoised, len(audio))
78
+
79
+ # 3. Improved Spectral Subtraction
80
+ def spectral_subtract(sig):
81
+ # Parameters
82
+ frame_length = int(sr * 0.04) # 40ms frames
83
+ hop_length = frame_length // 2
84
+
85
+ # Compute STFT
86
+ D = librosa.stft(sig, n_fft=frame_length, hop_length=hop_length)
87
+ mag, phase = np.abs(D), np.angle(D)
88
+
89
+ # Estimate noise spectrum from low-energy frames
90
+ frame_energy = np.sum(mag**2, axis=0)
91
+ noise_threshold = np.percentile(frame_energy, 15)
92
+ noise_frames = mag[:, frame_energy < noise_threshold]
93
+
94
+ if noise_frames.size > 0:
95
+ noise_spectrum = np.median(noise_frames, axis=1)
96
+
97
+ # Oversubtraction factor (frequency-dependent)
98
+ freq_bins = np.fft.rfftfreq(frame_length, 1/sr)
99
+ alpha = 1.0 + 0.01 * (freq_bins / nyquist)
100
+ alpha = alpha[:len(noise_spectrum)].reshape(-1, 1)
101
+
102
+ # Spectral subtraction with flooring
103
+ mag_clean = np.maximum(mag - alpha * noise_spectrum.reshape(-1, 1), 0.01 * mag)
104
+
105
+ # Reconstruct signal
106
+ D_clean = mag_clean * np.exp(1j * phase)
107
+ return librosa.istft(D_clean, hop_length=hop_length)
108
+
109
+ return sig
110
+
111
+ # Apply spectral subtraction
112
+ final = spectral_subtract(denoised)
113
+
114
+ # Final normalization
115
+ final = final / np.max(np.abs(final))
116
+
117
+ return final, sr
118
 
119
+ def getaudiodata(filepath: str, target_sr: int = 16000) -> tuple[int, np.ndarray]:
120
+ """
121
+ Load and process audio data with consistent output properties.
122
+
123
+ Parameters:
124
+ -----------
125
+ filepath : str
126
+ Path to the audio file
127
+ target_sr : int
128
+ Target sampling rate (default: 16000 Hz)
129
+
130
+ Returns:
131
+ --------
132
+ tuple[int, np.ndarray]
133
+ Sampling rate and processed audio data with consistent properties:
134
+ - dtype: float32
135
+ - shape: (N,) mono audio
136
+ - amplitude range: [-0.95, 0.95]
137
+ - no NaN or Inf values
138
+ - C-contiguous memory layout
139
+ """
140
+ # Load audio with specified sampling rate
141
+ audiodata, sr = librosa.load(filepath, sr=target_sr)
142
+
143
+ # Ensure numpy array
144
+ audiodata = np.asarray(audiodata)
145
+
146
+ # Convert to mono if stereo
147
  if len(audiodata.shape) > 1:
 
148
  audiodata = np.mean(audiodata, axis=1)
149
+
150
+ # Handle any NaN or Inf values
151
+ audiodata = np.nan_to_num(audiodata, nan=0.0, posinf=0.0, neginf=0.0)
152
+
153
+ # Normalize to prevent clipping while maintaining relative amplitudes
154
+ max_abs = np.max(np.abs(audiodata))
155
+ if max_abs > 0: # Avoid division by zero
156
+ audiodata = audiodata * (0.95 / max_abs)
157
+
158
+ # Ensure float32 dtype and memory contiguous
159
+ audiodata = np.ascontiguousarray(audiodata, dtype=np.float32)
160
+
161
  return sr, audiodata
162
 
163
+ def getBeats(audiodata: np.ndarray, sr: int, method='envelope') -> tuple[float, np.ndarray, np.ndarray]:
164
+ """
165
+ Advanced heartbeat detection optimized for peak detection with improved sensitivity.
166
 
167
+ Parameters:
168
+ -----------
169
+ audiodata : np.ndarray
170
+ Audio time series
171
+ sr : int
172
+ Sampling rate
173
+ method : str
174
+ Detection method: 'onset', 'envelope', 'fusion' (default)
175
+
176
+ Returns:
177
+ --------
178
+ tempo : float
179
+ Estimated heart rate in BPM
180
+ peak_times : np.ndarray
181
+ Times of detected heartbeat peaks
182
+ cleaned_audio : np.ndarray
183
+ Cleaned audio signal
184
+ """
185
+ # Denoise and normalize
186
+ audiodata, sr = denoise_audio(audiodata, sr)
187
+ cleaned_audio = audiodata / np.max(np.abs(audiodata))
188
 
189
+ def get_envelope_peaks():
190
+ """Detect peaks using enhanced envelope method with better sensitivity"""
191
+ # Calculate envelope using appropriate frame sizes
192
+ hop_length = int(sr * 0.01) # 10ms hop
193
+ frame_length = int(sr * 0.04) # 40ms window
194
+
195
+ # Calculate RMS energy
196
+ rms = librosa.feature.rms(
197
+ y=cleaned_audio,
198
+ frame_length=frame_length,
199
+ hop_length=hop_length
200
+ )[0]
201
+
202
+ # Smooth the envelope (less aggressive smoothing)
203
+ rms_smooth = savgol_filter(rms, 7, 3)
204
+
205
+ # Find peaks with more lenient thresholds
206
+ peaks, properties = find_peaks(
207
+ rms_smooth,
208
+ distance=int(0.2 * (sr / hop_length)), # Minimum 0.2s between peaks (300 BPM max)
209
+ height=np.mean(rms_smooth) + 0.1 * np.std(rms_smooth), # Lower height threshold
210
+ prominence=np.mean(rms_smooth) * 0.1, # Lower prominence threshold
211
+ width=(int(0.01 * (sr / hop_length)), int(0.2 * (sr / hop_length))) # 10-200ms width
212
+ )
213
+
214
+ # Refine peak locations using original signal
215
+ refined_peaks = []
216
+ window_size = int(0.05 * sr) # 50ms window for refinement
217
+
218
+ for peak in peaks:
219
+ # Convert envelope peak to sample domain
220
+ sample_idx = peak * hop_length
221
+
222
+ # Define window boundaries
223
+ start = max(0, sample_idx - window_size//2)
224
+ end = min(len(cleaned_audio), sample_idx + window_size//2)
225
+
226
+ # Find the maximum amplitude within the window
227
+ window = np.abs(cleaned_audio[int(start):int(end)])
228
+ max_idx = np.argmax(window)
229
+ refined_peaks.append(start + max_idx)
230
+
231
+ return np.array(refined_peaks), rms_smooth
232
+
233
+ def get_onset_peaks():
234
+ """Enhanced onset detection with better sensitivity"""
235
+ # Multi-band onset detection with adjusted parameters
236
+ onset_env = librosa.onset.onset_strength(
237
+ y=cleaned_audio,
238
+ sr=sr,
239
+ hop_length=256, # Smaller hop length for better temporal resolution
240
+ aggregate=np.median,
241
+ n_mels=128
242
+ )
243
+
244
+ # More lenient thresholding
245
+ threshold = np.mean(onset_env) + 0.3 * np.std(onset_env)
246
+
247
+ # Get onset positions
248
+ onset_frames = librosa.onset.onset_detect(
249
+ onset_envelope=onset_env,
250
+ sr=sr,
251
+ hop_length=256,
252
+ backtrack=True,
253
+ threshold=threshold,
254
+ pre_max=20, # 20 frames before peak
255
+ post_max=20, # 20 frames after peak
256
+ pre_avg=25, # 25 frames before for mean
257
+ post_avg=25, # 25 frames after for mean
258
+ wait=10 # Wait 10 frames before detecting next onset
259
+ )
260
+
261
+ # Refine onset positions to peaks
262
+ refined_peaks = []
263
+ window_size = int(0.05 * sr) # 50ms window
264
+
265
+ for frame in onset_frames:
266
+ # Convert frame to sample index
267
+ sample_idx = frame * 256 # Using hop_length=256
268
+
269
+ # Define window boundaries
270
+ start = max(0, sample_idx - window_size//2)
271
+ end = min(len(cleaned_audio), sample_idx + window_size//2)
272
+
273
+ # Find the maximum amplitude within the window
274
+ window = np.abs(cleaned_audio[int(start):int(end)])
275
+ max_idx = np.argmax(window)
276
+ refined_peaks.append(start + max_idx)
277
+
278
+ return np.array(refined_peaks), onset_env
279
 
280
+ # Apply selected method
281
+ if method == 'envelope':
282
+ peaks, _ = get_envelope_peaks()
283
+
284
+ elif method == 'onset':
285
+ peaks, _ = get_onset_peaks()
286
+
287
 
288
+ else: # fusion method
289
+ # Get peaks from both methods
290
+ env_peaks, _ = get_envelope_peaks()
291
+ onset_peaks, _ = get_onset_peaks()
292
+
293
+ # Merge nearby peaks (within 50ms)
294
+ all_peaks = np.sort(np.concatenate([env_peaks, onset_peaks]))
295
+ merged_peaks = []
296
+ last_peak = -np.inf
297
+
298
+ for peak in all_peaks:
299
+ if (peak - last_peak) / sr > 0.05: # 50ms minimum separation
300
+ merged_peaks.append(peak)
301
+ last_peak = peak
302
+
303
+ peaks = np.array(merged_peaks)
304
+
305
+ # Convert peaks to times
306
+ peak_times = peaks / sr
307
 
308
+ # Calculate tempo using peak times
309
  if len(peak_times) > 1:
310
+ # Use weighted average of intervals
311
+ intervals = np.diff(peak_times)
312
+ tempos = 60 / intervals # Convert intervals to BPM
313
+
314
+ # Remove physiologically impossible tempos
315
+ valid_tempos = tempos[(tempos >= 30) & (tempos <= 300)]
316
+
317
+ if len(valid_tempos) > 0:
318
+ tempo = np.median(valid_tempos) # Use median for robustness
319
+ else:
320
+ tempo = 0
321
  else:
322
  tempo = 0
 
 
323
 
324
+ return tempo, peak_times, cleaned_audio
325
 
326
+ # def plotBeattimes(beattimes: np.ndarray, audiodata: np.ndarray, sr: int) -> go.Figure:
327
+ # """
328
+ # Plot audio waveform with beat markers.
329
+
330
+ # Parameters:
331
+ # -----------
332
+ # beattimes : np.ndarray
333
+ # Array of beat times in seconds
334
+ # audiodata : np.ndarray
335
+ # Audio time series data
336
+ # sr : int
337
+ # Sampling rate
338
+
339
+ # Returns:
340
+ # --------
341
+ # go.Figure
342
+ # Plotly figure with waveform and beat markers
343
+ # """
344
+ # # Calculate correct time array for the full audio
345
+ # time = np.arange(len(audiodata)) / sr
346
+
347
+ # # Convert beat times to sample indices
348
+ # beat_indices = np.round(beattimes * sr).astype(int)
349
+
350
+ # # Ensure indices are within bounds
351
+ # beat_indices = beat_indices[beat_indices < len(audiodata)]
352
+
353
+ # # Get actual amplitudes at beat positions
354
+ # beat_amplitudes = audiodata[beat_indices]
355
+
356
+ # # Create the figure
357
+ # fig = go.Figure()
358
+
359
+ # # Add waveform
360
+ # fig.add_trace(
361
+ # go.Scatter(
362
+ # x=time,
363
+ # y=audiodata,
364
+ # mode='lines',
365
+ # name='Waveform',
366
+ # line=dict(color='blue', width=1)
367
+ # )
368
+ # )
369
+
370
+ # # Add beat markers
371
+ # fig.add_trace(
372
+ # go.Scatter(
373
+ # x=beattimes[beat_indices < len(audiodata)], # Use filtered beat times
374
+ # y=beat_amplitudes,
375
+ # mode='markers',
376
+ # name='Beats',
377
+ # marker=dict(
378
+ # color='red',
379
+ # size=8,
380
+ # symbol='circle',
381
+ # line=dict(color='darkred', width=1)
382
+ # )
383
+ # )
384
+ # )
385
+
386
+ # # Update layout
387
+ # fig.update_layout(
388
+ # title="Audio Waveform with Beat Detection",
389
+ # xaxis_title="Time (seconds)",
390
+ # yaxis_title="Amplitude",
391
+ # showlegend=False,
392
+ # hovermode='closest',
393
+ # plot_bgcolor='white',
394
+ # legend=dict(
395
+ # yanchor="top",
396
+ # y=0.99,
397
+ # xanchor="left",
398
+ # x=0.01
399
+ # )
400
+ # )
401
+
402
+ # # Add vertical lines at beat positions (optional)
403
+ # for beat_time in beattimes[beat_indices < len(audiodata)]:
404
+ # fig.add_vline(
405
+ # x=beat_time,
406
+ # line=dict(color="rgba(255, 0, 0, 0.2)", width=1),
407
+ # layer="below"
408
+ # )
409
+
410
+ # return fig
411
 
412
+ def plotBeattimes(beattimes: np.ndarray,
413
+ audiodata: np.ndarray,
414
+ sr: int,
415
+ beattimes2: np.ndarray = None) -> go.Figure:
416
+ """
417
+ Plot audio waveform with beat markers for one or two sets of beat times.
418
+
419
+ Parameters:
420
+ -----------
421
+ beattimes : np.ndarray
422
+ Primary array of beat times in seconds (S1 beats if beattimes2 is provided)
423
+ audiodata : np.ndarray
424
+ Audio time series data
425
+ sr : int
426
+ Sampling rate
427
+ beattimes2 : np.ndarray, optional
428
+ Secondary array of beat times in seconds (S2 beats)
429
+
430
+ Returns:
431
+ --------
432
+ go.Figure
433
+ Plotly figure with waveform and beat markers
434
+ """
435
+ # Calculate time array for the full audio
436
+ time = np.arange(len(audiodata)) / sr
437
+
438
+ # Create the figure
439
+ fig = go.Figure()
440
+
441
+ # Add waveform
442
+ fig.add_trace(
443
+ go.Scatter(
444
+ x=time,
445
+ y=audiodata,
446
+ mode='lines',
447
+ name='Waveform',
448
+ line=dict(color='blue', width=1)
449
+ )
450
  )
451
+
452
+ # Process and plot primary beat times
453
+ beat_indices = np.round(beattimes * sr).astype(int)
454
+ beat_indices = beat_indices[beat_indices < len(audiodata)]
455
+ beat_amplitudes = audiodata[beat_indices]
456
+
457
+ # Define beat name based on whether secondary beats are provided
458
+ beat_name = "Beats S1" if beattimes2 is not None else "Beats"
459
+
460
+ # Add primary beat markers
461
  fig.add_trace(
462
+ go.Scatter(
463
+ x=beattimes[beat_indices < len(audiodata)],
464
+ y=beat_amplitudes,
465
+ mode='markers',
466
+ name=beat_name,
467
+ marker=dict(
468
+ color='red',
469
+ size=8,
470
+ symbol='circle',
471
+ line=dict(color='darkred', width=1)
472
+ )
473
+ )
474
  )
475
+
476
+ # Add primary beat vertical lines
477
+ for beat_time in beattimes[beat_indices < len(audiodata)]:
478
+ fig.add_vline(
479
+ x=beat_time,
480
+ line=dict(color="rgba(255, 0, 0, 0.2)", width=1),
481
+ layer="below"
482
+ )
483
+
484
+ # Process and plot secondary beat times if provided
485
+ if beattimes2 is not None:
486
+ beat_indices2 = np.round(beattimes2 * sr).astype(int)
487
+ beat_indices2 = beat_indices2[beat_indices2 < len(audiodata)]
488
+ beat_amplitudes2 = audiodata[beat_indices2]
489
+
490
+ # Add secondary beat markers
491
+ fig.add_trace(
492
+ go.Scatter(
493
+ x=beattimes2[beat_indices2 < len(audiodata)],
494
+ y=beat_amplitudes2,
495
+ mode='markers',
496
+ name="Beats S2",
497
+ marker=dict(
498
+ color='green',
499
+ size=8,
500
+ symbol='circle',
501
+ line=dict(color='darkgreen', width=1)
502
+ )
503
+ )
504
+ )
505
+
506
+ # Add secondary beat vertical lines
507
+ for beat_time in beattimes2[beat_indices2 < len(audiodata)]:
508
+ fig.add_vline(
509
+ x=beat_time,
510
+ line=dict(color="rgba(0, 255, 0, 0.2)", width=1),
511
+ layer="below"
512
+ )
513
+
514
+ # Update layout
515
  fig.update_layout(
516
+ title="Audio Waveform with Beat Detection",
517
+ xaxis_title="Time (seconds)",
518
+ yaxis_title="Amplitude",
519
+ showlegend=True, # Changed to True to show beat types
520
+ hovermode='closest',
521
+ plot_bgcolor='white',
522
+ legend=dict(
523
+ yanchor="top",
524
+ y=0.99,
525
+ xanchor="left",
526
+ x=0.01
527
+ )
528
  )
529
+
530
  return fig
531
 
 
532
  def iterate_beat_segments(beat_times, sr, audio):
533
  """
534
  Iterate over audio segments between beats.
 
554
 
555
  segment_metrics = segment_analysis(segment, sr)
556
 
 
 
557
  def segment_analysis(segment, sr):
558
  """
559
  Analyze an audio segment and compute various metrics.
 
594
  duration,
595
  s1_to_s2_duration,
596
  s2_to_s1_duration
597
+ ]
598
+
599
+ def find_s1s2(df:pd.DataFrame):
600
+
601
+
602
+ times = df['Beattimes'].to_numpy()
603
+ n_peaks = len(times)
604
+
605
+ # Initialize the feature array
606
+ feature_array = np.zeros((n_peaks, 4))
607
+
608
+ # Fill in the peak times (first column)
609
+ feature_array[:, 0] = times
610
+
611
+ # Calculate and fill distances to previous peaks (second column)
612
+ feature_array[1:, 1] = np.diff(times) # For all except first peak
613
+ feature_array[0, 1] = feature_array[1, 1] # First peak uses same as second
614
+
615
+ # Calculate and fill distances to next peaks (third column)
616
+ feature_array[:-1, 2] = np.diff(times) # For all except last peak
617
+ feature_array[-1, 2] = feature_array[-2, 2] # Last peak uses same as second-to-last
618
+
619
+ # Extract features (distances to prev and next peaks)
620
+ X = feature_array[:, 1:3]
621
+
622
+ # Scale features
623
+ scaler = StandardScaler()
624
+ X_scaled = scaler.fit_transform(X)
625
+
626
+ # Apply K-means clustering
627
+ kmeans = KMeans(n_clusters=2, random_state=42)
628
+ labels = kmeans.fit_predict(X_scaled)
629
+
630
+ # Update the labels in the feature array
631
+ feature_array[:, 3] = labels
632
+
633
+ return feature_array
634
+