Spaces:
Sleeping
Sleeping
MiloSobral
commited on
Commit
•
321613a
1
Parent(s):
c2adb11
Added ability to read from file and fixed delayer bugs
Browse files- portiloop/src/capture.py +20 -8
- portiloop/src/stimulation.py +66 -14
- portiloop/src/utils.py +21 -3
portiloop/src/capture.py
CHANGED
@@ -16,7 +16,7 @@ from portiloop.src.hardware.frontend import Frontend
|
|
16 |
from portiloop.src.hardware.leds import LEDs, Color
|
17 |
from portiloop.src.processing import FilterPipeline, int_to_float
|
18 |
from portiloop.src.config import mod_config, LEADOFF_CONFIG, FRONTEND_CONFIG, to_ads_frequency
|
19 |
-
from portiloop.src.utils import FileReader, LiveDisplay, DummyAlsaMixer, EDFRecorder, EDF_PATH
|
20 |
from IPython.display import clear_output, display
|
21 |
import ipywidgets as widgets
|
22 |
|
@@ -493,6 +493,7 @@ class Capture:
|
|
493 |
|
494 |
self.b_capture.observe(self.on_b_capture, 'value')
|
495 |
self.b_clock.observe(self.on_b_clock, 'value')
|
|
|
496 |
self.b_frequency.observe(self.on_b_frequency, 'value')
|
497 |
self.b_threshold.observe(self.on_b_threshold, 'value')
|
498 |
self.b_duration.observe(self.on_b_duration, 'value')
|
@@ -901,6 +902,8 @@ class Capture:
|
|
901 |
self.__capture_on = True
|
902 |
p_msg_io, p_msg_io_2 = mp.Pipe()
|
903 |
p_data_i, p_data_o = mp.Pipe(duplex=False)
|
|
|
|
|
904 |
|
905 |
# Initialize filtering pipeline
|
906 |
if filter:
|
@@ -933,7 +936,7 @@ class Capture:
|
|
933 |
self._p_capture.start()
|
934 |
print(f"PID capture: {self._p_capture.pid}")
|
935 |
else:
|
936 |
-
filename =
|
937 |
file_reader = FileReader(filename)
|
938 |
|
939 |
# Initialize display if requested
|
@@ -966,7 +969,7 @@ class Capture:
|
|
966 |
|
967 |
# Initialize stimulation delayer if requested
|
968 |
if not self.spindle_detection_mode == 'Fast' and stimulator is not None:
|
969 |
-
stimulation_delayer = UpStateDelayer(self.frequency, self.
|
970 |
stimulator.add_delayer(stimulation_delayer)
|
971 |
else:
|
972 |
stimulation_delayer = None
|
@@ -998,7 +1001,14 @@ class Capture:
|
|
998 |
# Convert point from int to corresponding value in microvolts
|
999 |
n_array_raw = int_to_float(np.array([point]))
|
1000 |
elif self.signal_input == "File":
|
1001 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
|
1003 |
# Go through filtering pipeline
|
1004 |
if filter:
|
@@ -1017,7 +1027,7 @@ class Capture:
|
|
1017 |
|
1018 |
# Adds point to buffer for delayed stimulation
|
1019 |
if stimulation_delayer is not None:
|
1020 |
-
stimulation_delayer.
|
1021 |
|
1022 |
# Check if detection is on or off
|
1023 |
with self._pause_detect_lock:
|
@@ -1037,8 +1047,10 @@ class Capture:
|
|
1037 |
if test_stimulus:
|
1038 |
stimulator.test_stimulus()
|
1039 |
|
1040 |
-
|
1041 |
-
|
|
|
|
|
1042 |
|
1043 |
# Add point to the buffer to send to viz and recorder
|
1044 |
buffer += filtered_point
|
@@ -1062,7 +1074,7 @@ class Capture:
|
|
1062 |
p_data_i.close()
|
1063 |
p_msg_io.close()
|
1064 |
self._p_capture.join()
|
1065 |
-
|
1066 |
|
1067 |
if record:
|
1068 |
recorder.close_recording_file()
|
|
|
16 |
from portiloop.src.hardware.leds import LEDs, Color
|
17 |
from portiloop.src.processing import FilterPipeline, int_to_float
|
18 |
from portiloop.src.config import mod_config, LEADOFF_CONFIG, FRONTEND_CONFIG, to_ads_frequency
|
19 |
+
from portiloop.src.utils import FileReader, LiveDisplay, DummyAlsaMixer, EDFRecorder, EDF_PATH, RECORDING_PATH
|
20 |
from IPython.display import clear_output, display
|
21 |
import ipywidgets as widgets
|
22 |
|
|
|
493 |
|
494 |
self.b_capture.observe(self.on_b_capture, 'value')
|
495 |
self.b_clock.observe(self.on_b_clock, 'value')
|
496 |
+
self.b_signal_input.observe(self.on_b_signal_input, 'value')
|
497 |
self.b_frequency.observe(self.on_b_frequency, 'value')
|
498 |
self.b_threshold.observe(self.on_b_threshold, 'value')
|
499 |
self.b_duration.observe(self.on_b_duration, 'value')
|
|
|
902 |
self.__capture_on = True
|
903 |
p_msg_io, p_msg_io_2 = mp.Pipe()
|
904 |
p_data_i, p_data_o = mp.Pipe(duplex=False)
|
905 |
+
else:
|
906 |
+
p_msg_io, _ = mp.Pipe()
|
907 |
|
908 |
# Initialize filtering pipeline
|
909 |
if filter:
|
|
|
936 |
self._p_capture.start()
|
937 |
print(f"PID capture: {self._p_capture.pid}")
|
938 |
else:
|
939 |
+
filename = RECORDING_PATH / 'test_recording.csv'
|
940 |
file_reader = FileReader(filename)
|
941 |
|
942 |
# Initialize display if requested
|
|
|
969 |
|
970 |
# Initialize stimulation delayer if requested
|
971 |
if not self.spindle_detection_mode == 'Fast' and stimulator is not None:
|
972 |
+
stimulation_delayer = UpStateDelayer(self.frequency, self.spindle_detection_mode == 'Peak', 0.3)
|
973 |
stimulator.add_delayer(stimulation_delayer)
|
974 |
else:
|
975 |
stimulation_delayer = None
|
|
|
1001 |
# Convert point from int to corresponding value in microvolts
|
1002 |
n_array_raw = int_to_float(np.array([point]))
|
1003 |
elif self.signal_input == "File":
|
1004 |
+
# Check if the message to stop has been sent
|
1005 |
+
with self._lock_msg_out:
|
1006 |
+
if self._msg_out == "STOP":
|
1007 |
+
break
|
1008 |
+
|
1009 |
+
index, raw_point, off_filtered_point, past_stimulation, lacourse_stimulation = file_reader.get_point()
|
1010 |
+
n_array_raw = np.array([0, raw_point, 0, 0, 0, 0, 0, 0])
|
1011 |
+
n_array_raw = np.reshape(n_array_raw, (1, 8))
|
1012 |
|
1013 |
# Go through filtering pipeline
|
1014 |
if filter:
|
|
|
1027 |
|
1028 |
# Adds point to buffer for delayed stimulation
|
1029 |
if stimulation_delayer is not None:
|
1030 |
+
stimulation_delayer.step_timesteps(filtered_point[0][channel-1])
|
1031 |
|
1032 |
# Check if detection is on or off
|
1033 |
with self._pause_detect_lock:
|
|
|
1047 |
if test_stimulus:
|
1048 |
stimulator.test_stimulus()
|
1049 |
|
1050 |
+
# Send the stimulation from the file reader
|
1051 |
+
if stimulator is not None:
|
1052 |
+
if self.signal_input == "File" and lacourse_stimulation:
|
1053 |
+
stimulator.send_stimulation("GROUND_TRUTH_STIM", False)
|
1054 |
|
1055 |
# Add point to the buffer to send to viz and recorder
|
1056 |
buffer += filtered_point
|
|
|
1074 |
p_data_i.close()
|
1075 |
p_msg_io.close()
|
1076 |
self._p_capture.join()
|
1077 |
+
self.__capture_on = False
|
1078 |
|
1079 |
if record:
|
1080 |
recorder.close_recording_file()
|
portiloop/src/stimulation.py
CHANGED
@@ -7,6 +7,11 @@ import alsaaudio
|
|
7 |
import wave
|
8 |
import pylsl
|
9 |
from scipy.signal import find_peaks
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
# Abstract interface for developers:
|
@@ -47,14 +52,14 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
47 |
channel_format='string',
|
48 |
source_id='portiloop1') # TODO: replace this by unique device identifier
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
57 |
-
|
58 |
|
59 |
# Initialize Alsa stuff
|
60 |
# Open WAV file and set PCM device
|
@@ -114,6 +119,7 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
114 |
self.last_detected_ts = ts
|
115 |
|
116 |
def send_stimulation(self, lsl_text, sound):
|
|
|
117 |
# Send lsl stimulation
|
118 |
self.lsl_outlet_markers.push_sample([lsl_text])
|
119 |
# Send sound to patient
|
@@ -137,24 +143,22 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
137 |
|
138 |
def add_delayer(self, delayer):
|
139 |
self.delayer = delayer
|
140 |
-
self.delayer.stimulate = lambda
|
141 |
|
142 |
# Class that delays stimulation to always stimulate peak or through
|
143 |
class UpStateDelayer:
|
144 |
-
def __init__(self, sample_freq,
|
145 |
'''
|
146 |
args:
|
147 |
sample_freq: int -> Sampling frequency of signal in Hz
|
148 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
149 |
'''
|
150 |
# Get number of timesteps for a whole spindle
|
151 |
-
self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
|
152 |
self.sample_freq = sample_freq
|
153 |
-
self.buffer_size = 1.5 * self.spindle_timesteps
|
154 |
self.peak = peak
|
155 |
self.buffer = []
|
156 |
self.time_to_buffer = time_to_buffer
|
157 |
-
self.stimulate =
|
158 |
|
159 |
self.state = States.NO_SPINDLE
|
160 |
|
@@ -177,7 +181,37 @@ class UpStateDelayer:
|
|
177 |
return False
|
178 |
elif self.state == States.DELAYING:
|
179 |
# Check if we are done delaying
|
180 |
-
if time.time() - self.time_started >= self.time_to_wait
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# Actually stimulate the patient after the delay
|
182 |
if self.stimulate is not None:
|
183 |
self.stimulate()
|
@@ -190,7 +224,6 @@ class UpStateDelayer:
|
|
190 |
def detected(self):
|
191 |
if self.state == States.NO_SPINDLE:
|
192 |
self.state = States.BUFFERING
|
193 |
-
self.time_started = time.time()
|
194 |
|
195 |
def compute_time_to_wait(self):
|
196 |
"""
|
@@ -203,8 +236,27 @@ class UpStateDelayer:
|
|
203 |
# Returns the index of the last peak in the buffer
|
204 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
# Compute the time until next peak and return it
|
207 |
-
|
|
|
|
|
|
|
208 |
|
209 |
class States(Enum):
|
210 |
NO_SPINDLE = 0
|
|
|
7 |
import wave
|
8 |
import pylsl
|
9 |
from scipy.signal import find_peaks
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
import alsaaudio
|
14 |
+
import pylsl
|
15 |
|
16 |
|
17 |
# Abstract interface for developers:
|
|
|
52 |
channel_format='string',
|
53 |
source_id='portiloop1') # TODO: replace this by unique device identifier
|
54 |
|
55 |
+
# lsl_markers_info_fast = pylsl.StreamInfo(name='Portiloop_stimuli_fast',
|
56 |
+
# type='Markers',
|
57 |
+
# channel_count=1,
|
58 |
+
# channel_format='string',
|
59 |
+
# source_id='portiloop1') # TODO: replace this by unique device identifier
|
60 |
|
61 |
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
62 |
+
# self.lsl_outlet_markers_fast = pylsl.StreamOutlet(lsl_markers_info_fast)
|
63 |
|
64 |
# Initialize Alsa stuff
|
65 |
# Open WAV file and set PCM device
|
|
|
119 |
self.last_detected_ts = ts
|
120 |
|
121 |
def send_stimulation(self, lsl_text, sound):
|
122 |
+
print(f"Stimulating with text: {lsl_text}")
|
123 |
# Send lsl stimulation
|
124 |
self.lsl_outlet_markers.push_sample([lsl_text])
|
125 |
# Send sound to patient
|
|
|
143 |
|
144 |
def add_delayer(self, delayer):
|
145 |
self.delayer = delayer
|
146 |
+
self.delayer.stimulate = lambda: self.send_stimulation("DELAY_STIM", True)
|
147 |
|
148 |
# Class that delays stimulation to always stimulate peak or through
|
149 |
class UpStateDelayer:
|
150 |
+
def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None):
|
151 |
'''
|
152 |
args:
|
153 |
sample_freq: int -> Sampling frequency of signal in Hz
|
154 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
155 |
'''
|
156 |
# Get number of timesteps for a whole spindle
|
|
|
157 |
self.sample_freq = sample_freq
|
|
|
158 |
self.peak = peak
|
159 |
self.buffer = []
|
160 |
self.time_to_buffer = time_to_buffer
|
161 |
+
self.stimulate = stimulate
|
162 |
|
163 |
self.state = States.NO_SPINDLE
|
164 |
|
|
|
181 |
return False
|
182 |
elif self.state == States.DELAYING:
|
183 |
# Check if we are done delaying
|
184 |
+
if time.time() - self.time_started >= self.time_to_wait:
|
185 |
+
# Actually stimulate the patient after the delay
|
186 |
+
if self.stimulate is not None:
|
187 |
+
self.stimulate()
|
188 |
+
# Reset state
|
189 |
+
self.time_to_wait = -1
|
190 |
+
self.state = States.NO_SPINDLE
|
191 |
+
return True
|
192 |
+
return False
|
193 |
+
|
194 |
+
def step_timesteps(self, point):
|
195 |
+
'''
|
196 |
+
Step the delayer, ads a point to buffer if necessary.
|
197 |
+
Returns True if stimulation is actually done
|
198 |
+
'''
|
199 |
+
if self.state == States.NO_SPINDLE:
|
200 |
+
return False
|
201 |
+
elif self.state == States.BUFFERING:
|
202 |
+
self.buffer.append(point)
|
203 |
+
# If we are done buffering, move on to the waiting stage
|
204 |
+
if len(self.buffer) >= self.time_to_buffer * self.sample_freq:
|
205 |
+
# Compute the necessary time to wait
|
206 |
+
self.time_to_wait = self.compute_time_to_wait()
|
207 |
+
self.state = States.DELAYING
|
208 |
+
self.buffer = []
|
209 |
+
self.delaying_counter = 0
|
210 |
+
return False
|
211 |
+
elif self.state == States.DELAYING:
|
212 |
+
# Check if we are done delaying
|
213 |
+
self.delaying_counter += 1
|
214 |
+
if self.delaying_counter >= self.time_to_wait * self.sample_freq:
|
215 |
# Actually stimulate the patient after the delay
|
216 |
if self.stimulate is not None:
|
217 |
self.stimulate()
|
|
|
224 |
def detected(self):
|
225 |
if self.state == States.NO_SPINDLE:
|
226 |
self.state = States.BUFFERING
|
|
|
227 |
|
228 |
def compute_time_to_wait(self):
|
229 |
"""
|
|
|
236 |
# Returns the index of the last peak in the buffer
|
237 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
238 |
|
239 |
+
# Make a figure to show the peaks
|
240 |
+
if False:
|
241 |
+
plt.figure()
|
242 |
+
plt.plot(self.buffer)
|
243 |
+
for peak in peaks:
|
244 |
+
plt.axvline(x=peak)
|
245 |
+
plt.plot(np.zeros_like(self.buffer), "--", color="gray")
|
246 |
+
plt.show()
|
247 |
+
|
248 |
+
if len(peaks) == 0:
|
249 |
+
print("No peaks found, increase buffer size")
|
250 |
+
return (self.sample_freq / 10) * (1.0 / self.sample_freq)
|
251 |
+
|
252 |
+
# Compute average distance between each peak
|
253 |
+
avg_dist = np.mean(np.diff(peaks))
|
254 |
+
|
255 |
# Compute the time until next peak and return it
|
256 |
+
if (avg_dist < len(self.buffer) - peaks[-1]):
|
257 |
+
print("Average distance between peaks is smaller than the time to last peak, decrease buffer size")
|
258 |
+
return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq)
|
259 |
+
return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq)
|
260 |
|
261 |
class States(Enum):
|
262 |
NO_SPINDLE = 0
|
portiloop/src/utils.py
CHANGED
@@ -2,9 +2,13 @@ from EDFlib.edfwriter import EDFwriter
|
|
2 |
from portilooplot.jupyter_plot import ProgressPlot
|
3 |
from pathlib import Path
|
4 |
import numpy as np
|
|
|
|
|
5 |
|
6 |
-
EDF_PATH = Path.home() / 'workspace' / 'edf_recording'
|
7 |
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class DummyAlsaMixer:
|
@@ -102,7 +106,21 @@ class LiveDisplay():
|
|
102 |
|
103 |
class FileReader:
|
104 |
def __init__(self, filename):
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def get_point(self):
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from portilooplot.jupyter_plot import ProgressPlot
|
3 |
from pathlib import Path
|
4 |
import numpy as np
|
5 |
+
import csv
|
6 |
+
import time
|
7 |
|
|
|
8 |
|
9 |
+
EDF_PATH = Path.home() / 'workspace' / 'edf_recording'
|
10 |
+
# Path to the recordings
|
11 |
+
RECORDING_PATH = Path.home() / 'portiloop-software' / 'portiloop' / 'recordings'
|
12 |
|
13 |
|
14 |
class DummyAlsaMixer:
|
|
|
106 |
|
107 |
class FileReader:
|
108 |
def __init__(self, filename):
|
109 |
+
file = open(filename, 'r')
|
110 |
+
# Open a csv file
|
111 |
+
print(f"Reading from file {filename}")
|
112 |
+
self.csv_reader = csv.reader(file, delimiter=',')
|
113 |
+
self.wait_time = 1/250.0
|
114 |
+
self.index = -1
|
115 |
+
self.last_time = time.time()
|
116 |
|
117 |
def get_point(self):
|
118 |
+
"""
|
119 |
+
Returns the next point in the file
|
120 |
+
"""
|
121 |
+
point = next(self.csv_reader)
|
122 |
+
self.index += 1
|
123 |
+
while time.time() - self.last_time < self.wait_time:
|
124 |
+
continue
|
125 |
+
self.last_time = time.time()
|
126 |
+
return self.index, float(point[0]), float(point[1]), point[2] == '1', point[3] == '1'
|