ecg-classification / baseline_wander_removal.py
nabeelraza's picture
Add: included butter filter
e9321a8
import numpy as np
from scipy.signal import medfilt
import matplotlib.pyplot as plt
# import seaborn as sns
# sns.set_theme()
BASIC_SRATE = 160
def get_median_filter_width(sampling_rate, duration):
res = int(sampling_rate*duration)
res += ((res % 2) - 1) # needs to be an odd number
return res
# baseline fitting by filtering
# === Define Filtering Params for Baseline fitting Leads======================
# ms_flt_array = [0.2, 0.6] #<-- length of baseline fitting filters (in seconds)
# mfa = np.zeros(len(ms_flt_array), dtype='int')
# for i in range(0, len(ms_flt_array)):
# mfa[i] = get_median_filter_width(BASIC_SRATE,ms_flt_array[i])
# def filter_signal(X):
# global mfa
# X0 = X #read orignal signal
# for mi in range(0,len(mfa)):
# X0 = medfilt(X0,mfa[mi]) # apply median filter one by one on top of each other
# print(X0 - X)
# X0 = np.subtract(X,X0) # finally subtract from orignal signal
# return X0
def bw_remover(fs, X):
kernel = get_median_filter_width(fs, 0.4)
X0 = X # read orignal signal
# apply median filter one by one on top of each other
X0 = medfilt(X0, kernel)
# print(X0 - X)
X0 = np.subtract(X, X0) # finally subtract from orignal signal
return X0
def band_pass_filter(signal):
# Bandpass filter
if type(signal) != np.ndarray:
signal = signal.to_numpy()
result = None
sig = signal.copy()
for index in range(len(signal)):
sig[index] = signal[index]
if (index >= 1):
sig[index] += 2*sig[index-1]
if (index >= 2):
sig[index] -= sig[index-2]
if (index >= 6):
sig[index] -= 2*signal[index-6]
if (index >= 12):
sig[index] += signal[index-12]
result = sig.copy()
for index in range(len(signal)):
result[index] = -1*sig[index]
if (index >= 1):
result[index] -= result[index-1]
if (index >= 16):
result[index] += 32*sig[index-16]
if (index >= 32):
result[index] += sig[index-32]
# Normalize the result from the high pass filter
# max_val = max(max(result), -min(result))
# result = result/max_val
return result
if __name__ == '__main__':
signal = np.loadtxt("./Lead 2/n1_lead2.txt")[:1000]
signal_bs = bw_remover(BASIC_SRATE, signal)
signal_flt = band_pass_filter(signal_bs)
plt.figure(figsize=(16, 9))
plt.subplot(2, 1, 1)
plt.plot(signal)
plt.grid(True)
plt.title("RAW signal")
plt.subplot(2, 1, 2)
plt.plot(signal_flt)
plt.title("baseline removed signal")
plt.grid(True)
plt.show()