Upload 5 files
Browse files- cyton.py +381 -0
- decode.py +612 -0
- eegembed.py +543 -0
- embed.py +424 -0
- morphism.py +436 -0
cyton.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import serial
|
| 5 |
+
import time
|
| 6 |
+
import paramiko
|
| 7 |
+
import io
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
def create_ssh_connection():
|
| 12 |
+
"""Create SSH connection to remote server"""
|
| 13 |
+
ssh = paramiko.SSHClient()
|
| 14 |
+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
| 15 |
+
try:
|
| 16 |
+
ssh.connect('topos.exypno.tech', port=420, username='albert')
|
| 17 |
+
return ssh
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Failed to connect to remote server: {e}")
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
def set_gain(ser, gain=8):
|
| 23 |
+
"""Set 2x gain on all channels (1-16) for Cyton+Daisy"""
|
| 24 |
+
print(f"Setting {gain}x gain on all channels...")
|
| 25 |
+
|
| 26 |
+
# Stop any streaming first
|
| 27 |
+
ser.write(b's')
|
| 28 |
+
time.sleep(0.5)
|
| 29 |
+
|
| 30 |
+
gain_mapping = [1, 2, 4, 6, 8, 12, 24]
|
| 31 |
+
gain_val = gain_mapping.index(gain)
|
| 32 |
+
|
| 33 |
+
# Main board channels (1-8)
|
| 34 |
+
main_channels = ['1', '2', '3', '4', '5', '6', '7', '8']
|
| 35 |
+
# Daisy board channels (9-16 represented as Q-I)
|
| 36 |
+
daisy_channels = ['Q', 'W', 'E', 'R', 'T', 'Y', 'U', 'I']
|
| 37 |
+
|
| 38 |
+
# Combine all channel commands into one string
|
| 39 |
+
commands = ''
|
| 40 |
+
for ch in main_channels + daisy_channels:
|
| 41 |
+
commands += f'x{ch}0{gain_val}0000X'
|
| 42 |
+
|
| 43 |
+
# Send all commands at once
|
| 44 |
+
ser.write(commands.encode())
|
| 45 |
+
time.sleep(0.5)
|
| 46 |
+
|
| 47 |
+
# Clear any response from the serial buffer
|
| 48 |
+
ser.reset_input_buffer()
|
| 49 |
+
|
| 50 |
+
print("Gain settings updated")
|
| 51 |
+
|
| 52 |
+
def set_sample_rate(ser, freq):
|
| 53 |
+
"""Set sample rate using the '~' command"""
|
| 54 |
+
# Sample rate mapping according to documentation
|
| 55 |
+
freq_mapping = {
|
| 56 |
+
16000: '0',
|
| 57 |
+
8000: '1',
|
| 58 |
+
4000: '2',
|
| 59 |
+
2000: '3',
|
| 60 |
+
1000: '4',
|
| 61 |
+
500: '5',
|
| 62 |
+
250: '6'
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
if freq not in freq_mapping:
|
| 66 |
+
raise ValueError(f"Unsupported frequency {freq}Hz. Supported rates: {list(freq_mapping.keys())}")
|
| 67 |
+
|
| 68 |
+
# Stop any streaming first
|
| 69 |
+
ser.write(b's')
|
| 70 |
+
time.sleep(0.5)
|
| 71 |
+
|
| 72 |
+
# Set sample rate
|
| 73 |
+
command = f"~{freq_mapping[freq]}"
|
| 74 |
+
ser.write(command.encode())
|
| 75 |
+
time.sleep(0.5)
|
| 76 |
+
|
| 77 |
+
# Clear response from buffer
|
| 78 |
+
ser.reset_input_buffer()
|
| 79 |
+
print(f"Sample rate set to {freq}Hz")
|
| 80 |
+
|
| 81 |
+
def init_board(ser):
|
| 82 |
+
"""Initialize the OpenBCI board for 16 channels"""
|
| 83 |
+
print("Initializing board...")
|
| 84 |
+
|
| 85 |
+
# Stop any previous streaming
|
| 86 |
+
ser.write(b's')
|
| 87 |
+
time.sleep(1)
|
| 88 |
+
|
| 89 |
+
# Soft reset
|
| 90 |
+
ser.write(b'v')
|
| 91 |
+
time.sleep(2)
|
| 92 |
+
|
| 93 |
+
# Clear buffers
|
| 94 |
+
ser.reset_input_buffer()
|
| 95 |
+
ser.reset_output_buffer()
|
| 96 |
+
|
| 97 |
+
# Enable 16 channel mode
|
| 98 |
+
ser.write(b'C')
|
| 99 |
+
time.sleep(1)
|
| 100 |
+
|
| 101 |
+
# Enable all channels (1-16)
|
| 102 |
+
# First 8 channels
|
| 103 |
+
commands = [b'!', b'@', b'#', b'$', b'%', b'^', b'&', b'*']
|
| 104 |
+
# Next 8 channels (Daisy module)
|
| 105 |
+
commands.extend([b'Q', b'W', b'E', b'R', b'T', b'Y', b'U', b'I'])
|
| 106 |
+
|
| 107 |
+
for cmd in commands:
|
| 108 |
+
ser.write(cmd)
|
| 109 |
+
time.sleep(0.1)
|
| 110 |
+
|
| 111 |
+
# Set high-speed mode
|
| 112 |
+
ser.write(b'\xF0\x06') # Set baud rate to 230400
|
| 113 |
+
time.sleep(1)
|
| 114 |
+
ser.baudrate = 230400
|
| 115 |
+
|
| 116 |
+
set_gain(ser, gain=2)
|
| 117 |
+
|
| 118 |
+
print("Board initialized")
|
| 119 |
+
|
| 120 |
+
def find_packet_start(ser):
|
| 121 |
+
"""Find the start of a packet by looking for 0xA0 header"""
|
| 122 |
+
while True:
|
| 123 |
+
byte = ser.read()
|
| 124 |
+
if byte[0] == 0xA0: # Header byte
|
| 125 |
+
return byte
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
def read_complete_packet(ser):
|
| 129 |
+
"""Read a complete packet ensuring proper alignment"""
|
| 130 |
+
# Find the start of packet
|
| 131 |
+
start_byte = find_packet_start(ser)
|
| 132 |
+
if not start_byte:
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
# Read remaining 32 bytes
|
| 136 |
+
remaining_bytes = ser.read(32)
|
| 137 |
+
if len(remaining_bytes) != 32:
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
# Verify footer byte (0xCx)
|
| 141 |
+
if (remaining_bytes[31] & 0xF0) != 0xC0:
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
return start_byte + remaining_bytes
|
| 145 |
+
|
| 146 |
+
def process_packet(packet):
|
| 147 |
+
"""Process a 33-byte packet and extract channel data according to documentation"""
|
| 148 |
+
if len(packet) != 33:
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
channels = []
|
| 152 |
+
for i in range(8):
|
| 153 |
+
start_idx = 2 + (i * 3) # Start at byte 2 (after header and sample number)
|
| 154 |
+
channel_data = packet[start_idx:start_idx + 3]
|
| 155 |
+
|
| 156 |
+
# Convert 24-bit to 32-bit signed int according to documentation
|
| 157 |
+
if (channel_data[0] & 0x80): # If negative number
|
| 158 |
+
value = -1 * ((((~channel_data[0] & 0xFF) << 16) |
|
| 159 |
+
((~channel_data[1] & 0xFF) << 8) |
|
| 160 |
+
(~channel_data[2] & 0xFF)) + 1)
|
| 161 |
+
else: # If positive number
|
| 162 |
+
value = (channel_data[0] << 16) | (channel_data[1] << 8) | channel_data[2]
|
| 163 |
+
|
| 164 |
+
# Convert to microvolts: 4.5V / gain / (2^23 - 1)
|
| 165 |
+
scale_factor = 4.5 / (24.0 * 8388607.0) * 1000000 # Using gain of 24
|
| 166 |
+
channels.append(value * scale_factor)
|
| 167 |
+
|
| 168 |
+
return channels
|
| 169 |
+
|
| 170 |
+
def start_sd_recording(ser, duration='G'):
|
| 171 |
+
"""Start recording to SD card with specified duration
|
| 172 |
+
Duration codes:
|
| 173 |
+
A = 5MIN
|
| 174 |
+
S = 15MIN
|
| 175 |
+
F = 30MIN
|
| 176 |
+
G = 1HR (default)
|
| 177 |
+
H = 2HR
|
| 178 |
+
J = 4HR
|
| 179 |
+
K = 12HR
|
| 180 |
+
L = 24HR
|
| 181 |
+
a = ~14sec (test)
|
| 182 |
+
"""
|
| 183 |
+
valid_durations = {'A', 'S', 'F', 'G', 'H', 'J', 'K', 'L', 'a'}
|
| 184 |
+
if duration not in valid_durations:
|
| 185 |
+
raise ValueError(f"Invalid duration code. Valid codes: {valid_durations}")
|
| 186 |
+
|
| 187 |
+
print(f"Starting SD card recording with duration code {duration}")
|
| 188 |
+
ser.write(duration.encode())
|
| 189 |
+
time.sleep(0.5)
|
| 190 |
+
ser.write(b'b')
|
| 191 |
+
time.sleep(0.5)
|
| 192 |
+
|
| 193 |
+
def stop_sd_recording(ser):
|
| 194 |
+
"""Stop SD card recording"""
|
| 195 |
+
print("Stopping SD card recording")
|
| 196 |
+
ser.write(b's')
|
| 197 |
+
time.sleep(0.5)
|
| 198 |
+
ser.write(b'j')
|
| 199 |
+
time.sleep(0.5)
|
| 200 |
+
|
| 201 |
+
def sd_record(port, duration='G', sample_rate=1000):
|
| 202 |
+
"""Record data to SD card"""
|
| 203 |
+
duration_map = {
|
| 204 |
+
'A': 5*60, # 5 minutes
|
| 205 |
+
'S': 15*60, # 15 minutes
|
| 206 |
+
'F': 30*60, # 30 minutes
|
| 207 |
+
'G': 60*60, # 1 hour
|
| 208 |
+
'H': 2*60*60, # 2 hours
|
| 209 |
+
'J': 4*60*60, # 4 hours
|
| 210 |
+
'K': 12*60*60, # 12 hours
|
| 211 |
+
'L': 24*60*60, # 24 hours
|
| 212 |
+
'a': 14 # ~14 seconds (test)
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# Open serial port
|
| 216 |
+
ser = serial.Serial(port, 115200)
|
| 217 |
+
time.sleep(2)
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
# Initialize board
|
| 221 |
+
init_board(ser)
|
| 222 |
+
|
| 223 |
+
# Set sample rate
|
| 224 |
+
set_sample_rate(ser, sample_rate)
|
| 225 |
+
|
| 226 |
+
# Start recording
|
| 227 |
+
start_sd_recording(ser, duration)
|
| 228 |
+
|
| 229 |
+
# Calculate wait time
|
| 230 |
+
wait_time = duration_map[duration]
|
| 231 |
+
start_time = time.time()
|
| 232 |
+
|
| 233 |
+
print(f"Recording to SD card for {wait_time} seconds...")
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
while (time.time() - start_time) < wait_time:
|
| 237 |
+
remaining = wait_time - (time.time() - start_time)
|
| 238 |
+
print(f"\rRecording... {remaining:.1f} seconds remaining ", end='')
|
| 239 |
+
time.sleep(0.1)
|
| 240 |
+
|
| 241 |
+
except KeyboardInterrupt:
|
| 242 |
+
print("\nRecording interrupted by user")
|
| 243 |
+
|
| 244 |
+
finally:
|
| 245 |
+
# Always stop recording
|
| 246 |
+
stop_sd_recording(ser)
|
| 247 |
+
print("\nRecording complete")
|
| 248 |
+
|
| 249 |
+
finally:
|
| 250 |
+
ser.close()
|
| 251 |
+
|
| 252 |
+
def main():
|
| 253 |
+
parser = argparse.ArgumentParser(description='OpenBCI EEG Recording Tool')
|
| 254 |
+
parser.add_argument('--port', '-p', type=str, default='/dev/ttyUSB0',
|
| 255 |
+
help='Serial port to use (default: /dev/ttyUSB0)')
|
| 256 |
+
parser.add_argument('--filename', '-o', type=str,
|
| 257 |
+
help='Output filename (default: openbci_<timestamp>.txt)')
|
| 258 |
+
parser.add_argument('--sd', action='store_true',
|
| 259 |
+
help='Record to SD card instead of streaming to PC')
|
| 260 |
+
parser.add_argument('--duration', type=str, default='G',
|
| 261 |
+
help='SD card recording duration code (default: G = 1 hour)')
|
| 262 |
+
parser.add_argument('--sample-rate', type=int, default=1000,
|
| 263 |
+
help='Sample rate in Hz (default: 1000)')
|
| 264 |
+
parser.add_argument('--remote', action='store_true',
|
| 265 |
+
help='Write to remote server instead of local file')
|
| 266 |
+
args = parser.parse_args()
|
| 267 |
+
|
| 268 |
+
if args.sd:
|
| 269 |
+
sd_record(args.port, args.duration, args.sample_rate)
|
| 270 |
+
return
|
| 271 |
+
|
| 272 |
+
if args.filename is None:
|
| 273 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 274 |
+
args.filename = f"openbci_{timestamp}.txt"
|
| 275 |
+
|
| 276 |
+
# Open serial port
|
| 277 |
+
ser = serial.Serial(args.port, 115200)
|
| 278 |
+
time.sleep(2)
|
| 279 |
+
init_board(ser)
|
| 280 |
+
|
| 281 |
+
filename = args.filename
|
| 282 |
+
|
| 283 |
+
if args.remote:
|
| 284 |
+
ssh = create_ssh_connection()
|
| 285 |
+
if not ssh:
|
| 286 |
+
print("Failed to establish SSH connection. Exiting.")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
sftp = ssh.open_sftp()
|
| 290 |
+
remote_file = sftp.open(filename, 'w')
|
| 291 |
+
|
| 292 |
+
# Write header
|
| 293 |
+
header = "Timestamp,"
|
| 294 |
+
header += ",".join([f"Channel{i+1}" for i in range(16)])
|
| 295 |
+
header += "\n"
|
| 296 |
+
remote_file.write(header)
|
| 297 |
+
else:
|
| 298 |
+
# Original local file writing
|
| 299 |
+
with open(filename, 'w') as f:
|
| 300 |
+
header = "Timestamp,"
|
| 301 |
+
header += ",".join([f"Channel{i+1}" for i in range(16)])
|
| 302 |
+
header += "\n"
|
| 303 |
+
f.write(header)
|
| 304 |
+
|
| 305 |
+
# Start streaming
|
| 306 |
+
ser.write(b'b')
|
| 307 |
+
time.sleep(0.5)
|
| 308 |
+
ser.reset_input_buffer()
|
| 309 |
+
|
| 310 |
+
print(f"Started recording to {filename}")
|
| 311 |
+
|
| 312 |
+
packet_count = 0
|
| 313 |
+
start_time = time.time()
|
| 314 |
+
buffer = io.StringIO()
|
| 315 |
+
last_write = time.time()
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
while True:
|
| 319 |
+
# Read two properly aligned packets
|
| 320 |
+
packet1 = read_complete_packet(ser)
|
| 321 |
+
if packet1:
|
| 322 |
+
packet2 = read_complete_packet(ser)
|
| 323 |
+
if packet2:
|
| 324 |
+
# Process both packets
|
| 325 |
+
data1 = process_packet(packet1)
|
| 326 |
+
data2 = process_packet(packet2)
|
| 327 |
+
|
| 328 |
+
if data1 and data2:
|
| 329 |
+
packet_count += 1
|
| 330 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
| 331 |
+
all_channels = data1 + data2 # Combine all 16 channels
|
| 332 |
+
data_str = [f"{x:.6f}" for x in all_channels]
|
| 333 |
+
line = timestamp + "," + ",".join(data_str) + "\n"
|
| 334 |
+
|
| 335 |
+
if args.remote:
|
| 336 |
+
buffer.write(line)
|
| 337 |
+
|
| 338 |
+
# Write buffer every 100ms
|
| 339 |
+
if time.time() - last_write >= 0.1:
|
| 340 |
+
remote_file.write(buffer.getvalue())
|
| 341 |
+
buffer = io.StringIO()
|
| 342 |
+
last_write = time.time()
|
| 343 |
+
else:
|
| 344 |
+
with open(filename, 'a') as f:
|
| 345 |
+
f.write(line)
|
| 346 |
+
|
| 347 |
+
# Print status every second
|
| 348 |
+
if packet_count % 125 == 0:
|
| 349 |
+
elapsed_time = time.time() - start_time
|
| 350 |
+
rate = packet_count / elapsed_time
|
| 351 |
+
print(f"\rRecording... {rate:.1f} Hz, {packet_count} packets", end='')
|
| 352 |
+
|
| 353 |
+
# Check for buffer overflow
|
| 354 |
+
if ser.in_waiting > 1000:
|
| 355 |
+
print(f"\nWarning: Buffer building up ({ser.in_waiting} bytes)")
|
| 356 |
+
ser.reset_input_buffer()
|
| 357 |
+
|
| 358 |
+
except KeyboardInterrupt:
|
| 359 |
+
# Stop streaming
|
| 360 |
+
ser.write(b's')
|
| 361 |
+
ser.close()
|
| 362 |
+
|
| 363 |
+
if args.remote:
|
| 364 |
+
# Write any remaining data in buffer
|
| 365 |
+
if buffer.getvalue():
|
| 366 |
+
remote_file.write(buffer.getvalue())
|
| 367 |
+
remote_file.close()
|
| 368 |
+
sftp.close()
|
| 369 |
+
ssh.close()
|
| 370 |
+
|
| 371 |
+
# Print final statistics
|
| 372 |
+
elapsed_time = time.time() - start_time
|
| 373 |
+
rate = packet_count / elapsed_time
|
| 374 |
+
print(f"\n\nRecording stopped")
|
| 375 |
+
print(f"Duration: {elapsed_time:.1f} seconds")
|
| 376 |
+
print(f"Packets recorded: {packet_count}")
|
| 377 |
+
print(f"Average sample rate: {rate:.1f} Hz")
|
| 378 |
+
print(f"Data saved to: {filename}")
|
| 379 |
+
|
| 380 |
+
if __name__ == "__main__":
|
| 381 |
+
main()
|
decode.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import hashlib
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import sqlite3
|
| 9 |
+
import logging
|
| 10 |
+
import argparse
|
| 11 |
+
import random
|
| 12 |
+
import traceback
|
| 13 |
+
import faiss
|
| 14 |
+
import pickle
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from collections import deque, defaultdict
|
| 17 |
+
from typing import List, Dict, Tuple, Optional, Union, Any
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
# Import from our streaming module
|
| 21 |
+
from eegembed import EEGEmbeddingStream
|
| 22 |
+
|
| 23 |
+
PRINT_DEBUG_HASH = False
|
| 24 |
+
|
| 25 |
+
def fix_encoding(s):
|
| 26 |
+
if not s:
|
| 27 |
+
return s
|
| 28 |
+
|
| 29 |
+
if isinstance(s, str):
|
| 30 |
+
b = s.encode('utf-8', 'surrogateescape')
|
| 31 |
+
else:
|
| 32 |
+
b = s
|
| 33 |
+
|
| 34 |
+
fixed = b.decode('utf-8', 'replace')
|
| 35 |
+
if 'ì' in s or 'í' in s or 'ï' in s:
|
| 36 |
+
return ""
|
| 37 |
+
|
| 38 |
+
return fixed
|
| 39 |
+
|
| 40 |
+
# Set up logging
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
level=logging.INFO,
|
| 43 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger("EEGSemanticStream")
|
| 46 |
+
|
| 47 |
+
def setup_eeg_logger(eeg_file_path):
|
| 48 |
+
"""Set up a file logger based on the EEG filename."""
|
| 49 |
+
base_name = os.path.basename(eeg_file_path)
|
| 50 |
+
file_name = os.path.splitext(base_name)[0]
|
| 51 |
+
|
| 52 |
+
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_logs")
|
| 53 |
+
if not os.path.exists(logs_dir):
|
| 54 |
+
os.makedirs(logs_dir)
|
| 55 |
+
|
| 56 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 57 |
+
log_file_path = os.path.join(logs_dir, f"{file_name}_{timestamp}.log")
|
| 58 |
+
|
| 59 |
+
log_file = open(log_file_path, "w", encoding="utf-8")
|
| 60 |
+
log_file.write(f"--- Session started at {timestamp} for EEG file: {base_name} ---\n")
|
| 61 |
+
log_file.flush()
|
| 62 |
+
|
| 63 |
+
return log_file
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class EmbeddingIndex:
|
| 67 |
+
def __init__(self, dim=1536, use_gpu=True):
|
| 68 |
+
self.dim = dim
|
| 69 |
+
self.use_gpu = use_gpu and faiss.get_num_gpus() > 0
|
| 70 |
+
self.index = None
|
| 71 |
+
self.gpu_resources = None
|
| 72 |
+
self.message_ids = []
|
| 73 |
+
|
| 74 |
+
if self.use_gpu:
|
| 75 |
+
self.gpu_resources = faiss.StandardGpuResources()
|
| 76 |
+
|
| 77 |
+
def add_embeddings(self, embeddings: np.ndarray, message_ids: List[int]):
|
| 78 |
+
logger.info(f"Building FAISS index with {len(embeddings)} embeddings")
|
| 79 |
+
|
| 80 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 81 |
+
embeddings = embeddings / (norms + 1e-8)
|
| 82 |
+
|
| 83 |
+
cpu_index = faiss.IndexFlatIP(self.dim)
|
| 84 |
+
cpu_index.add(embeddings.astype(np.float32))
|
| 85 |
+
|
| 86 |
+
if self.use_gpu:
|
| 87 |
+
try:
|
| 88 |
+
self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index)
|
| 89 |
+
logger.info("Using GPU FAISS index")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.warning(f"GPU failed: {e}. Using CPU.")
|
| 92 |
+
self.index = cpu_index
|
| 93 |
+
self.use_gpu = False
|
| 94 |
+
else:
|
| 95 |
+
self.index = cpu_index
|
| 96 |
+
logger.info("Using CPU FAISS index")
|
| 97 |
+
|
| 98 |
+
self.message_ids = message_ids
|
| 99 |
+
|
| 100 |
+
def get_current_count(self):
|
| 101 |
+
if self.index is None:
|
| 102 |
+
return 0
|
| 103 |
+
return self.index.ntotal
|
| 104 |
+
|
| 105 |
+
def search(self, query: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
|
| 106 |
+
if self.index is None:
|
| 107 |
+
raise RuntimeError("Index not initialized")
|
| 108 |
+
|
| 109 |
+
norm = np.linalg.norm(query)
|
| 110 |
+
if norm > 0:
|
| 111 |
+
query = query / norm
|
| 112 |
+
|
| 113 |
+
actual_k = min(k, self.get_current_count())
|
| 114 |
+
if actual_k == 0:
|
| 115 |
+
return np.array([]), np.array([])
|
| 116 |
+
|
| 117 |
+
similarities, indices = self.index.search(query.astype(np.float32), actual_k)
|
| 118 |
+
distances = 1.0 - similarities
|
| 119 |
+
|
| 120 |
+
labels = np.array([[self.message_ids[idx] for idx in row] for row in indices])
|
| 121 |
+
|
| 122 |
+
return distances, labels
|
| 123 |
+
|
| 124 |
+
def save(self, path: str):
|
| 125 |
+
if self.index is None:
|
| 126 |
+
raise RuntimeError("Cannot save uninitialized index")
|
| 127 |
+
|
| 128 |
+
if self.use_gpu:
|
| 129 |
+
cpu_index = faiss.index_gpu_to_cpu(self.index)
|
| 130 |
+
faiss.write_index(cpu_index, f"{path}.index")
|
| 131 |
+
else:
|
| 132 |
+
faiss.write_index(self.index, f"{path}.index")
|
| 133 |
+
|
| 134 |
+
with open(f"{path}_message_ids.pkl", 'wb') as f:
|
| 135 |
+
pickle.dump(self.message_ids, f)
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def load(cls, path: str, use_gpu: bool = True) -> 'EmbeddingIndex':
|
| 139 |
+
with open(f"{path}_message_ids.pkl", 'rb') as f:
|
| 140 |
+
message_ids = pickle.load(f)
|
| 141 |
+
|
| 142 |
+
index = cls(use_gpu=use_gpu)
|
| 143 |
+
index.message_ids = message_ids
|
| 144 |
+
|
| 145 |
+
cpu_index = faiss.read_index(f"{path}.index")
|
| 146 |
+
|
| 147 |
+
if index.use_gpu:
|
| 148 |
+
try:
|
| 149 |
+
index.index = faiss.index_cpu_to_gpu(index.gpu_resources, 0, cpu_index)
|
| 150 |
+
logger.info("Loaded existing index and moved to GPU")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.warning(f"Failed to move loaded index to GPU: {e}. Using CPU.")
|
| 153 |
+
index.index = cpu_index
|
| 154 |
+
index.use_gpu = False
|
| 155 |
+
else:
|
| 156 |
+
index.index = cpu_index
|
| 157 |
+
logger.info("Loaded existing index on CPU")
|
| 158 |
+
|
| 159 |
+
return index
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class EEGSemanticProcessor:
|
| 163 |
+
"""
|
| 164 |
+
Process EEG data through autoencoder and semantic model pipeline,
|
| 165 |
+
then lookup similar messages.
|
| 166 |
+
"""
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
autoencoder_model_path: str,
|
| 170 |
+
semantic_model_path: str,
|
| 171 |
+
nexus_db_path: str,
|
| 172 |
+
embeddings_db_path: str,
|
| 173 |
+
index_path: str = None,
|
| 174 |
+
eeg_file_path: str = None,
|
| 175 |
+
window_size: int = 624,
|
| 176 |
+
stride: int = 64,
|
| 177 |
+
batch_size: int = 32,
|
| 178 |
+
normalize: bool = True,
|
| 179 |
+
device: str = None,
|
| 180 |
+
search_k: int = 180,
|
| 181 |
+
final_k: int = 90,
|
| 182 |
+
use_raw_eeg: bool = False,
|
| 183 |
+
last_n_messages: int = 3,
|
| 184 |
+
input_dim_override: int = None,
|
| 185 |
+
save_vectors: bool = False,
|
| 186 |
+
vector_output_path: str = None
|
| 187 |
+
):
|
| 188 |
+
if device is None:
|
| 189 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 190 |
+
else:
|
| 191 |
+
self.device = torch.device(device)
|
| 192 |
+
|
| 193 |
+
logger.info(f"Using device: {self.device}")
|
| 194 |
+
|
| 195 |
+
self.last_n_messages = last_n_messages
|
| 196 |
+
self.use_raw_eeg = use_raw_eeg
|
| 197 |
+
self.input_dim_override = input_dim_override
|
| 198 |
+
|
| 199 |
+
# Initialize EEG stream
|
| 200 |
+
self.eeg_stream = EEGEmbeddingStream(
|
| 201 |
+
file_path=eeg_file_path if eeg_file_path else "",
|
| 202 |
+
model_path=autoencoder_model_path,
|
| 203 |
+
window_size=window_size,
|
| 204 |
+
stride=stride,
|
| 205 |
+
normalize=normalize,
|
| 206 |
+
batch_size=batch_size,
|
| 207 |
+
device=self.device
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Load traced semantic model
|
| 211 |
+
logger.info(f"Loading traced semantic model from {semantic_model_path}")
|
| 212 |
+
self.semantic_model = torch.jit.load(semantic_model_path, map_location=self.device)
|
| 213 |
+
self.semantic_model.eval()
|
| 214 |
+
|
| 215 |
+
# Probe to get input/output dims
|
| 216 |
+
# Try a few common input sizes to find the right one
|
| 217 |
+
self._semantic_input_dim = None
|
| 218 |
+
self._semantic_output_dim = None
|
| 219 |
+
for test_dim in [64, 10112]:
|
| 220 |
+
try:
|
| 221 |
+
dummy = torch.randn(1, test_dim, device=self.device)
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
out = self.semantic_model(dummy)
|
| 224 |
+
self._semantic_input_dim = test_dim
|
| 225 |
+
self._semantic_output_dim = out.shape[1]
|
| 226 |
+
logger.info(f"Semantic model: input_dim={self._semantic_input_dim}, output_dim={self._semantic_output_dim}")
|
| 227 |
+
break
|
| 228 |
+
except Exception:
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
if self._semantic_input_dim is None:
|
| 232 |
+
logger.warning("Could not auto-detect semantic model input dim. Will adapt at runtime.")
|
| 233 |
+
|
| 234 |
+
self.log_file = setup_eeg_logger(eeg_file_path) if eeg_file_path else None
|
| 235 |
+
|
| 236 |
+
# Initialize database connections
|
| 237 |
+
self.nexus_conn = sqlite3.connect(nexus_db_path)
|
| 238 |
+
self.embeddings_conn = sqlite3.connect(embeddings_db_path)
|
| 239 |
+
|
| 240 |
+
# Message tracking system
|
| 241 |
+
self.search_k = search_k
|
| 242 |
+
self.final_k = final_k
|
| 243 |
+
self.message_counts = defaultdict(int)
|
| 244 |
+
self.recent_messages = deque(maxlen=10)
|
| 245 |
+
self.repetition_penalty = 1.5
|
| 246 |
+
|
| 247 |
+
logger.info("Creating embedding index")
|
| 248 |
+
self.embedding_index = self._create_index(index_path)
|
| 249 |
+
|
| 250 |
+
self.error_count = 0
|
| 251 |
+
self.max_consecutive_errors = 5
|
| 252 |
+
|
| 253 |
+
self.save_vectors = save_vectors
|
| 254 |
+
self.vector_output_path = vector_output_path
|
| 255 |
+
|
| 256 |
+
if self.save_vectors:
|
| 257 |
+
self.vectors_list = []
|
| 258 |
+
self.timestamps = []
|
| 259 |
+
logger.info(f"Vector saving enabled. Output will be saved to {self.vector_output_path}")
|
| 260 |
+
|
| 261 |
+
self.previous_message_sets = deque(maxlen=self.last_n_messages)
|
| 262 |
+
|
| 263 |
+
def _create_index(self, index_path: str = None) -> EmbeddingIndex:
|
| 264 |
+
"""Create or load the embedding index for similarity search"""
|
| 265 |
+
|
| 266 |
+
cursor = self.embeddings_conn.cursor()
|
| 267 |
+
|
| 268 |
+
cursor.execute("SELECT COUNT(*) FROM embeddings")
|
| 269 |
+
db_count = cursor.fetchone()[0]
|
| 270 |
+
|
| 271 |
+
cursor.execute("SELECT MAX(message_id) FROM embeddings")
|
| 272 |
+
db_max_id = cursor.fetchone()[0]
|
| 273 |
+
|
| 274 |
+
if index_path and os.path.exists(f"{index_path}.index"):
|
| 275 |
+
try:
|
| 276 |
+
logger.info(f"Checking existing index at {index_path}")
|
| 277 |
+
|
| 278 |
+
index = EmbeddingIndex.load(index_path)
|
| 279 |
+
|
| 280 |
+
metadata_path = f"{index_path}_metadata.npz"
|
| 281 |
+
if os.path.exists(metadata_path):
|
| 282 |
+
metadata = np.load(metadata_path, allow_pickle=True)
|
| 283 |
+
saved_count = int(metadata.get('count', 0))
|
| 284 |
+
saved_max_id = int(metadata.get('max_message_id', 0))
|
| 285 |
+
|
| 286 |
+
logger.info(f"Saved index: {saved_count} items, max_id={saved_max_id}")
|
| 287 |
+
logger.info(f"Database: {db_count} items, max_id={db_max_id}")
|
| 288 |
+
|
| 289 |
+
if db_count != saved_count or db_max_id != saved_max_id:
|
| 290 |
+
logger.info("Database has changed. Recreating index...")
|
| 291 |
+
else:
|
| 292 |
+
logger.info("Database unchanged. Using existing index...")
|
| 293 |
+
return index
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.warning(f"Error checking existing index: {str(e)}")
|
| 297 |
+
logger.info("Will create new index")
|
| 298 |
+
|
| 299 |
+
logger.info("Creating new index from database...")
|
| 300 |
+
|
| 301 |
+
cursor.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id")
|
| 302 |
+
|
| 303 |
+
embeddings = []
|
| 304 |
+
message_ids = []
|
| 305 |
+
|
| 306 |
+
for message_id, emb in cursor.fetchall():
|
| 307 |
+
embedding = np.frombuffer(emb, dtype=np.float32)
|
| 308 |
+
embeddings.append(embedding)
|
| 309 |
+
message_ids.append(message_id)
|
| 310 |
+
|
| 311 |
+
if not embeddings:
|
| 312 |
+
raise ValueError("No embeddings found in database")
|
| 313 |
+
|
| 314 |
+
embeddings = np.vstack(embeddings)
|
| 315 |
+
logger.info(f"Loaded {len(embeddings)} embeddings with shape: {embeddings.shape}")
|
| 316 |
+
|
| 317 |
+
index = EmbeddingIndex(dim=embeddings.shape[1])
|
| 318 |
+
index.add_embeddings(embeddings, message_ids)
|
| 319 |
+
|
| 320 |
+
if index_path:
|
| 321 |
+
logger.info(f"Saving index to {index_path}")
|
| 322 |
+
index.save(index_path)
|
| 323 |
+
|
| 324 |
+
metadata = {
|
| 325 |
+
'count': db_count,
|
| 326 |
+
'max_message_id': db_max_id
|
| 327 |
+
}
|
| 328 |
+
np.savez(f"{index_path}_metadata.npz", **metadata)
|
| 329 |
+
|
| 330 |
+
return index
|
| 331 |
+
|
| 332 |
+
def process_eeg_embedding(self, eeg_embedding: np.ndarray) -> torch.Tensor:
|
| 333 |
+
"""Convert EEG embedding to text embedding using the traced semantic model"""
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
tensor = torch.tensor(eeg_embedding, dtype=torch.float32).to(self.device)
|
| 336 |
+
|
| 337 |
+
if len(tensor.shape) < 2:
|
| 338 |
+
tensor = tensor.unsqueeze(0)
|
| 339 |
+
|
| 340 |
+
batch_size = tensor.shape[0]
|
| 341 |
+
tensor = tensor.reshape(batch_size, -1)
|
| 342 |
+
|
| 343 |
+
# Adapt dimensions if needed
|
| 344 |
+
if self._semantic_input_dim is not None:
|
| 345 |
+
current_features = tensor.shape[1]
|
| 346 |
+
if current_features != self._semantic_input_dim:
|
| 347 |
+
if current_features < self._semantic_input_dim:
|
| 348 |
+
padded = torch.zeros(batch_size, self._semantic_input_dim, device=self.device)
|
| 349 |
+
padded[:, :current_features] = tensor
|
| 350 |
+
tensor = padded
|
| 351 |
+
else:
|
| 352 |
+
tensor = tensor[:, :self._semantic_input_dim]
|
| 353 |
+
|
| 354 |
+
return self.semantic_model(tensor)
|
| 355 |
+
|
| 356 |
+
def find_similar_messages(self, embedding: torch.Tensor, assistant_only=False) -> List[str]:
|
| 357 |
+
"""Find similar messages using the embedding index"""
|
| 358 |
+
embedding_np = embedding.detach().cpu().numpy()
|
| 359 |
+
if len(embedding_np.shape) > 1:
|
| 360 |
+
embedding_np = embedding_np.reshape(1, -1)
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
distances, indices = self.embedding_index.search(embedding_np, self.search_k)
|
| 364 |
+
distances = distances.flatten()
|
| 365 |
+
indices = indices.flatten()
|
| 366 |
+
|
| 367 |
+
cursor = self.nexus_conn.cursor()
|
| 368 |
+
candidates = []
|
| 369 |
+
|
| 370 |
+
if assistant_only:
|
| 371 |
+
query = """
|
| 372 |
+
SELECT content FROM messages
|
| 373 |
+
WHERE id = ? AND role = 'assistant'
|
| 374 |
+
"""
|
| 375 |
+
else:
|
| 376 |
+
query = """
|
| 377 |
+
SELECT content FROM messages
|
| 378 |
+
WHERE id = ?
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
for message_id, distance in zip(indices, distances):
|
| 382 |
+
cursor.execute(query, (int(message_id),))
|
| 383 |
+
if result := cursor.fetchone():
|
| 384 |
+
content = result[0]
|
| 385 |
+
candidates.append(content)
|
| 386 |
+
|
| 387 |
+
return candidates[:self.final_k]
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
logger.error(f"Error during similarity search: {str(e)}")
|
| 391 |
+
traceback.print_exc()
|
| 392 |
+
return []
|
| 393 |
+
|
| 394 |
+
def save_vectors_to_disk(self):
|
| 395 |
+
"""Save the collected vectors and timestamps to disk"""
|
| 396 |
+
if not self.vectors_list:
|
| 397 |
+
logger.warning("No vectors to save")
|
| 398 |
+
return
|
| 399 |
+
|
| 400 |
+
output_dir = os.path.dirname(self.vector_output_path)
|
| 401 |
+
if output_dir and not os.path.exists(output_dir):
|
| 402 |
+
os.makedirs(output_dir)
|
| 403 |
+
|
| 404 |
+
vectors_array = np.vstack(self.vectors_list)
|
| 405 |
+
timestamps_array = np.array(self.timestamps)
|
| 406 |
+
|
| 407 |
+
logger.info(f"Saving {len(self.vectors_list)} vectors to {self.vector_output_path}")
|
| 408 |
+
np.savez(
|
| 409 |
+
self.vector_output_path,
|
| 410 |
+
vectors=vectors_array,
|
| 411 |
+
timestamps=timestamps_array
|
| 412 |
+
)
|
| 413 |
+
logger.info(f"Vectors saved successfully to {self.vector_output_path}")
|
| 414 |
+
|
| 415 |
+
def process_streaming_embeddings(self, callback=None):
|
| 416 |
+
"""
|
| 417 |
+
Process streaming EEG embeddings through the semantic model
|
| 418 |
+
and find similar messages.
|
| 419 |
+
"""
|
| 420 |
+
self.eeg_stream.start()
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
consecutive_errors = 0
|
| 424 |
+
while True:
|
| 425 |
+
try:
|
| 426 |
+
for embedding_data in self.eeg_stream.get_embeddings(timeout=0.5):
|
| 427 |
+
try:
|
| 428 |
+
autoencoder_embedding = embedding_data['embedding']
|
| 429 |
+
semantic_embedding = self.process_eeg_embedding(autoencoder_embedding)
|
| 430 |
+
|
| 431 |
+
if self.save_vectors:
|
| 432 |
+
embedding_np = semantic_embedding.detach().cpu().numpy()
|
| 433 |
+
self.vectors_list.append(embedding_np)
|
| 434 |
+
self.timestamps.append({
|
| 435 |
+
'start': embedding_data['start_timestamp'],
|
| 436 |
+
'end': embedding_data['end_timestamp']
|
| 437 |
+
})
|
| 438 |
+
|
| 439 |
+
if len(self.vectors_list) % 100 == 0:
|
| 440 |
+
logger.info(f"Collected {len(self.vectors_list)} vectors so far")
|
| 441 |
+
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
similar_messages = self.find_similar_messages(semantic_embedding)
|
| 445 |
+
|
| 446 |
+
result = {
|
| 447 |
+
'start_timestamp': embedding_data['start_timestamp'],
|
| 448 |
+
'end_timestamp': embedding_data['end_timestamp'],
|
| 449 |
+
'processing_time': 0,
|
| 450 |
+
'similar_messages': similar_messages
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
if callback:
|
| 454 |
+
callback(result)
|
| 455 |
+
else:
|
| 456 |
+
self._print_unique_lines(result)
|
| 457 |
+
|
| 458 |
+
consecutive_errors = 0
|
| 459 |
+
|
| 460 |
+
except Exception as e:
|
| 461 |
+
print(f"Error: {str(e)}", file=sys.stderr)
|
| 462 |
+
consecutive_errors += 1
|
| 463 |
+
|
| 464 |
+
if consecutive_errors >= self.max_consecutive_errors:
|
| 465 |
+
raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})")
|
| 466 |
+
|
| 467 |
+
time.sleep(0.01)
|
| 468 |
+
|
| 469 |
+
except Exception as e:
|
| 470 |
+
if "Too many consecutive errors" in str(e):
|
| 471 |
+
raise
|
| 472 |
+
print(f"Error: {str(e)}", file=sys.stderr)
|
| 473 |
+
consecutive_errors += 1
|
| 474 |
+
if consecutive_errors >= self.max_consecutive_errors:
|
| 475 |
+
raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})")
|
| 476 |
+
time.sleep(1)
|
| 477 |
+
|
| 478 |
+
except KeyboardInterrupt:
|
| 479 |
+
pass
|
| 480 |
+
except Exception as e:
|
| 481 |
+
print(f"Fatal error: {str(e)}", file=sys.stderr)
|
| 482 |
+
finally:
|
| 483 |
+
if self.save_vectors and self.vectors_list:
|
| 484 |
+
self.save_vectors_to_disk()
|
| 485 |
+
|
| 486 |
+
self.eeg_stream.stop()
|
| 487 |
+
|
| 488 |
+
def _print_unique_lines(self, result):
|
| 489 |
+
"""Print only lines that aren't in common with the last n batches of messages"""
|
| 490 |
+
if not result['similar_messages']:
|
| 491 |
+
return
|
| 492 |
+
|
| 493 |
+
sample_size = min(42, len(result['similar_messages']))
|
| 494 |
+
current_messages = random.sample(result['similar_messages'], sample_size)
|
| 495 |
+
|
| 496 |
+
current_lines = set()
|
| 497 |
+
for message in current_messages:
|
| 498 |
+
for line in message.splitlines():
|
| 499 |
+
line = line.strip()
|
| 500 |
+
if line:
|
| 501 |
+
current_lines.add(line)
|
| 502 |
+
|
| 503 |
+
unique_lines = current_lines.copy()
|
| 504 |
+
for previous_lines in self.previous_message_sets:
|
| 505 |
+
unique_lines -= previous_lines
|
| 506 |
+
|
| 507 |
+
self.previous_message_sets.append(current_lines)
|
| 508 |
+
|
| 509 |
+
__uniq_log_empty = False
|
| 510 |
+
if unique_lines:
|
| 511 |
+
if PRINT_DEBUG_HASH:
|
| 512 |
+
unique_lines = [f"{hash} | {line}" for (hash, line) in zip(
|
| 513 |
+
map(lambda s: hashlib.md5(s.encode()).hexdigest()[:8], unique_lines),
|
| 514 |
+
unique_lines)]
|
| 515 |
+
|
| 516 |
+
unique_lines = filter(lambda s: bool(s), map(fix_encoding, unique_lines))
|
| 517 |
+
|
| 518 |
+
output_text = "\n".join(sorted(unique_lines))
|
| 519 |
+
print(output_text)
|
| 520 |
+
|
| 521 |
+
if hasattr(self, 'log_file') and self.log_file:
|
| 522 |
+
try:
|
| 523 |
+
self.log_file.write(output_text + "\n")
|
| 524 |
+
self.log_file.flush()
|
| 525 |
+
except Exception as e:
|
| 526 |
+
print(f"Error writing to log file: {str(e)}", file=sys.stderr)
|
| 527 |
+
|
| 528 |
+
elif __uniq_log_empty:
|
| 529 |
+
logger.info(f"No unique lines")
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def main():
|
| 533 |
+
parser = argparse.ArgumentParser(description='Process EEG data through semantic model and lookup similar messages')
|
| 534 |
+
|
| 535 |
+
parser.add_argument('--autoencoder', '-a', type=str, required=True,
|
| 536 |
+
help='Path to the traced autoencoder encoder model')
|
| 537 |
+
parser.add_argument('--semantic-model', '-s', type=str, required=True,
|
| 538 |
+
help='Path to the traced semantic model')
|
| 539 |
+
|
| 540 |
+
parser.add_argument('--nexus-db', '-n', type=str,
|
| 541 |
+
default=os.path.expanduser('~/.nexus/data/nexus-new.db'),
|
| 542 |
+
help='Path to the nexus database')
|
| 543 |
+
parser.add_argument('--embeddings-db', '-e', type=str, default='emb_full.db',
|
| 544 |
+
help='Path to the embeddings database')
|
| 545 |
+
parser.add_argument('--index', '-i', type=str, default='embedding_index',
|
| 546 |
+
help='Path to save/load the FAISS index')
|
| 547 |
+
|
| 548 |
+
parser.add_argument('--eeg-file', '-f', type=str, required=True,
|
| 549 |
+
help='Path to the EEG data file to monitor')
|
| 550 |
+
parser.add_argument('--window-size', type=int, default=624,
|
| 551 |
+
help='Window size in samples')
|
| 552 |
+
parser.add_argument('--stride', type=int, default=32,
|
| 553 |
+
help='Stride between windows')
|
| 554 |
+
parser.add_argument('--batch-size', type=int, default=32,
|
| 555 |
+
help='Batch size for processing')
|
| 556 |
+
parser.add_argument('--no-normalize', dest='normalize', action='store_false',
|
| 557 |
+
help='Disable normalization of EEG data')
|
| 558 |
+
|
| 559 |
+
parser.add_argument('--search-k', type=int, default=180,
|
| 560 |
+
help='Number of candidates to retrieve for selection')
|
| 561 |
+
parser.add_argument('--final-k', type=int, default=90,
|
| 562 |
+
help='Number of results to show')
|
| 563 |
+
|
| 564 |
+
parser.add_argument('--device', type=str, default=None,
|
| 565 |
+
help='Device to use (cuda or cpu)')
|
| 566 |
+
|
| 567 |
+
parser.add_argument('--last_n', type=int, default=None,
|
| 568 |
+
help='Window queue size for repetition filter')
|
| 569 |
+
|
| 570 |
+
parser.add_argument('--use-raw-eeg', action='store_true',
|
| 571 |
+
help='Use raw EEG data with semantic model (skip autoencoder)')
|
| 572 |
+
parser.add_argument('--input-dim', type=int,
|
| 573 |
+
help='Override the input dimension for the semantic model')
|
| 574 |
+
|
| 575 |
+
parser.add_argument('--save-vectors', action='store_true',
|
| 576 |
+
help='Save semantic vectors to disk without generating output')
|
| 577 |
+
parser.add_argument('--vector-output', type=str, default='semantic_vectors.npz',
|
| 578 |
+
help='Path to save the semantic vectors')
|
| 579 |
+
|
| 580 |
+
args = parser.parse_args()
|
| 581 |
+
|
| 582 |
+
processor = EEGSemanticProcessor(
|
| 583 |
+
autoencoder_model_path=args.autoencoder,
|
| 584 |
+
semantic_model_path=args.semantic_model,
|
| 585 |
+
nexus_db_path=args.nexus_db,
|
| 586 |
+
embeddings_db_path=args.embeddings_db,
|
| 587 |
+
index_path=args.index,
|
| 588 |
+
last_n_messages=args.last_n,
|
| 589 |
+
eeg_file_path=args.eeg_file,
|
| 590 |
+
window_size=args.window_size,
|
| 591 |
+
stride=args.stride,
|
| 592 |
+
batch_size=args.batch_size,
|
| 593 |
+
normalize=args.normalize,
|
| 594 |
+
device=args.device,
|
| 595 |
+
search_k=args.search_k,
|
| 596 |
+
final_k=args.final_k,
|
| 597 |
+
use_raw_eeg=args.use_raw_eeg,
|
| 598 |
+
input_dim_override=args.input_dim,
|
| 599 |
+
save_vectors=args.save_vectors,
|
| 600 |
+
vector_output_path=args.vector_output
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
try:
|
| 604 |
+
processor.process_streaming_embeddings()
|
| 605 |
+
except KeyboardInterrupt:
|
| 606 |
+
pass
|
| 607 |
+
except Exception as e:
|
| 608 |
+
print(f"Error: {str(e)}", file=sys.stderr)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
if __name__ == "__main__":
|
| 612 |
+
main()
|
eegembed.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import csv
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import threading
|
| 7 |
+
import queue
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict, Tuple, Optional, Callable, Generator, Union, Any
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(
|
| 14 |
+
level=logging.INFO,
|
| 15 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 16 |
+
)
|
| 17 |
+
logger = logging.getLogger("EEGStream")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EncoderExtractor:
|
| 21 |
+
def __init__(self, model_path, device=None, force_sequence_length=None):
|
| 22 |
+
if device is None:
|
| 23 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
else:
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
self.force_sequence_length = force_sequence_length
|
| 28 |
+
logger.info(f"Loading traced encoder from {model_path} to {self.device}")
|
| 29 |
+
self.model = torch.jit.load(model_path, map_location=self.device)
|
| 30 |
+
self.model.eval()
|
| 31 |
+
|
| 32 |
+
dummy = torch.randn(1, 16, force_sequence_length or 624, device=self.device)
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
self._embedding_size = self.model(dummy).shape[1]
|
| 35 |
+
logger.info(f"Embedding size: {self._embedding_size}")
|
| 36 |
+
|
| 37 |
+
def get_embedding_size(self):
|
| 38 |
+
return self._embedding_size
|
| 39 |
+
|
| 40 |
+
def embed(self, data):
|
| 41 |
+
if self.force_sequence_length and data.shape[2] != self.force_sequence_length:
|
| 42 |
+
data = torch.nn.functional.interpolate(
|
| 43 |
+
data, size=self.force_sequence_length, mode='linear', align_corners=False
|
| 44 |
+
)
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
return self.model(data.to(self.device))
|
| 47 |
+
|
| 48 |
+
class EEGFileWatcher:
|
| 49 |
+
"""
|
| 50 |
+
Watches a CSV file for new data and yields new lines as they appear.
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self, file_path: str, poll_interval: float = 0.1):
|
| 53 |
+
"""
|
| 54 |
+
Initialize the file watcher.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
file_path: Path to the CSV file to watch
|
| 58 |
+
poll_interval: How often to check for new data (in seconds)
|
| 59 |
+
"""
|
| 60 |
+
self.file_path = Path(file_path)
|
| 61 |
+
self.poll_interval = poll_interval
|
| 62 |
+
self.last_position = 0
|
| 63 |
+
self.running = False
|
| 64 |
+
self.thread = None
|
| 65 |
+
self.queue = queue.Queue()
|
| 66 |
+
self.header = None
|
| 67 |
+
|
| 68 |
+
def start(self):
|
| 69 |
+
"""Start watching the file in a background thread."""
|
| 70 |
+
if self.running:
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
self.running = True
|
| 74 |
+
self.thread = threading.Thread(target=self._watch_file, daemon=True)
|
| 75 |
+
self.thread.start()
|
| 76 |
+
|
| 77 |
+
def stop(self):
|
| 78 |
+
"""Stop watching the file."""
|
| 79 |
+
self.running = False
|
| 80 |
+
if self.thread:
|
| 81 |
+
self.thread.join(timeout=1.0)
|
| 82 |
+
|
| 83 |
+
def _watch_file(self):
|
| 84 |
+
"""Background thread that watches the file for changes."""
|
| 85 |
+
# Wait for the file to exist
|
| 86 |
+
while self.running and not self.file_path.exists():
|
| 87 |
+
logger.info(f"Waiting for file {self.file_path} to exist...")
|
| 88 |
+
time.sleep(self.poll_interval)
|
| 89 |
+
|
| 90 |
+
logger.info(f"File {self.file_path} found, starting to watch")
|
| 91 |
+
|
| 92 |
+
# Keep track of the file position
|
| 93 |
+
self.last_position = 0
|
| 94 |
+
|
| 95 |
+
# Read header first
|
| 96 |
+
try:
|
| 97 |
+
with open(self.file_path, 'r') as f:
|
| 98 |
+
self.header = f.readline().strip()
|
| 99 |
+
self.last_position = f.tell()
|
| 100 |
+
self.queue.put(self.header)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"Error reading header: {e}")
|
| 103 |
+
|
| 104 |
+
while self.running:
|
| 105 |
+
try:
|
| 106 |
+
# Check if the file has grown
|
| 107 |
+
current_size = self.file_path.stat().st_size
|
| 108 |
+
if current_size > self.last_position:
|
| 109 |
+
# Read new data
|
| 110 |
+
with open(self.file_path, 'r') as f:
|
| 111 |
+
f.seek(self.last_position)
|
| 112 |
+
new_data = f.read()
|
| 113 |
+
self.last_position = f.tell()
|
| 114 |
+
|
| 115 |
+
# Process new lines (excluding partial lines)
|
| 116 |
+
lines = new_data.split('\n')
|
| 117 |
+
if not new_data.endswith('\n'):
|
| 118 |
+
# The last line might be incomplete, so we'll read it again next time
|
| 119 |
+
self.last_position -= len(lines[-1])
|
| 120 |
+
lines = lines[:-1]
|
| 121 |
+
|
| 122 |
+
# Add complete lines to the queue
|
| 123 |
+
for line in lines:
|
| 124 |
+
if line.strip(): # Skip empty lines
|
| 125 |
+
self.queue.put(line)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Error watching file: {e}")
|
| 128 |
+
|
| 129 |
+
time.sleep(self.poll_interval)
|
| 130 |
+
|
| 131 |
+
def get_new_lines(self, timeout: Optional[float] = None) -> List[str]:
|
| 132 |
+
"""
|
| 133 |
+
Get any new lines that have been read since the last call.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
timeout: How long to wait for new data (in seconds). None means don't wait.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
List of new lines (might be empty if no new data)
|
| 140 |
+
"""
|
| 141 |
+
lines = []
|
| 142 |
+
try:
|
| 143 |
+
# Get the first line (with timeout)
|
| 144 |
+
line = self.queue.get(timeout=timeout)
|
| 145 |
+
lines.append(line)
|
| 146 |
+
|
| 147 |
+
# Get any remaining lines (without waiting)
|
| 148 |
+
while True:
|
| 149 |
+
try:
|
| 150 |
+
line = self.queue.get_nowait()
|
| 151 |
+
lines.append(line)
|
| 152 |
+
except queue.Empty:
|
| 153 |
+
break
|
| 154 |
+
except queue.Empty:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
return lines
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class SlidingWindowProcessor:
|
| 161 |
+
"""
|
| 162 |
+
Processes data using a sliding window approach.
|
| 163 |
+
"""
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
window_size: int,
|
| 167 |
+
stride: int,
|
| 168 |
+
num_channels: int,
|
| 169 |
+
channel_means: List[float],
|
| 170 |
+
channel_stds: List[float],
|
| 171 |
+
normalize: bool = True
|
| 172 |
+
):
|
| 173 |
+
"""
|
| 174 |
+
Initialize the sliding window processor.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
window_size: Number of data points in each window
|
| 178 |
+
stride: Number of data points to advance between windows
|
| 179 |
+
num_channels: Number of data channels
|
| 180 |
+
channel_means: Mean value for each channel (for normalization)
|
| 181 |
+
channel_stds: Standard deviation for each channel (for normalization)
|
| 182 |
+
normalize: Whether to normalize the data
|
| 183 |
+
"""
|
| 184 |
+
self.window_size = window_size
|
| 185 |
+
self.stride = stride
|
| 186 |
+
self.num_channels = num_channels
|
| 187 |
+
self.channel_means = np.array(channel_means)
|
| 188 |
+
self.channel_stds = np.array(channel_stds)
|
| 189 |
+
self.normalize = normalize
|
| 190 |
+
|
| 191 |
+
# Buffer to hold data points
|
| 192 |
+
self.buffer = []
|
| 193 |
+
|
| 194 |
+
# Current position in the buffer
|
| 195 |
+
self.current_pos = 0
|
| 196 |
+
|
| 197 |
+
def add_data(self, data_points: List[Dict[str, Union[str, float]]]):
|
| 198 |
+
"""
|
| 199 |
+
Add new data points to the buffer.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
data_points: List of data points. Each point should be a dictionary with
|
| 203 |
+
'timestamp' and channel values.
|
| 204 |
+
"""
|
| 205 |
+
self.buffer.extend(data_points)
|
| 206 |
+
|
| 207 |
+
def get_windows(self) -> Generator[Tuple[List[str], np.ndarray], None, None]:
|
| 208 |
+
"""
|
| 209 |
+
Generate windows from the buffered data using the sliding window approach.
|
| 210 |
+
|
| 211 |
+
Yields:
|
| 212 |
+
Tuple of (timestamps, data array) for each window
|
| 213 |
+
data array shape: [num_channels, window_size]
|
| 214 |
+
"""
|
| 215 |
+
while self.current_pos + self.window_size <= len(self.buffer):
|
| 216 |
+
# Extract window
|
| 217 |
+
window = self.buffer[self.current_pos:self.current_pos + self.window_size]
|
| 218 |
+
|
| 219 |
+
# Extract timestamps
|
| 220 |
+
timestamps = [point['timestamp'] for point in window]
|
| 221 |
+
|
| 222 |
+
# Extract data
|
| 223 |
+
data = np.zeros((self.num_channels, self.window_size), dtype=np.float32)
|
| 224 |
+
for i, point in enumerate(window):
|
| 225 |
+
for c in range(self.num_channels):
|
| 226 |
+
channel_key = f'Channel{c+1}'
|
| 227 |
+
if channel_key in point:
|
| 228 |
+
data[c, i] = point[channel_key]
|
| 229 |
+
|
| 230 |
+
# Normalize if requested
|
| 231 |
+
if self.normalize:
|
| 232 |
+
for c in range(self.num_channels):
|
| 233 |
+
if self.channel_stds[c] > 0:
|
| 234 |
+
data[c] = (data[c] - self.channel_means[c]) / self.channel_stds[c]
|
| 235 |
+
|
| 236 |
+
yield timestamps, data
|
| 237 |
+
|
| 238 |
+
# Advance by stride
|
| 239 |
+
self.current_pos += self.stride
|
| 240 |
+
|
| 241 |
+
# Remove processed data points that are no longer needed
|
| 242 |
+
if self.current_pos > 0:
|
| 243 |
+
# Keep the last (window_size - stride) points for the next window
|
| 244 |
+
keep_from = max(0, self.current_pos - (self.window_size - self.stride))
|
| 245 |
+
self.buffer = self.buffer[keep_from:]
|
| 246 |
+
self.current_pos = max(0, self.current_pos - keep_from)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class EEGEmbeddingStream:
|
| 250 |
+
"""
|
| 251 |
+
Stream of EEG embeddings from a live CSV file.
|
| 252 |
+
"""
|
| 253 |
+
def __init__(
|
| 254 |
+
self,
|
| 255 |
+
file_path: str,
|
| 256 |
+
model_path: str,
|
| 257 |
+
window_size: int = 256,
|
| 258 |
+
stride: int = 64,
|
| 259 |
+
normalizer_params: Dict[str, List[float]] = None,
|
| 260 |
+
poll_interval: float = 0.1,
|
| 261 |
+
batch_size: int = 32,
|
| 262 |
+
normalize: bool = True,
|
| 263 |
+
device: str = None,
|
| 264 |
+
start_from_timestamp: str = None,
|
| 265 |
+
force_sequence_length: int = None # New parameter
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Initialize the EEG embedding stream.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
file_path: Path to the CSV file to watch
|
| 272 |
+
model_path: Path to the trained model checkpoint
|
| 273 |
+
window_size: Number of data points in each window
|
| 274 |
+
stride: Number of data points to advance between windows
|
| 275 |
+
normalizer_params: Dictionary with 'means' and 'stds' for each channel
|
| 276 |
+
If None, default values will be used
|
| 277 |
+
poll_interval: How often to check for new data (in seconds)
|
| 278 |
+
batch_size: How many windows to encode at once
|
| 279 |
+
normalize: Whether to normalize the data
|
| 280 |
+
device: Device to use for encoding ('cuda' or 'cpu')
|
| 281 |
+
start_from_timestamp: Only process data from this timestamp onwards
|
| 282 |
+
force_sequence_length: Force the model to use this sequence length (to match training)
|
| 283 |
+
"""
|
| 284 |
+
self.file_path = file_path
|
| 285 |
+
self.poll_interval = poll_interval
|
| 286 |
+
self.window_size = window_size
|
| 287 |
+
self.stride = stride
|
| 288 |
+
self.normalize = normalize
|
| 289 |
+
self.batch_size = batch_size
|
| 290 |
+
self.start_from_timestamp = start_from_timestamp
|
| 291 |
+
|
| 292 |
+
# Set default normalizer parameters if not provided
|
| 293 |
+
if normalizer_params is None:
|
| 294 |
+
self.channel_means = [-70446.6562, -51197.2070, -42351.2812, -32628.9004, -58139.0547,
|
| 295 |
+
-56271.2852, -48508.2305, -57654.8711, -69949.6484, -49663.8398,
|
| 296 |
+
-43010.7070, -30252.7207, -56295.6250, -56075.9375, -48470.3086,
|
| 297 |
+
-56338.5820]
|
| 298 |
+
self.channel_stds = [76037.4453, 56048.1445, 71950.6328, 60051.6523, 64877.7422,
|
| 299 |
+
59371.3203, 56742.6055, 62344.4805, 75861.9141, 55614.6055,
|
| 300 |
+
70795.6719, 59312.4180, 64780.2109, 60292.6992, 56598.4609,
|
| 301 |
+
61472.3633]
|
| 302 |
+
else:
|
| 303 |
+
self.channel_means = normalizer_params['means']
|
| 304 |
+
self.channel_stds = normalizer_params['stds']
|
| 305 |
+
|
| 306 |
+
# Determine the number of channels
|
| 307 |
+
self.num_channels = len(self.channel_means)
|
| 308 |
+
|
| 309 |
+
# Initialize components
|
| 310 |
+
self.file_watcher = EEGFileWatcher(file_path, poll_interval)
|
| 311 |
+
self.window_processor = SlidingWindowProcessor(
|
| 312 |
+
window_size, stride, self.num_channels,
|
| 313 |
+
self.channel_means, self.channel_stds, normalize
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Set the device
|
| 317 |
+
if device is None:
|
| 318 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 319 |
+
else:
|
| 320 |
+
self.device = torch.device(device)
|
| 321 |
+
|
| 322 |
+
# Load the encoder with the forced sequence length
|
| 323 |
+
self.encoder = EncoderExtractor(model_path, self.device, force_sequence_length)
|
| 324 |
+
|
| 325 |
+
# CSV header
|
| 326 |
+
self.header = None
|
| 327 |
+
|
| 328 |
+
# Running flag
|
| 329 |
+
self.running = False
|
| 330 |
+
|
| 331 |
+
# Statistics
|
| 332 |
+
self.stats = {
|
| 333 |
+
'windows_processed': 0,
|
| 334 |
+
'start_time': None,
|
| 335 |
+
'last_timestamp': None
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def start(self):
|
| 339 |
+
"""Start the embedding stream."""
|
| 340 |
+
if self.running:
|
| 341 |
+
return
|
| 342 |
+
|
| 343 |
+
self.running = True
|
| 344 |
+
self.stats['start_time'] = time.time()
|
| 345 |
+
self.file_watcher.start()
|
| 346 |
+
|
| 347 |
+
def stop(self):
|
| 348 |
+
"""Stop the embedding stream."""
|
| 349 |
+
self.running = False
|
| 350 |
+
self.file_watcher.stop()
|
| 351 |
+
|
| 352 |
+
if self.stats['start_time'] is not None:
|
| 353 |
+
elapsed = time.time() - self.stats['start_time']
|
| 354 |
+
windows_processed = self.stats['windows_processed']
|
| 355 |
+
if windows_processed > 0 and elapsed > 0:
|
| 356 |
+
rate = windows_processed / elapsed
|
| 357 |
+
logger.info(f"Processed {windows_processed} windows in {elapsed:.2f}s ({rate:.2f} windows/s)")
|
| 358 |
+
|
| 359 |
+
def _parse_csv_line(self, line: str) -> Dict[str, Union[str, float]]:
|
| 360 |
+
"""
|
| 361 |
+
Parse a CSV line into a data point.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
line: CSV line
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Dictionary with timestamp and channel values, or None if header
|
| 368 |
+
"""
|
| 369 |
+
if not self.header:
|
| 370 |
+
# First line is the header
|
| 371 |
+
self.header = line.split(',')
|
| 372 |
+
logger.info(f"CSV header: {self.header}")
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
values = line.split(',')
|
| 376 |
+
if len(values) != len(self.header):
|
| 377 |
+
logger.warning(f"Line has wrong number of values: {line}")
|
| 378 |
+
return None
|
| 379 |
+
|
| 380 |
+
data_point = {}
|
| 381 |
+
for i, column in enumerate(self.header):
|
| 382 |
+
if i == 0:
|
| 383 |
+
# Timestamp column
|
| 384 |
+
data_point['timestamp'] = values[i]
|
| 385 |
+
|
| 386 |
+
# Skip if before start_from_timestamp
|
| 387 |
+
if self.start_from_timestamp and values[i] < self.start_from_timestamp:
|
| 388 |
+
return None
|
| 389 |
+
else:
|
| 390 |
+
# Channel column
|
| 391 |
+
try:
|
| 392 |
+
data_point[column] = float(values[i])
|
| 393 |
+
except ValueError:
|
| 394 |
+
logger.warning(f"Could not parse value {values[i]} as float for column {column}")
|
| 395 |
+
data_point[column] = 0.0
|
| 396 |
+
|
| 397 |
+
return data_point
|
| 398 |
+
|
| 399 |
+
def get_embeddings(self, timeout: Optional[float] = None) -> Generator[Dict[str, Any], None, None]:
|
| 400 |
+
"""
|
| 401 |
+
Get embeddings for new data.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
timeout: How long to wait for new data (in seconds). None means don't wait.
|
| 405 |
+
|
| 406 |
+
Yields:
|
| 407 |
+
Dictionary with window information and embedding
|
| 408 |
+
"""
|
| 409 |
+
if not self.running:
|
| 410 |
+
self.start()
|
| 411 |
+
|
| 412 |
+
# Get new lines from the file
|
| 413 |
+
new_lines = self.file_watcher.get_new_lines(timeout)
|
| 414 |
+
if not new_lines:
|
| 415 |
+
return
|
| 416 |
+
|
| 417 |
+
# Parse CSV lines
|
| 418 |
+
data_points = []
|
| 419 |
+
for line in new_lines:
|
| 420 |
+
data_point = self._parse_csv_line(line)
|
| 421 |
+
if data_point:
|
| 422 |
+
data_points.append(data_point)
|
| 423 |
+
self.stats['last_timestamp'] = data_point['timestamp']
|
| 424 |
+
|
| 425 |
+
if not data_points:
|
| 426 |
+
return
|
| 427 |
+
|
| 428 |
+
# Add to the window processor
|
| 429 |
+
self.window_processor.add_data(data_points)
|
| 430 |
+
|
| 431 |
+
# Get windows and batch them for embedding
|
| 432 |
+
windows = list(self.window_processor.get_windows())
|
| 433 |
+
if not windows:
|
| 434 |
+
return
|
| 435 |
+
|
| 436 |
+
for batch_start in range(0, len(windows), self.batch_size):
|
| 437 |
+
batch_end = min(batch_start + self.batch_size, len(windows))
|
| 438 |
+
batch = windows[batch_start:batch_end]
|
| 439 |
+
|
| 440 |
+
# Extract timestamps and data
|
| 441 |
+
batch_timestamps = [window[0] for window in batch]
|
| 442 |
+
batch_data = [window[1] for window in batch]
|
| 443 |
+
|
| 444 |
+
# Convert to tensors
|
| 445 |
+
batch_tensor = torch.tensor(np.array(batch_data), dtype=torch.float32)
|
| 446 |
+
|
| 447 |
+
# Generate embeddings
|
| 448 |
+
embeddings = self.encoder.embed(batch_tensor)
|
| 449 |
+
|
| 450 |
+
# Convert to numpy and yield
|
| 451 |
+
embeddings_np = embeddings.cpu().numpy()
|
| 452 |
+
|
| 453 |
+
for i in range(len(batch)):
|
| 454 |
+
self.stats['windows_processed'] += 1
|
| 455 |
+
yield {
|
| 456 |
+
'start_timestamp': batch_timestamps[i][0],
|
| 457 |
+
'end_timestamp': batch_timestamps[i][-1],
|
| 458 |
+
'embedding': embeddings_np[i],
|
| 459 |
+
'window_index': self.stats['windows_processed'] - 1
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
def get_streaming_embeddings(self, callback: Optional[Callable[[Dict[str, Any]], None]] = None) -> Generator[Dict[str, Any], None, None]:
|
| 463 |
+
"""
|
| 464 |
+
Continuously generate embeddings and call the callback function with each one.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
callback: Function to call with each embedding. If None, embeddings are yielded.
|
| 468 |
+
|
| 469 |
+
Yields:
|
| 470 |
+
If no callback is provided, yields dictionaries with window information and embedding
|
| 471 |
+
"""
|
| 472 |
+
self.start()
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
while self.running:
|
| 476 |
+
any_embeddings = False
|
| 477 |
+
for embedding in self.get_embeddings(timeout=self.poll_interval):
|
| 478 |
+
any_embeddings = True
|
| 479 |
+
if callback:
|
| 480 |
+
callback(embedding)
|
| 481 |
+
else:
|
| 482 |
+
yield embedding
|
| 483 |
+
|
| 484 |
+
if not any_embeddings:
|
| 485 |
+
# No new embeddings, just wait a bit
|
| 486 |
+
time.sleep(self.poll_interval)
|
| 487 |
+
finally:
|
| 488 |
+
self.stop()
|
| 489 |
+
|
| 490 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 491 |
+
"""
|
| 492 |
+
Get statistics about the streaming process.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
Dictionary with statistics
|
| 496 |
+
"""
|
| 497 |
+
stats = dict(self.stats)
|
| 498 |
+
if stats['start_time'] is not None:
|
| 499 |
+
stats['elapsed'] = time.time() - stats['start_time']
|
| 500 |
+
if stats['windows_processed'] > 0 and stats['elapsed'] > 0:
|
| 501 |
+
stats['windows_per_second'] = stats['windows_processed'] / stats['elapsed']
|
| 502 |
+
return stats
|
| 503 |
+
|
| 504 |
+
# Example usage
|
| 505 |
+
def example():
|
| 506 |
+
def handle_embedding(embedding):
|
| 507 |
+
"""Callback function to handle new embeddings."""
|
| 508 |
+
start_time = embedding['start_timestamp']
|
| 509 |
+
end_time = embedding['end_timestamp']
|
| 510 |
+
embedding_data = embedding['embedding']
|
| 511 |
+
|
| 512 |
+
print(f"Got embedding for window from {start_time} to {end_time}")
|
| 513 |
+
print(f"Embedding shape: {embedding_data.shape}")
|
| 514 |
+
print(f"First few values: {embedding_data.flatten()[:5]}")
|
| 515 |
+
|
| 516 |
+
# Create the embedding stream
|
| 517 |
+
stream = EEGEmbeddingStream(
|
| 518 |
+
file_path="eeg_data.csv",
|
| 519 |
+
model_path="models/eeg_autoencoder.pth",
|
| 520 |
+
window_size=256, # Number of data points in each window
|
| 521 |
+
stride=128, # How much to advance between windows
|
| 522 |
+
poll_interval=0.5 # Check for new data every 0.5 seconds
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
print("Starting embedding stream...")
|
| 526 |
+
print("Press Ctrl+C to stop")
|
| 527 |
+
|
| 528 |
+
try:
|
| 529 |
+
# Method 1: Using callback
|
| 530 |
+
stream.get_streaming_embeddings(callback=handle_embedding)
|
| 531 |
+
|
| 532 |
+
# Method 2: Using generator
|
| 533 |
+
# for embedding in stream.get_streaming_embeddings():
|
| 534 |
+
# handle_embedding(embedding)
|
| 535 |
+
except KeyboardInterrupt:
|
| 536 |
+
print("\nStopping...")
|
| 537 |
+
finally:
|
| 538 |
+
stream.stop()
|
| 539 |
+
print("Stopped")
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if __name__ == "__main__":
|
| 543 |
+
example()
|
embed.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Text embedding script with SQLite storage (using numpy buffers)
|
| 4 |
+
Now with flexible text splitting modes!
|
| 5 |
+
|
| 6 |
+
Usage: python embed_flex.py <directory_path> <db_path> [--split-mode MODE]
|
| 7 |
+
|
| 8 |
+
Split modes:
|
| 9 |
+
- line (default): Each non-empty line becomes one embedding
|
| 10 |
+
- block: Double-newline separated blocks (paragraphs)
|
| 11 |
+
- sentence: Split on sentence boundaries (., !, ?)
|
| 12 |
+
- chunk: Fixed token-ish chunks with overlap (for long docs)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import argparse
|
| 18 |
+
import sqlite3
|
| 19 |
+
import numpy as np
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from transformers import AutoModel, AutoTokenizer
|
| 22 |
+
import torch
|
| 23 |
+
import gc
|
| 24 |
+
import random
|
| 25 |
+
import re
|
| 26 |
+
|
| 27 |
+
INITIAL_BATCH_SIZE = 128
|
| 28 |
+
MIN_BATCH_SIZE = 1
|
| 29 |
+
SHUFFLE_SEED = 42
|
| 30 |
+
|
| 31 |
+
# Chunk mode settings
|
| 32 |
+
DEFAULT_CHUNK_SIZE = 512 # characters
|
| 33 |
+
DEFAULT_CHUNK_OVERLAP = 64
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_index_if_possible(cursor):
|
| 37 |
+
try:
|
| 38 |
+
cursor.execute("""
|
| 39 |
+
CREATE INDEX IF NOT EXISTS idx_content ON messages(content)
|
| 40 |
+
""")
|
| 41 |
+
except sqlite3.OperationalError:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_existing_content(cursor):
|
| 46 |
+
try:
|
| 47 |
+
cursor.execute("SELECT content FROM messages")
|
| 48 |
+
return {row[0] for row in cursor.fetchall()}
|
| 49 |
+
except sqlite3.OperationalError:
|
| 50 |
+
return set()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def clear_gpu_memory():
|
| 54 |
+
if torch.cuda.is_available():
|
| 55 |
+
torch.cuda.empty_cache()
|
| 56 |
+
gc.collect()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# =============================================================================
|
| 60 |
+
# SPLITTING STRATEGIES
|
| 61 |
+
# =============================================================================
|
| 62 |
+
|
| 63 |
+
def split_by_lines(text):
|
| 64 |
+
"""Original behavior: each non-empty line is one unit."""
|
| 65 |
+
lines = []
|
| 66 |
+
for line in text.split('\n'):
|
| 67 |
+
line = line.strip()
|
| 68 |
+
if line:
|
| 69 |
+
lines.append(line)
|
| 70 |
+
return lines
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def split_by_blocks(text):
|
| 74 |
+
blocks = re.split(r'\n\s*\n+', text)
|
| 75 |
+
result = []
|
| 76 |
+
for block in blocks:
|
| 77 |
+
cleaned = ' '.join(block.split())
|
| 78 |
+
if cleaned:
|
| 79 |
+
result.append(cleaned)
|
| 80 |
+
return result
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def split_by_sentences(text):
|
| 84 |
+
"""
|
| 85 |
+
Split on sentence boundaries.
|
| 86 |
+
Handles common abbreviations somewhat gracefully.
|
| 87 |
+
"""
|
| 88 |
+
# First normalize whitespace
|
| 89 |
+
text = ' '.join(text.split())
|
| 90 |
+
|
| 91 |
+
# Sentence-ending pattern (handles ., !, ? followed by space and capital or end)
|
| 92 |
+
# This is imperfect but reasonable for most text
|
| 93 |
+
pattern = r'(?<=[.!?])\s+(?=[A-Z])'
|
| 94 |
+
|
| 95 |
+
sentences = re.split(pattern, text)
|
| 96 |
+
result = []
|
| 97 |
+
for sent in sentences:
|
| 98 |
+
sent = sent.strip()
|
| 99 |
+
if sent:
|
| 100 |
+
result.append(sent)
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def split_by_chunks(text, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP):
|
| 105 |
+
"""
|
| 106 |
+
Fixed-size character chunks with overlap.
|
| 107 |
+
Good for long documents where you want sliding window coverage.
|
| 108 |
+
"""
|
| 109 |
+
# Normalize whitespace
|
| 110 |
+
text = ' '.join(text.split())
|
| 111 |
+
|
| 112 |
+
if len(text) <= chunk_size:
|
| 113 |
+
return [text] if text else []
|
| 114 |
+
|
| 115 |
+
chunks = []
|
| 116 |
+
start = 0
|
| 117 |
+
while start < len(text):
|
| 118 |
+
end = start + chunk_size
|
| 119 |
+
chunk = text[start:end]
|
| 120 |
+
|
| 121 |
+
# Try to break at word boundary if not at end
|
| 122 |
+
if end < len(text):
|
| 123 |
+
last_space = chunk.rfind(' ')
|
| 124 |
+
if last_space > chunk_size // 2: # Only if we're not losing too much
|
| 125 |
+
chunk = chunk[:last_space]
|
| 126 |
+
end = start + last_space
|
| 127 |
+
|
| 128 |
+
chunk = chunk.strip()
|
| 129 |
+
if chunk:
|
| 130 |
+
chunks.append(chunk)
|
| 131 |
+
|
| 132 |
+
# Move forward with overlap
|
| 133 |
+
start = end - overlap
|
| 134 |
+
if start <= chunks[-1] if chunks else 0: # Prevent infinite loop
|
| 135 |
+
start = end
|
| 136 |
+
|
| 137 |
+
return chunks
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_splitter(mode, chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP):
|
| 141 |
+
"""Return the appropriate splitting function."""
|
| 142 |
+
if mode == 'line':
|
| 143 |
+
return split_by_lines
|
| 144 |
+
elif mode == 'block':
|
| 145 |
+
return split_by_blocks
|
| 146 |
+
elif mode == 'sentence':
|
| 147 |
+
return split_by_sentences
|
| 148 |
+
elif mode == 'chunk':
|
| 149 |
+
return lambda text: split_by_chunks(text, chunk_size, chunk_overlap)
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError(f"Unknown split mode: {mode}")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# =============================================================================
|
| 155 |
+
# PROCESSING
|
| 156 |
+
# =============================================================================
|
| 157 |
+
|
| 158 |
+
def process_batch(model, batch_lines, cursor, task="text-matching"):
|
| 159 |
+
try:
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
batch_embeddings = model.encode(batch_lines, task=task, device="cuda")
|
| 162 |
+
|
| 163 |
+
for line_text, embedding in zip(batch_lines, batch_embeddings):
|
| 164 |
+
try:
|
| 165 |
+
cursor.execute(
|
| 166 |
+
"INSERT INTO messages (content, role) VALUES (?, ?)",
|
| 167 |
+
(line_text, "system")
|
| 168 |
+
)
|
| 169 |
+
message_id = cursor.lastrowid
|
| 170 |
+
|
| 171 |
+
if torch.is_tensor(embedding):
|
| 172 |
+
embedding_np = embedding.cpu().numpy()
|
| 173 |
+
elif not isinstance(embedding, np.ndarray):
|
| 174 |
+
embedding_np = np.array(embedding)
|
| 175 |
+
else:
|
| 176 |
+
embedding_np = embedding
|
| 177 |
+
|
| 178 |
+
embedding_blob = embedding_np.astype(np.float32).tobytes()
|
| 179 |
+
|
| 180 |
+
cursor.execute(
|
| 181 |
+
"INSERT INTO embeddings (message_id, embedding) VALUES (?, ?)",
|
| 182 |
+
(message_id, embedding_blob)
|
| 183 |
+
)
|
| 184 |
+
except sqlite3.Error as e:
|
| 185 |
+
print(f"Error processing entry: {e}")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
return True
|
| 189 |
+
|
| 190 |
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
| 191 |
+
if "out of memory" in str(e).lower():
|
| 192 |
+
clear_gpu_memory()
|
| 193 |
+
return False
|
| 194 |
+
else:
|
| 195 |
+
raise
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def convert_existing_pickles(cursor, conn):
|
| 199 |
+
"""Convert any existing pickle embeddings to numpy buffers"""
|
| 200 |
+
import pickle
|
| 201 |
+
|
| 202 |
+
def is_numpy_buffer(blob):
|
| 203 |
+
try:
|
| 204 |
+
np_array = np.frombuffer(blob, dtype=np.float32)
|
| 205 |
+
if np_array.ndim >= 1 and len(np_array) > 0:
|
| 206 |
+
return True
|
| 207 |
+
except Exception:
|
| 208 |
+
pass
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
def unpickle_to_numpy(blob):
|
| 212 |
+
try:
|
| 213 |
+
pickled_obj = pickle.loads(blob)
|
| 214 |
+
if isinstance(pickled_obj, np.ndarray):
|
| 215 |
+
return pickled_obj
|
| 216 |
+
elif torch.is_tensor(pickled_obj):
|
| 217 |
+
return pickled_obj.cpu().numpy()
|
| 218 |
+
else:
|
| 219 |
+
return np.array(pickled_obj)
|
| 220 |
+
except Exception:
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
cursor.execute("SELECT COUNT(*) FROM embeddings")
|
| 224 |
+
total_embeddings = cursor.fetchone()[0]
|
| 225 |
+
|
| 226 |
+
if total_embeddings == 0:
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
print(f"Checking {total_embeddings} existing embeddings for pickle->numpy conversion...")
|
| 230 |
+
|
| 231 |
+
cursor.execute("SELECT message_id, embedding FROM embeddings")
|
| 232 |
+
embeddings = cursor.fetchall()
|
| 233 |
+
|
| 234 |
+
converted_count = 0
|
| 235 |
+
for message_id, embedding_blob in embeddings:
|
| 236 |
+
if not is_numpy_buffer(embedding_blob):
|
| 237 |
+
numpy_array = unpickle_to_numpy(embedding_blob)
|
| 238 |
+
|
| 239 |
+
if numpy_array is not None:
|
| 240 |
+
np_buffer = numpy_array.astype(np.float32).tobytes()
|
| 241 |
+
cursor.execute(
|
| 242 |
+
"UPDATE embeddings SET embedding = ? WHERE message_id = ?",
|
| 243 |
+
(np_buffer, message_id)
|
| 244 |
+
)
|
| 245 |
+
converted_count += 1
|
| 246 |
+
|
| 247 |
+
if converted_count > 0:
|
| 248 |
+
conn.commit()
|
| 249 |
+
print(f"Converted {converted_count} pickle embeddings to numpy buffers")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def main():
|
| 253 |
+
parser = argparse.ArgumentParser(
|
| 254 |
+
description='Generate embeddings for text files with flexible splitting modes',
|
| 255 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 256 |
+
epilog="""
|
| 257 |
+
Split Modes:
|
| 258 |
+
line Each non-empty line = one embedding (default, original behavior)
|
| 259 |
+
block Double-newline separated paragraphs = one embedding each
|
| 260 |
+
sentence Split on sentence boundaries (., !, ?)
|
| 261 |
+
chunk Fixed-size character chunks with overlap (good for long docs)
|
| 262 |
+
|
| 263 |
+
Examples:
|
| 264 |
+
%(prog)s ~/docs embeddings.db # line mode (default)
|
| 265 |
+
%(prog)s ~/docs embeddings.db --split-mode block # paragraph mode
|
| 266 |
+
%(prog)s ~/docs embeddings.db --split-mode sentence # sentence mode
|
| 267 |
+
%(prog)s ~/docs embeddings.db --split-mode chunk --chunk-size 1024 --chunk-overlap 128
|
| 268 |
+
"""
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
parser.add_argument('directory',
|
| 272 |
+
help='Directory containing .txt files to process')
|
| 273 |
+
parser.add_argument('database',
|
| 274 |
+
help='SQLite database path (will be created if not exists)')
|
| 275 |
+
parser.add_argument('--split-mode', '-s',
|
| 276 |
+
choices=['line', 'block', 'sentence', 'chunk'],
|
| 277 |
+
default='line',
|
| 278 |
+
help='Text splitting strategy (default: line)')
|
| 279 |
+
parser.add_argument('--chunk-size', type=int, default=DEFAULT_CHUNK_SIZE,
|
| 280 |
+
help=f'Character chunk size for chunk mode (default: {DEFAULT_CHUNK_SIZE})')
|
| 281 |
+
parser.add_argument('--chunk-overlap', type=int, default=DEFAULT_CHUNK_OVERLAP,
|
| 282 |
+
help=f'Overlap between chunks (default: {DEFAULT_CHUNK_OVERLAP})')
|
| 283 |
+
parser.add_argument('--batch-size', type=int, default=INITIAL_BATCH_SIZE,
|
| 284 |
+
help=f'Initial batch size (default: {INITIAL_BATCH_SIZE})')
|
| 285 |
+
parser.add_argument('--task', default='text-matching',
|
| 286 |
+
help='Encoding task (default: text-matching)')
|
| 287 |
+
parser.add_argument('--model', default='jinaai/jina-embeddings-v3',
|
| 288 |
+
help='Model name (default: jinaai/jina-embeddings-v3)')
|
| 289 |
+
parser.add_argument('--skip-conversion', action='store_true',
|
| 290 |
+
help='Skip checking/converting existing pickle embeddings')
|
| 291 |
+
|
| 292 |
+
args = parser.parse_args()
|
| 293 |
+
|
| 294 |
+
directory_path = os.path.expanduser(args.directory)
|
| 295 |
+
db_path = os.path.expanduser(args.database)
|
| 296 |
+
|
| 297 |
+
if not os.path.isdir(directory_path):
|
| 298 |
+
print(f"Error: Directory '{directory_path}' does not exist")
|
| 299 |
+
sys.exit(1)
|
| 300 |
+
|
| 301 |
+
print(f"Processing directory: {directory_path}")
|
| 302 |
+
print(f"Database: {db_path}")
|
| 303 |
+
print(f"Split mode: {args.split_mode}")
|
| 304 |
+
if args.split_mode == 'chunk':
|
| 305 |
+
print(f"Chunk size: {args.chunk_size}, overlap: {args.chunk_overlap}")
|
| 306 |
+
print(f"Initial batch size: {args.batch_size}")
|
| 307 |
+
|
| 308 |
+
# Get splitter function
|
| 309 |
+
splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap)
|
| 310 |
+
|
| 311 |
+
# Initialize model
|
| 312 |
+
print(f"Loading model: {args.model}")
|
| 313 |
+
model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda()
|
| 314 |
+
model.eval()
|
| 315 |
+
|
| 316 |
+
# Set up SQLite
|
| 317 |
+
conn = sqlite3.connect(db_path)
|
| 318 |
+
cursor = conn.cursor()
|
| 319 |
+
|
| 320 |
+
cursor.execute("""
|
| 321 |
+
CREATE TABLE IF NOT EXISTS messages (
|
| 322 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 323 |
+
content TEXT,
|
| 324 |
+
role TEXT
|
| 325 |
+
)
|
| 326 |
+
""")
|
| 327 |
+
|
| 328 |
+
cursor.execute("""
|
| 329 |
+
CREATE TABLE IF NOT EXISTS embeddings (
|
| 330 |
+
message_id INTEGER PRIMARY KEY,
|
| 331 |
+
embedding BLOB,
|
| 332 |
+
FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE
|
| 333 |
+
)
|
| 334 |
+
""")
|
| 335 |
+
conn.commit()
|
| 336 |
+
|
| 337 |
+
create_index_if_possible(cursor)
|
| 338 |
+
conn.commit()
|
| 339 |
+
|
| 340 |
+
if not args.skip_conversion:
|
| 341 |
+
convert_existing_pickles(cursor, conn)
|
| 342 |
+
|
| 343 |
+
existing_content = get_existing_content(cursor)
|
| 344 |
+
print(f"Already processed: {len(existing_content)} entries")
|
| 345 |
+
|
| 346 |
+
# Collect all text units using the selected splitter
|
| 347 |
+
all_units = []
|
| 348 |
+
txt_files = [f for f in os.listdir(directory_path) if f.lower().endswith(".txt")]
|
| 349 |
+
|
| 350 |
+
if not txt_files:
|
| 351 |
+
print(f"Warning: No .txt files found in {directory_path}")
|
| 352 |
+
conn.close()
|
| 353 |
+
return
|
| 354 |
+
|
| 355 |
+
print(f"Found {len(txt_files)} .txt files")
|
| 356 |
+
|
| 357 |
+
for filename in txt_files:
|
| 358 |
+
filepath = os.path.join(directory_path, filename)
|
| 359 |
+
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
|
| 360 |
+
content = f.read()
|
| 361 |
+
units = splitter(content)
|
| 362 |
+
all_units.extend(units)
|
| 363 |
+
|
| 364 |
+
print(f"Total units from source ({args.split_mode} mode): {len(all_units)}")
|
| 365 |
+
|
| 366 |
+
# Deterministic shuffle
|
| 367 |
+
random.seed(SHUFFLE_SEED)
|
| 368 |
+
random.shuffle(all_units)
|
| 369 |
+
|
| 370 |
+
# Filter out already processed
|
| 371 |
+
new_units = [u for u in all_units if u not in existing_content]
|
| 372 |
+
|
| 373 |
+
print(f"Remaining to process: {len(new_units)}")
|
| 374 |
+
|
| 375 |
+
if not new_units:
|
| 376 |
+
print("Nothing new to process.")
|
| 377 |
+
conn.close()
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
# Process with dynamic batch sizing
|
| 381 |
+
batch_size = args.batch_size
|
| 382 |
+
total = len(new_units)
|
| 383 |
+
task = args.task
|
| 384 |
+
|
| 385 |
+
idx = 0
|
| 386 |
+
processed_count = 0
|
| 387 |
+
|
| 388 |
+
with tqdm(total=total, desc="Processing") as pbar:
|
| 389 |
+
while idx < total:
|
| 390 |
+
end_idx = min(idx + batch_size, total)
|
| 391 |
+
batch = new_units[idx:end_idx]
|
| 392 |
+
|
| 393 |
+
success = process_batch(model, batch, cursor, task)
|
| 394 |
+
|
| 395 |
+
if success:
|
| 396 |
+
try:
|
| 397 |
+
conn.commit()
|
| 398 |
+
except sqlite3.Error as e:
|
| 399 |
+
print(f"Error committing batch: {e}")
|
| 400 |
+
|
| 401 |
+
batch_processed = len(batch)
|
| 402 |
+
pbar.update(batch_processed)
|
| 403 |
+
processed_count += batch_processed
|
| 404 |
+
idx = end_idx
|
| 405 |
+
|
| 406 |
+
if batch_size < args.batch_size and processed_count % (batch_size * 10) == 0:
|
| 407 |
+
batch_size = min(batch_size * 2, args.batch_size)
|
| 408 |
+
else:
|
| 409 |
+
if batch_size > MIN_BATCH_SIZE:
|
| 410 |
+
batch_size = max(batch_size // 2, MIN_BATCH_SIZE)
|
| 411 |
+
print(f"\nOOM - batch size -> {batch_size}")
|
| 412 |
+
else:
|
| 413 |
+
print(f"\nSkipping: {batch[0][:100]}...")
|
| 414 |
+
idx += 1
|
| 415 |
+
pbar.update(1)
|
| 416 |
+
processed_count += 1
|
| 417 |
+
|
| 418 |
+
conn.close()
|
| 419 |
+
print(f"\nProcessed {processed_count:,} entries total.")
|
| 420 |
+
print("All embeddings stored as numpy buffers (float32).")
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
main()
|
morphism.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
morphism — EEG-to-text semantic search
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
morphism record [options]
|
| 7 |
+
morphism index create|info|rebuild [options]
|
| 8 |
+
morphism decode [options]
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
from retrieval import FloodMode, DriftMode, FocusMode, LayeredMode
|
| 16 |
+
|
| 17 |
+
def cmd_record(args):
|
| 18 |
+
"""Record EEG data from OpenBCI Cyton+Daisy"""
|
| 19 |
+
from cyton import (
|
| 20 |
+
init_board, set_sample_rate, read_complete_packet, process_packet,
|
| 21 |
+
start_sd_recording, stop_sd_recording, create_ssh_connection, sd_record
|
| 22 |
+
)
|
| 23 |
+
import serial, time, io
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
|
| 26 |
+
if args.sd:
|
| 27 |
+
sd_record(args.port, args.duration, args.sample_rate)
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
filename = args.output
|
| 31 |
+
if filename is None:
|
| 32 |
+
filename = f"openbci_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
|
| 33 |
+
|
| 34 |
+
ser = serial.Serial(args.port, 115200)
|
| 35 |
+
time.sleep(2)
|
| 36 |
+
init_board(ser)
|
| 37 |
+
|
| 38 |
+
if args.sample_rate != 1000:
|
| 39 |
+
set_sample_rate(ser, args.sample_rate)
|
| 40 |
+
|
| 41 |
+
ssh, sftp, remote_file = None, None, None
|
| 42 |
+
if args.remote:
|
| 43 |
+
ssh = create_ssh_connection()
|
| 44 |
+
if not ssh:
|
| 45 |
+
print("SSH connection failed.")
|
| 46 |
+
return
|
| 47 |
+
sftp = ssh.open_sftp()
|
| 48 |
+
remote_file = sftp.open(filename, 'w')
|
| 49 |
+
|
| 50 |
+
header = "Timestamp," + ",".join(f"Channel{i+1}" for i in range(16)) + "\n"
|
| 51 |
+
if args.remote:
|
| 52 |
+
remote_file.write(header)
|
| 53 |
+
else:
|
| 54 |
+
with open(filename, 'w') as f:
|
| 55 |
+
f.write(header)
|
| 56 |
+
|
| 57 |
+
ser.write(b'b')
|
| 58 |
+
time.sleep(0.5)
|
| 59 |
+
ser.reset_input_buffer()
|
| 60 |
+
|
| 61 |
+
print(f"Recording to {filename} — Ctrl+C to stop")
|
| 62 |
+
|
| 63 |
+
pkt_count = 0
|
| 64 |
+
t0 = time.time()
|
| 65 |
+
buf = io.StringIO()
|
| 66 |
+
last_flush = time.time()
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
while True:
|
| 70 |
+
p1 = read_complete_packet(ser)
|
| 71 |
+
if not p1:
|
| 72 |
+
continue
|
| 73 |
+
p2 = read_complete_packet(ser)
|
| 74 |
+
if not p2:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
d1, d2 = process_packet(p1), process_packet(p2)
|
| 78 |
+
if not (d1 and d2):
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
pkt_count += 1
|
| 82 |
+
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
| 83 |
+
line = ts + "," + ",".join(f"{x:.6f}" for x in d1 + d2) + "\n"
|
| 84 |
+
|
| 85 |
+
if args.remote:
|
| 86 |
+
buf.write(line)
|
| 87 |
+
if time.time() - last_flush >= 0.1:
|
| 88 |
+
remote_file.write(buf.getvalue())
|
| 89 |
+
buf = io.StringIO()
|
| 90 |
+
last_flush = time.time()
|
| 91 |
+
else:
|
| 92 |
+
with open(filename, 'a') as f:
|
| 93 |
+
f.write(line)
|
| 94 |
+
|
| 95 |
+
if pkt_count % 125 == 0:
|
| 96 |
+
rate = pkt_count / (time.time() - t0)
|
| 97 |
+
print(f"\r {rate:.1f} Hz, {pkt_count} packets", end='')
|
| 98 |
+
|
| 99 |
+
if ser.in_waiting > 1000:
|
| 100 |
+
ser.reset_input_buffer()
|
| 101 |
+
|
| 102 |
+
except KeyboardInterrupt:
|
| 103 |
+
ser.write(b's')
|
| 104 |
+
ser.close()
|
| 105 |
+
if args.remote:
|
| 106 |
+
if buf.getvalue():
|
| 107 |
+
remote_file.write(buf.getvalue())
|
| 108 |
+
remote_file.close()
|
| 109 |
+
sftp.close()
|
| 110 |
+
ssh.close()
|
| 111 |
+
|
| 112 |
+
elapsed = time.time() - t0
|
| 113 |
+
print(f"\n\nDone — {pkt_count} packets in {elapsed:.1f}s ({pkt_count/elapsed:.1f} Hz)")
|
| 114 |
+
print(f"Saved to {filename}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def cmd_index(args):
|
| 118 |
+
"""Manage the text embedding index"""
|
| 119 |
+
from embed import (
|
| 120 |
+
get_splitter, process_batch, create_index_if_possible,
|
| 121 |
+
get_existing_content, INITIAL_BATCH_SIZE, MIN_BATCH_SIZE, SHUFFLE_SEED
|
| 122 |
+
)
|
| 123 |
+
import sqlite3, numpy as np, random
|
| 124 |
+
from tqdm import tqdm
|
| 125 |
+
|
| 126 |
+
db_path = os.path.expanduser(args.db)
|
| 127 |
+
index_prefix = args.index
|
| 128 |
+
|
| 129 |
+
if args.action == 'info':
|
| 130 |
+
if not os.path.exists(db_path):
|
| 131 |
+
print(f"No database at {db_path}")
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
conn = sqlite3.connect(db_path)
|
| 135 |
+
c = conn.cursor()
|
| 136 |
+
c.execute("SELECT COUNT(*) FROM messages")
|
| 137 |
+
msg_count = c.fetchone()[0]
|
| 138 |
+
c.execute("SELECT COUNT(*) FROM embeddings")
|
| 139 |
+
emb_count = c.fetchone()[0]
|
| 140 |
+
conn.close()
|
| 141 |
+
|
| 142 |
+
index_exists = os.path.exists(f"{index_prefix}.index")
|
| 143 |
+
|
| 144 |
+
print(f"Database: {db_path}")
|
| 145 |
+
print(f"Messages: {msg_count:,}")
|
| 146 |
+
print(f"Embeddings: {emb_count:,}")
|
| 147 |
+
print(f"FAISS index: {'exists' if index_exists else 'not built'} ({index_prefix}.index)")
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
if args.action in ('create', 'rebuild'):
|
| 151 |
+
corpus = os.path.expanduser(args.corpus)
|
| 152 |
+
if not os.path.isdir(corpus):
|
| 153 |
+
print(f"Not a directory: {corpus}")
|
| 154 |
+
sys.exit(1)
|
| 155 |
+
|
| 156 |
+
splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap)
|
| 157 |
+
|
| 158 |
+
print(f"Loading model: {args.model}")
|
| 159 |
+
from transformers import AutoModel
|
| 160 |
+
model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda()
|
| 161 |
+
model.eval()
|
| 162 |
+
|
| 163 |
+
conn = sqlite3.connect(db_path)
|
| 164 |
+
c = conn.cursor()
|
| 165 |
+
|
| 166 |
+
if args.action == 'rebuild':
|
| 167 |
+
print("Dropping existing data...")
|
| 168 |
+
c.execute("DELETE FROM embeddings")
|
| 169 |
+
c.execute("DELETE FROM messages")
|
| 170 |
+
conn.commit()
|
| 171 |
+
|
| 172 |
+
c.execute("""CREATE TABLE IF NOT EXISTS messages (
|
| 173 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT, content TEXT, role TEXT)""")
|
| 174 |
+
c.execute("""CREATE TABLE IF NOT EXISTS embeddings (
|
| 175 |
+
message_id INTEGER PRIMARY KEY, embedding BLOB,
|
| 176 |
+
FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE)""")
|
| 177 |
+
conn.commit()
|
| 178 |
+
create_index_if_possible(c)
|
| 179 |
+
conn.commit()
|
| 180 |
+
|
| 181 |
+
existing = get_existing_content(c)
|
| 182 |
+
print(f"Already indexed: {len(existing):,}")
|
| 183 |
+
|
| 184 |
+
txt_files = [f for f in os.listdir(corpus) if f.lower().endswith('.txt')]
|
| 185 |
+
if not txt_files:
|
| 186 |
+
print(f"No .txt files in {corpus}")
|
| 187 |
+
conn.close()
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
units = []
|
| 191 |
+
for fn in txt_files:
|
| 192 |
+
with open(os.path.join(corpus, fn), 'r', encoding='utf-8', errors='ignore') as f:
|
| 193 |
+
units.extend(splitter(f.read()))
|
| 194 |
+
|
| 195 |
+
random.seed(SHUFFLE_SEED)
|
| 196 |
+
random.shuffle(units)
|
| 197 |
+
new_units = [u for u in units if u not in existing]
|
| 198 |
+
print(f"New units to embed: {len(new_units):,}")
|
| 199 |
+
|
| 200 |
+
if not new_units:
|
| 201 |
+
print("Nothing new.")
|
| 202 |
+
conn.close()
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
batch_size = args.batch_size
|
| 206 |
+
idx = 0
|
| 207 |
+
processed = 0
|
| 208 |
+
|
| 209 |
+
with tqdm(total=len(new_units), desc="Embedding") as pbar:
|
| 210 |
+
while idx < len(new_units):
|
| 211 |
+
batch = new_units[idx:idx + batch_size]
|
| 212 |
+
ok = process_batch(model, batch, c, args.task)
|
| 213 |
+
if ok:
|
| 214 |
+
conn.commit()
|
| 215 |
+
pbar.update(len(batch))
|
| 216 |
+
processed += len(batch)
|
| 217 |
+
idx += len(batch)
|
| 218 |
+
else:
|
| 219 |
+
if batch_size > MIN_BATCH_SIZE:
|
| 220 |
+
batch_size = max(batch_size // 2, MIN_BATCH_SIZE)
|
| 221 |
+
print(f"\nOOM — batch size → {batch_size}")
|
| 222 |
+
else:
|
| 223 |
+
idx += 1
|
| 224 |
+
pbar.update(1)
|
| 225 |
+
processed += 1
|
| 226 |
+
|
| 227 |
+
conn.close()
|
| 228 |
+
print(f"Embedded {processed:,} units.")
|
| 229 |
+
|
| 230 |
+
print("Building FAISS index...")
|
| 231 |
+
_build_faiss_index(db_path, index_prefix)
|
| 232 |
+
print("Done.")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _build_faiss_index(db_path, index_prefix):
|
| 236 |
+
"""Build FAISS index from the embeddings database"""
|
| 237 |
+
import sqlite3, numpy as np
|
| 238 |
+
from decode import EmbeddingIndex
|
| 239 |
+
|
| 240 |
+
conn = sqlite3.connect(db_path)
|
| 241 |
+
c = conn.cursor()
|
| 242 |
+
c.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id")
|
| 243 |
+
|
| 244 |
+
embeddings, ids = [], []
|
| 245 |
+
for mid, blob in c.fetchall():
|
| 246 |
+
embeddings.append(np.frombuffer(blob, dtype=np.float32))
|
| 247 |
+
ids.append(mid)
|
| 248 |
+
conn.close()
|
| 249 |
+
|
| 250 |
+
if not embeddings:
|
| 251 |
+
print(" No embeddings found.")
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
embeddings = np.vstack(embeddings)
|
| 255 |
+
print(f" {len(embeddings):,} vectors, dim={embeddings.shape[1]}")
|
| 256 |
+
|
| 257 |
+
idx = EmbeddingIndex(dim=embeddings.shape[1])
|
| 258 |
+
idx.add_embeddings(embeddings, ids)
|
| 259 |
+
idx.save(index_prefix)
|
| 260 |
+
|
| 261 |
+
conn2 = sqlite3.connect(db_path)
|
| 262 |
+
c2 = conn2.cursor()
|
| 263 |
+
c2.execute("SELECT COUNT(*) FROM embeddings")
|
| 264 |
+
count = c2.fetchone()[0]
|
| 265 |
+
c2.execute("SELECT MAX(message_id) FROM embeddings")
|
| 266 |
+
max_id = c2.fetchone()[0]
|
| 267 |
+
conn2.close()
|
| 268 |
+
np.savez(f"{index_prefix}_metadata.npz", count=count, max_message_id=max_id)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def cmd_decode(args):
|
| 272 |
+
"""Run EEG → text decoding"""
|
| 273 |
+
from decode import EEGSemanticProcessor
|
| 274 |
+
|
| 275 |
+
processor = EEGSemanticProcessor(
|
| 276 |
+
autoencoder_model_path=args.autoencoder,
|
| 277 |
+
semantic_model_path=args.semantic,
|
| 278 |
+
nexus_db_path=args.db,
|
| 279 |
+
embeddings_db_path=args.db,
|
| 280 |
+
index_path=args.index,
|
| 281 |
+
eeg_file_path=args.eeg,
|
| 282 |
+
window_size=args.window_size,
|
| 283 |
+
stride=args.stride,
|
| 284 |
+
batch_size=args.batch_size,
|
| 285 |
+
device=args.device,
|
| 286 |
+
search_k=args.search_k,
|
| 287 |
+
final_k=args.final_k,
|
| 288 |
+
use_raw_eeg=args.raw_eeg,
|
| 289 |
+
input_dim_override=args.input_dim,
|
| 290 |
+
save_vectors=args.save_vectors,
|
| 291 |
+
vector_output_path=args.vector_output,
|
| 292 |
+
last_n_messages=args.last_n,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
modes = {
|
| 296 |
+
'flood': lambda: FloodMode(processor.embedding_index, processor.nexus_conn,
|
| 297 |
+
search_k=args.search_k, final_k=args.final_k,
|
| 298 |
+
last_n=args.last_n),
|
| 299 |
+
'drift': lambda: DriftMode(processor.embedding_index, processor.nexus_conn,
|
| 300 |
+
search_k=64),
|
| 301 |
+
'focus': lambda: FocusMode(processor.embedding_index, processor.nexus_conn,
|
| 302 |
+
search_k=48),
|
| 303 |
+
'layered': lambda: LayeredMode(processor.embedding_index, processor.nexus_conn),
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
mode = modes[args.mode]()
|
| 307 |
+
|
| 308 |
+
processor.eeg_stream.start()
|
| 309 |
+
try:
|
| 310 |
+
consecutive_errors = 0
|
| 311 |
+
while True:
|
| 312 |
+
try:
|
| 313 |
+
for embedding_data in processor.eeg_stream.get_embeddings(timeout=0.5):
|
| 314 |
+
try:
|
| 315 |
+
semantic_embedding = processor.process_eeg_embedding(
|
| 316 |
+
embedding_data['embedding'])
|
| 317 |
+
|
| 318 |
+
if processor.save_vectors:
|
| 319 |
+
embedding_np = semantic_embedding.detach().cpu().numpy()
|
| 320 |
+
processor.vectors_list.append(embedding_np)
|
| 321 |
+
processor.timestamps.append({
|
| 322 |
+
'start': embedding_data['start_timestamp'],
|
| 323 |
+
'end': embedding_data['end_timestamp']
|
| 324 |
+
})
|
| 325 |
+
if len(processor.vectors_list) % 100 == 0:
|
| 326 |
+
import logging
|
| 327 |
+
logging.getLogger("EEGSemanticStream").info(
|
| 328 |
+
f"Collected {len(processor.vectors_list)} vectors")
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
lines = mode.step(semantic_embedding)
|
| 332 |
+
if lines:
|
| 333 |
+
output = "\n".join(lines)
|
| 334 |
+
print(output)
|
| 335 |
+
if processor.log_file:
|
| 336 |
+
processor.log_file.write(output + "\n")
|
| 337 |
+
processor.log_file.flush()
|
| 338 |
+
|
| 339 |
+
consecutive_errors = 0
|
| 340 |
+
except Exception as e:
|
| 341 |
+
import sys
|
| 342 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 343 |
+
consecutive_errors += 1
|
| 344 |
+
if consecutive_errors >= 5:
|
| 345 |
+
raise RuntimeError("Too many consecutive errors")
|
| 346 |
+
|
| 347 |
+
import time
|
| 348 |
+
time.sleep(0.01)
|
| 349 |
+
except Exception as e:
|
| 350 |
+
if "Too many" in str(e):
|
| 351 |
+
raise
|
| 352 |
+
import sys, time
|
| 353 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 354 |
+
consecutive_errors += 1
|
| 355 |
+
if consecutive_errors >= 5:
|
| 356 |
+
raise
|
| 357 |
+
time.sleep(1)
|
| 358 |
+
except KeyboardInterrupt:
|
| 359 |
+
pass
|
| 360 |
+
except Exception as e:
|
| 361 |
+
import sys
|
| 362 |
+
print(f"Fatal: {e}", file=sys.stderr)
|
| 363 |
+
finally:
|
| 364 |
+
if processor.save_vectors and processor.vectors_list:
|
| 365 |
+
processor.save_vectors_to_disk()
|
| 366 |
+
processor.eeg_stream.stop()
|
| 367 |
+
|
| 368 |
+
def main():
|
| 369 |
+
p = argparse.ArgumentParser(
|
| 370 |
+
prog='morphism',
|
| 371 |
+
description='EEG-to-text semantic search',
|
| 372 |
+
)
|
| 373 |
+
sub = p.add_subparsers(dest='command')
|
| 374 |
+
|
| 375 |
+
# --- record ---
|
| 376 |
+
rec = sub.add_parser('record', help='Record EEG from OpenBCI Cyton+Daisy')
|
| 377 |
+
rec.add_argument('--port', '-p', default='/dev/ttyUSB0')
|
| 378 |
+
rec.add_argument('--output', '-o', default=None)
|
| 379 |
+
rec.add_argument('--sample-rate', type=int, default=1000)
|
| 380 |
+
rec.add_argument('--sd', action='store_true', help='Record to SD card')
|
| 381 |
+
rec.add_argument('--duration', default='G')
|
| 382 |
+
rec.add_argument('--remote', action='store_true', help='Stream via SSH')
|
| 383 |
+
|
| 384 |
+
# --- index ---
|
| 385 |
+
idx = sub.add_parser('index', help='Manage the text embedding index')
|
| 386 |
+
idx.add_argument('action', choices=['create', 'info', 'rebuild'])
|
| 387 |
+
idx.add_argument('--corpus', '-c', default=None)
|
| 388 |
+
idx.add_argument('--db', default='morphism.db')
|
| 389 |
+
idx.add_argument('--index', default='morphism')
|
| 390 |
+
idx.add_argument('--split-mode', default='line',
|
| 391 |
+
choices=['line', 'block', 'sentence', 'chunk'])
|
| 392 |
+
idx.add_argument('--chunk-size', type=int, default=512)
|
| 393 |
+
idx.add_argument('--chunk-overlap', type=int, default=64)
|
| 394 |
+
idx.add_argument('--batch-size', type=int, default=128)
|
| 395 |
+
idx.add_argument('--task', default='text-matching')
|
| 396 |
+
idx.add_argument('--model', default='jinaai/jina-embeddings-v3')
|
| 397 |
+
|
| 398 |
+
# --- decode ---
|
| 399 |
+
dec = sub.add_parser('decode', help='Run EEG → text decoding')
|
| 400 |
+
dec.add_argument('--mode', default='flood', choices=['flood', 'drift', 'focus', 'layered'])
|
| 401 |
+
dec.add_argument('--eeg', '-f', required=True)
|
| 402 |
+
dec.add_argument('--autoencoder', '-a', required=True)
|
| 403 |
+
dec.add_argument('--semantic', '-s', required=True)
|
| 404 |
+
dec.add_argument('--db', default='morphism.db')
|
| 405 |
+
dec.add_argument('--index', default='morphism')
|
| 406 |
+
dec.add_argument('--window-size', type=int, default=624)
|
| 407 |
+
dec.add_argument('--stride', type=int, default=32)
|
| 408 |
+
dec.add_argument('--batch-size', type=int, default=32)
|
| 409 |
+
dec.add_argument('--device', default=None)
|
| 410 |
+
dec.add_argument('--search-k', type=int, default=1024)
|
| 411 |
+
dec.add_argument('--final-k', type=int, default=1024)
|
| 412 |
+
dec.add_argument('--last-n', type=int, default=128)
|
| 413 |
+
dec.add_argument('--raw-eeg', action='store_true')
|
| 414 |
+
dec.add_argument('--input-dim', type=int, default=None)
|
| 415 |
+
dec.add_argument('--save-vectors', action='store_true')
|
| 416 |
+
dec.add_argument('--vector-output', default='semantic_vectors.npz')
|
| 417 |
+
|
| 418 |
+
args = p.parse_args()
|
| 419 |
+
|
| 420 |
+
if args.command is None:
|
| 421 |
+
p.print_help()
|
| 422 |
+
sys.exit(0)
|
| 423 |
+
|
| 424 |
+
if args.command == 'record':
|
| 425 |
+
cmd_record(args)
|
| 426 |
+
elif args.command == 'index':
|
| 427 |
+
if args.action in ('create', 'rebuild') and not args.corpus:
|
| 428 |
+
print("--corpus is required for create/rebuild")
|
| 429 |
+
sys.exit(1)
|
| 430 |
+
cmd_index(args)
|
| 431 |
+
elif args.command == 'decode':
|
| 432 |
+
cmd_decode(args)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
if __name__ == '__main__':
|
| 436 |
+
main()
|