Spaces:
Build error
Build error
| import numpy as np | |
| import pyxdf | |
| from wonambi.detect.spindle import DetectSpindle, detect_Lacourse2018, detect_Wamsley2012 | |
| from scipy.signal import butter, filtfilt, iirnotch, detrend | |
| import time | |
| from portiloop.src.stimulation import Stimulator | |
| STREAM_NAMES = { | |
| 'filtered_data': 'Portiloop Filtered', | |
| 'raw_data': 'Portiloop Raw Data', | |
| 'stimuli': 'Portiloop_stimuli' | |
| } | |
| class OfflineSleepSpindleRealTimeStimulator(Stimulator): | |
| def __init__(self): | |
| self.last_detected_ts = time.time() | |
| self.wait_t = 0.4 # 400 ms | |
| self.delayer = None | |
| def stimulate(self, detection_signal): | |
| stim = False | |
| for sig in detection_signal: | |
| # We detect a stimulation | |
| if sig: | |
| # Record time of stimulation | |
| ts = time.time() | |
| # Check if time since last stimulation is long enough | |
| if ts - self.last_detected_ts > self.wait_t: | |
| if self.delayer is not None: | |
| # If we have a delayer, notify it | |
| self.delayer.detected() | |
| stim = True | |
| self.last_detected_ts = ts | |
| return stim | |
| def add_delayer(self, delayer): | |
| self.delayer = delayer | |
| self.delayer.stimulate = lambda: True | |
| def xdf2array(xdf_path, channel): | |
| xdf_data, _ = pyxdf.load_xdf(xdf_path) | |
| # Load all streams given their names | |
| filtered_stream, raw_stream, markers = None, None, None | |
| for stream in xdf_data: | |
| # print(stream['info']['name']) | |
| if stream['info']['name'][0] == STREAM_NAMES['filtered_data']: | |
| filtered_stream = stream | |
| elif stream['info']['name'][0] == STREAM_NAMES['raw_data']: | |
| raw_stream = stream | |
| elif stream['info']['name'][0] == STREAM_NAMES['stimuli']: | |
| markers = stream | |
| if filtered_stream is None or raw_stream is None: | |
| raise ValueError("One of the necessary streams could not be found. Make sure that at least one signal stream is present in XDF recording") | |
| # Add all samples from raw and filtered signals | |
| csv_list = [] | |
| diffs = [] | |
| shortest_stream = min(int(filtered_stream['footer']['info']['sample_count'][0]), | |
| int(raw_stream['footer']['info']['sample_count'][0])) | |
| for i in range(shortest_stream): | |
| if markers is not None: | |
| datapoint = [filtered_stream['time_stamps'][i], | |
| float(filtered_stream['time_series'][i, channel-1]), | |
| raw_stream['time_series'][i, channel-1], | |
| 0] | |
| else: | |
| datapoint = [filtered_stream['time_stamps'][i], | |
| float(filtered_stream['time_series'][i, channel-1]), | |
| raw_stream['time_series'][i, channel-1]] | |
| diffs.append(abs(filtered_stream['time_stamps'][i] - raw_stream['time_stamps'][i])) | |
| csv_list.append(datapoint) | |
| # Add markers | |
| columns = ["time_stamps", "online_filtered_signal_portiloop", "raw_signal"] | |
| if markers is not None: | |
| columns.append("online_stimulations_portiloop") | |
| for time_stamp in markers['time_stamps']: | |
| new_index = np.abs(filtered_stream['time_stamps'] - time_stamp).argmin() | |
| csv_list[new_index][3] = 1 | |
| return np.array(csv_list), columns | |
| def offline_detect(method, data, timesteps, freq): | |
| # Get the spindle data from the offline methods | |
| time = np.arange(0, len(data)) / freq | |
| if method == "Lacourse": | |
| detector = DetectSpindle(method='Lacourse2018') | |
| spindles, _, _ = detect_Lacourse2018(data, freq, time, detector) | |
| elif method == "Wamsley": | |
| detector = DetectSpindle(method='Wamsley2012') | |
| spindles, _, _ = detect_Wamsley2012(data, freq, time, detector) | |
| else: | |
| raise ValueError("Invalid method") | |
| # Convert the spindle data to a numpy array | |
| spindle_result = np.zeros(data.shape) | |
| for spindle in spindles: | |
| start = spindle["start"] | |
| end = spindle["end"] | |
| # Find index of timestep closest to start and end | |
| start_index = np.argmin(np.abs(timesteps - start)) | |
| end_index = np.argmin(np.abs(timesteps - end)) | |
| spindle_result[start_index:end_index] = 1 | |
| return spindle_result | |
| def offline_filter(signal, freq): | |
| # Notch filter | |
| f0 = 60.0 # Frequency to be removed from signal (Hz) | |
| Q = 100.0 # Quality factor | |
| b, a = iirnotch(f0, Q, freq) | |
| signal = filtfilt(b, a, signal) | |
| # Bandpass filter | |
| lowcut = 0.5 | |
| highcut = 40.0 | |
| order = 4 | |
| b, a = butter(order, [lowcut / (freq / 2.0), highcut / (freq / 2.0)], btype='bandpass') | |
| signal = filtfilt(b, a, signal) | |
| # Detrend the signal | |
| signal = detrend(signal) | |
| return signal | |