acb commited on
Commit
5e284bb
·
verified ·
1 Parent(s): 55fdccb

Upload 5 files

Browse files
Files changed (5) hide show
  1. cyton.py +381 -0
  2. decode.py +612 -0
  3. eegembed.py +543 -0
  4. embed.py +424 -0
  5. morphism.py +436 -0
cyton.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+ import serial
5
+ import time
6
+ import paramiko
7
+ import io
8
+ from pathlib import Path
9
+ from datetime import datetime
10
+
11
+ def create_ssh_connection():
12
+ """Create SSH connection to remote server"""
13
+ ssh = paramiko.SSHClient()
14
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
15
+ try:
16
+ ssh.connect('topos.exypno.tech', port=420, username='albert')
17
+ return ssh
18
+ except Exception as e:
19
+ print(f"Failed to connect to remote server: {e}")
20
+ return None
21
+
22
+ def set_gain(ser, gain=8):
23
+ """Set 2x gain on all channels (1-16) for Cyton+Daisy"""
24
+ print(f"Setting {gain}x gain on all channels...")
25
+
26
+ # Stop any streaming first
27
+ ser.write(b's')
28
+ time.sleep(0.5)
29
+
30
+ gain_mapping = [1, 2, 4, 6, 8, 12, 24]
31
+ gain_val = gain_mapping.index(gain)
32
+
33
+ # Main board channels (1-8)
34
+ main_channels = ['1', '2', '3', '4', '5', '6', '7', '8']
35
+ # Daisy board channels (9-16 represented as Q-I)
36
+ daisy_channels = ['Q', 'W', 'E', 'R', 'T', 'Y', 'U', 'I']
37
+
38
+ # Combine all channel commands into one string
39
+ commands = ''
40
+ for ch in main_channels + daisy_channels:
41
+ commands += f'x{ch}0{gain_val}0000X'
42
+
43
+ # Send all commands at once
44
+ ser.write(commands.encode())
45
+ time.sleep(0.5)
46
+
47
+ # Clear any response from the serial buffer
48
+ ser.reset_input_buffer()
49
+
50
+ print("Gain settings updated")
51
+
52
+ def set_sample_rate(ser, freq):
53
+ """Set sample rate using the '~' command"""
54
+ # Sample rate mapping according to documentation
55
+ freq_mapping = {
56
+ 16000: '0',
57
+ 8000: '1',
58
+ 4000: '2',
59
+ 2000: '3',
60
+ 1000: '4',
61
+ 500: '5',
62
+ 250: '6'
63
+ }
64
+
65
+ if freq not in freq_mapping:
66
+ raise ValueError(f"Unsupported frequency {freq}Hz. Supported rates: {list(freq_mapping.keys())}")
67
+
68
+ # Stop any streaming first
69
+ ser.write(b's')
70
+ time.sleep(0.5)
71
+
72
+ # Set sample rate
73
+ command = f"~{freq_mapping[freq]}"
74
+ ser.write(command.encode())
75
+ time.sleep(0.5)
76
+
77
+ # Clear response from buffer
78
+ ser.reset_input_buffer()
79
+ print(f"Sample rate set to {freq}Hz")
80
+
81
+ def init_board(ser):
82
+ """Initialize the OpenBCI board for 16 channels"""
83
+ print("Initializing board...")
84
+
85
+ # Stop any previous streaming
86
+ ser.write(b's')
87
+ time.sleep(1)
88
+
89
+ # Soft reset
90
+ ser.write(b'v')
91
+ time.sleep(2)
92
+
93
+ # Clear buffers
94
+ ser.reset_input_buffer()
95
+ ser.reset_output_buffer()
96
+
97
+ # Enable 16 channel mode
98
+ ser.write(b'C')
99
+ time.sleep(1)
100
+
101
+ # Enable all channels (1-16)
102
+ # First 8 channels
103
+ commands = [b'!', b'@', b'#', b'$', b'%', b'^', b'&', b'*']
104
+ # Next 8 channels (Daisy module)
105
+ commands.extend([b'Q', b'W', b'E', b'R', b'T', b'Y', b'U', b'I'])
106
+
107
+ for cmd in commands:
108
+ ser.write(cmd)
109
+ time.sleep(0.1)
110
+
111
+ # Set high-speed mode
112
+ ser.write(b'\xF0\x06') # Set baud rate to 230400
113
+ time.sleep(1)
114
+ ser.baudrate = 230400
115
+
116
+ set_gain(ser, gain=2)
117
+
118
+ print("Board initialized")
119
+
120
+ def find_packet_start(ser):
121
+ """Find the start of a packet by looking for 0xA0 header"""
122
+ while True:
123
+ byte = ser.read()
124
+ if byte[0] == 0xA0: # Header byte
125
+ return byte
126
+ return None
127
+
128
+ def read_complete_packet(ser):
129
+ """Read a complete packet ensuring proper alignment"""
130
+ # Find the start of packet
131
+ start_byte = find_packet_start(ser)
132
+ if not start_byte:
133
+ return None
134
+
135
+ # Read remaining 32 bytes
136
+ remaining_bytes = ser.read(32)
137
+ if len(remaining_bytes) != 32:
138
+ return None
139
+
140
+ # Verify footer byte (0xCx)
141
+ if (remaining_bytes[31] & 0xF0) != 0xC0:
142
+ return None
143
+
144
+ return start_byte + remaining_bytes
145
+
146
+ def process_packet(packet):
147
+ """Process a 33-byte packet and extract channel data according to documentation"""
148
+ if len(packet) != 33:
149
+ return None
150
+
151
+ channels = []
152
+ for i in range(8):
153
+ start_idx = 2 + (i * 3) # Start at byte 2 (after header and sample number)
154
+ channel_data = packet[start_idx:start_idx + 3]
155
+
156
+ # Convert 24-bit to 32-bit signed int according to documentation
157
+ if (channel_data[0] & 0x80): # If negative number
158
+ value = -1 * ((((~channel_data[0] & 0xFF) << 16) |
159
+ ((~channel_data[1] & 0xFF) << 8) |
160
+ (~channel_data[2] & 0xFF)) + 1)
161
+ else: # If positive number
162
+ value = (channel_data[0] << 16) | (channel_data[1] << 8) | channel_data[2]
163
+
164
+ # Convert to microvolts: 4.5V / gain / (2^23 - 1)
165
+ scale_factor = 4.5 / (24.0 * 8388607.0) * 1000000 # Using gain of 24
166
+ channels.append(value * scale_factor)
167
+
168
+ return channels
169
+
170
+ def start_sd_recording(ser, duration='G'):
171
+ """Start recording to SD card with specified duration
172
+ Duration codes:
173
+ A = 5MIN
174
+ S = 15MIN
175
+ F = 30MIN
176
+ G = 1HR (default)
177
+ H = 2HR
178
+ J = 4HR
179
+ K = 12HR
180
+ L = 24HR
181
+ a = ~14sec (test)
182
+ """
183
+ valid_durations = {'A', 'S', 'F', 'G', 'H', 'J', 'K', 'L', 'a'}
184
+ if duration not in valid_durations:
185
+ raise ValueError(f"Invalid duration code. Valid codes: {valid_durations}")
186
+
187
+ print(f"Starting SD card recording with duration code {duration}")
188
+ ser.write(duration.encode())
189
+ time.sleep(0.5)
190
+ ser.write(b'b')
191
+ time.sleep(0.5)
192
+
193
+ def stop_sd_recording(ser):
194
+ """Stop SD card recording"""
195
+ print("Stopping SD card recording")
196
+ ser.write(b's')
197
+ time.sleep(0.5)
198
+ ser.write(b'j')
199
+ time.sleep(0.5)
200
+
201
+ def sd_record(port, duration='G', sample_rate=1000):
202
+ """Record data to SD card"""
203
+ duration_map = {
204
+ 'A': 5*60, # 5 minutes
205
+ 'S': 15*60, # 15 minutes
206
+ 'F': 30*60, # 30 minutes
207
+ 'G': 60*60, # 1 hour
208
+ 'H': 2*60*60, # 2 hours
209
+ 'J': 4*60*60, # 4 hours
210
+ 'K': 12*60*60, # 12 hours
211
+ 'L': 24*60*60, # 24 hours
212
+ 'a': 14 # ~14 seconds (test)
213
+ }
214
+
215
+ # Open serial port
216
+ ser = serial.Serial(port, 115200)
217
+ time.sleep(2)
218
+
219
+ try:
220
+ # Initialize board
221
+ init_board(ser)
222
+
223
+ # Set sample rate
224
+ set_sample_rate(ser, sample_rate)
225
+
226
+ # Start recording
227
+ start_sd_recording(ser, duration)
228
+
229
+ # Calculate wait time
230
+ wait_time = duration_map[duration]
231
+ start_time = time.time()
232
+
233
+ print(f"Recording to SD card for {wait_time} seconds...")
234
+
235
+ try:
236
+ while (time.time() - start_time) < wait_time:
237
+ remaining = wait_time - (time.time() - start_time)
238
+ print(f"\rRecording... {remaining:.1f} seconds remaining ", end='')
239
+ time.sleep(0.1)
240
+
241
+ except KeyboardInterrupt:
242
+ print("\nRecording interrupted by user")
243
+
244
+ finally:
245
+ # Always stop recording
246
+ stop_sd_recording(ser)
247
+ print("\nRecording complete")
248
+
249
+ finally:
250
+ ser.close()
251
+
252
+ def main():
253
+ parser = argparse.ArgumentParser(description='OpenBCI EEG Recording Tool')
254
+ parser.add_argument('--port', '-p', type=str, default='/dev/ttyUSB0',
255
+ help='Serial port to use (default: /dev/ttyUSB0)')
256
+ parser.add_argument('--filename', '-o', type=str,
257
+ help='Output filename (default: openbci_<timestamp>.txt)')
258
+ parser.add_argument('--sd', action='store_true',
259
+ help='Record to SD card instead of streaming to PC')
260
+ parser.add_argument('--duration', type=str, default='G',
261
+ help='SD card recording duration code (default: G = 1 hour)')
262
+ parser.add_argument('--sample-rate', type=int, default=1000,
263
+ help='Sample rate in Hz (default: 1000)')
264
+ parser.add_argument('--remote', action='store_true',
265
+ help='Write to remote server instead of local file')
266
+ args = parser.parse_args()
267
+
268
+ if args.sd:
269
+ sd_record(args.port, args.duration, args.sample_rate)
270
+ return
271
+
272
+ if args.filename is None:
273
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
274
+ args.filename = f"openbci_{timestamp}.txt"
275
+
276
+ # Open serial port
277
+ ser = serial.Serial(args.port, 115200)
278
+ time.sleep(2)
279
+ init_board(ser)
280
+
281
+ filename = args.filename
282
+
283
+ if args.remote:
284
+ ssh = create_ssh_connection()
285
+ if not ssh:
286
+ print("Failed to establish SSH connection. Exiting.")
287
+ return
288
+
289
+ sftp = ssh.open_sftp()
290
+ remote_file = sftp.open(filename, 'w')
291
+
292
+ # Write header
293
+ header = "Timestamp,"
294
+ header += ",".join([f"Channel{i+1}" for i in range(16)])
295
+ header += "\n"
296
+ remote_file.write(header)
297
+ else:
298
+ # Original local file writing
299
+ with open(filename, 'w') as f:
300
+ header = "Timestamp,"
301
+ header += ",".join([f"Channel{i+1}" for i in range(16)])
302
+ header += "\n"
303
+ f.write(header)
304
+
305
+ # Start streaming
306
+ ser.write(b'b')
307
+ time.sleep(0.5)
308
+ ser.reset_input_buffer()
309
+
310
+ print(f"Started recording to {filename}")
311
+
312
+ packet_count = 0
313
+ start_time = time.time()
314
+ buffer = io.StringIO()
315
+ last_write = time.time()
316
+
317
+ try:
318
+ while True:
319
+ # Read two properly aligned packets
320
+ packet1 = read_complete_packet(ser)
321
+ if packet1:
322
+ packet2 = read_complete_packet(ser)
323
+ if packet2:
324
+ # Process both packets
325
+ data1 = process_packet(packet1)
326
+ data2 = process_packet(packet2)
327
+
328
+ if data1 and data2:
329
+ packet_count += 1
330
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
331
+ all_channels = data1 + data2 # Combine all 16 channels
332
+ data_str = [f"{x:.6f}" for x in all_channels]
333
+ line = timestamp + "," + ",".join(data_str) + "\n"
334
+
335
+ if args.remote:
336
+ buffer.write(line)
337
+
338
+ # Write buffer every 100ms
339
+ if time.time() - last_write >= 0.1:
340
+ remote_file.write(buffer.getvalue())
341
+ buffer = io.StringIO()
342
+ last_write = time.time()
343
+ else:
344
+ with open(filename, 'a') as f:
345
+ f.write(line)
346
+
347
+ # Print status every second
348
+ if packet_count % 125 == 0:
349
+ elapsed_time = time.time() - start_time
350
+ rate = packet_count / elapsed_time
351
+ print(f"\rRecording... {rate:.1f} Hz, {packet_count} packets", end='')
352
+
353
+ # Check for buffer overflow
354
+ if ser.in_waiting > 1000:
355
+ print(f"\nWarning: Buffer building up ({ser.in_waiting} bytes)")
356
+ ser.reset_input_buffer()
357
+
358
+ except KeyboardInterrupt:
359
+ # Stop streaming
360
+ ser.write(b's')
361
+ ser.close()
362
+
363
+ if args.remote:
364
+ # Write any remaining data in buffer
365
+ if buffer.getvalue():
366
+ remote_file.write(buffer.getvalue())
367
+ remote_file.close()
368
+ sftp.close()
369
+ ssh.close()
370
+
371
+ # Print final statistics
372
+ elapsed_time = time.time() - start_time
373
+ rate = packet_count / elapsed_time
374
+ print(f"\n\nRecording stopped")
375
+ print(f"Duration: {elapsed_time:.1f} seconds")
376
+ print(f"Packets recorded: {packet_count}")
377
+ print(f"Average sample rate: {rate:.1f} Hz")
378
+ print(f"Data saved to: {filename}")
379
+
380
+ if __name__ == "__main__":
381
+ main()
decode.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import sys
4
+ import time
5
+ import hashlib
6
+ import numpy as np
7
+ import torch
8
+ import sqlite3
9
+ import logging
10
+ import argparse
11
+ import random
12
+ import traceback
13
+ import faiss
14
+ import pickle
15
+ from datetime import datetime
16
+ from collections import deque, defaultdict
17
+ from typing import List, Dict, Tuple, Optional, Union, Any
18
+ from pathlib import Path
19
+
20
+ # Import from our streaming module
21
+ from eegembed import EEGEmbeddingStream
22
+
23
+ PRINT_DEBUG_HASH = False
24
+
25
+ def fix_encoding(s):
26
+ if not s:
27
+ return s
28
+
29
+ if isinstance(s, str):
30
+ b = s.encode('utf-8', 'surrogateescape')
31
+ else:
32
+ b = s
33
+
34
+ fixed = b.decode('utf-8', 'replace')
35
+ if 'ì' in s or 'í' in s or 'ï' in s:
36
+ return ""
37
+
38
+ return fixed
39
+
40
+ # Set up logging
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
44
+ )
45
+ logger = logging.getLogger("EEGSemanticStream")
46
+
47
+ def setup_eeg_logger(eeg_file_path):
48
+ """Set up a file logger based on the EEG filename."""
49
+ base_name = os.path.basename(eeg_file_path)
50
+ file_name = os.path.splitext(base_name)[0]
51
+
52
+ logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_logs")
53
+ if not os.path.exists(logs_dir):
54
+ os.makedirs(logs_dir)
55
+
56
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
57
+ log_file_path = os.path.join(logs_dir, f"{file_name}_{timestamp}.log")
58
+
59
+ log_file = open(log_file_path, "w", encoding="utf-8")
60
+ log_file.write(f"--- Session started at {timestamp} for EEG file: {base_name} ---\n")
61
+ log_file.flush()
62
+
63
+ return log_file
64
+
65
+
66
+ class EmbeddingIndex:
67
+ def __init__(self, dim=1536, use_gpu=True):
68
+ self.dim = dim
69
+ self.use_gpu = use_gpu and faiss.get_num_gpus() > 0
70
+ self.index = None
71
+ self.gpu_resources = None
72
+ self.message_ids = []
73
+
74
+ if self.use_gpu:
75
+ self.gpu_resources = faiss.StandardGpuResources()
76
+
77
+ def add_embeddings(self, embeddings: np.ndarray, message_ids: List[int]):
78
+ logger.info(f"Building FAISS index with {len(embeddings)} embeddings")
79
+
80
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
81
+ embeddings = embeddings / (norms + 1e-8)
82
+
83
+ cpu_index = faiss.IndexFlatIP(self.dim)
84
+ cpu_index.add(embeddings.astype(np.float32))
85
+
86
+ if self.use_gpu:
87
+ try:
88
+ self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index)
89
+ logger.info("Using GPU FAISS index")
90
+ except Exception as e:
91
+ logger.warning(f"GPU failed: {e}. Using CPU.")
92
+ self.index = cpu_index
93
+ self.use_gpu = False
94
+ else:
95
+ self.index = cpu_index
96
+ logger.info("Using CPU FAISS index")
97
+
98
+ self.message_ids = message_ids
99
+
100
+ def get_current_count(self):
101
+ if self.index is None:
102
+ return 0
103
+ return self.index.ntotal
104
+
105
+ def search(self, query: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
106
+ if self.index is None:
107
+ raise RuntimeError("Index not initialized")
108
+
109
+ norm = np.linalg.norm(query)
110
+ if norm > 0:
111
+ query = query / norm
112
+
113
+ actual_k = min(k, self.get_current_count())
114
+ if actual_k == 0:
115
+ return np.array([]), np.array([])
116
+
117
+ similarities, indices = self.index.search(query.astype(np.float32), actual_k)
118
+ distances = 1.0 - similarities
119
+
120
+ labels = np.array([[self.message_ids[idx] for idx in row] for row in indices])
121
+
122
+ return distances, labels
123
+
124
+ def save(self, path: str):
125
+ if self.index is None:
126
+ raise RuntimeError("Cannot save uninitialized index")
127
+
128
+ if self.use_gpu:
129
+ cpu_index = faiss.index_gpu_to_cpu(self.index)
130
+ faiss.write_index(cpu_index, f"{path}.index")
131
+ else:
132
+ faiss.write_index(self.index, f"{path}.index")
133
+
134
+ with open(f"{path}_message_ids.pkl", 'wb') as f:
135
+ pickle.dump(self.message_ids, f)
136
+
137
+ @classmethod
138
+ def load(cls, path: str, use_gpu: bool = True) -> 'EmbeddingIndex':
139
+ with open(f"{path}_message_ids.pkl", 'rb') as f:
140
+ message_ids = pickle.load(f)
141
+
142
+ index = cls(use_gpu=use_gpu)
143
+ index.message_ids = message_ids
144
+
145
+ cpu_index = faiss.read_index(f"{path}.index")
146
+
147
+ if index.use_gpu:
148
+ try:
149
+ index.index = faiss.index_cpu_to_gpu(index.gpu_resources, 0, cpu_index)
150
+ logger.info("Loaded existing index and moved to GPU")
151
+ except Exception as e:
152
+ logger.warning(f"Failed to move loaded index to GPU: {e}. Using CPU.")
153
+ index.index = cpu_index
154
+ index.use_gpu = False
155
+ else:
156
+ index.index = cpu_index
157
+ logger.info("Loaded existing index on CPU")
158
+
159
+ return index
160
+
161
+
162
+ class EEGSemanticProcessor:
163
+ """
164
+ Process EEG data through autoencoder and semantic model pipeline,
165
+ then lookup similar messages.
166
+ """
167
+ def __init__(
168
+ self,
169
+ autoencoder_model_path: str,
170
+ semantic_model_path: str,
171
+ nexus_db_path: str,
172
+ embeddings_db_path: str,
173
+ index_path: str = None,
174
+ eeg_file_path: str = None,
175
+ window_size: int = 624,
176
+ stride: int = 64,
177
+ batch_size: int = 32,
178
+ normalize: bool = True,
179
+ device: str = None,
180
+ search_k: int = 180,
181
+ final_k: int = 90,
182
+ use_raw_eeg: bool = False,
183
+ last_n_messages: int = 3,
184
+ input_dim_override: int = None,
185
+ save_vectors: bool = False,
186
+ vector_output_path: str = None
187
+ ):
188
+ if device is None:
189
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ else:
191
+ self.device = torch.device(device)
192
+
193
+ logger.info(f"Using device: {self.device}")
194
+
195
+ self.last_n_messages = last_n_messages
196
+ self.use_raw_eeg = use_raw_eeg
197
+ self.input_dim_override = input_dim_override
198
+
199
+ # Initialize EEG stream
200
+ self.eeg_stream = EEGEmbeddingStream(
201
+ file_path=eeg_file_path if eeg_file_path else "",
202
+ model_path=autoencoder_model_path,
203
+ window_size=window_size,
204
+ stride=stride,
205
+ normalize=normalize,
206
+ batch_size=batch_size,
207
+ device=self.device
208
+ )
209
+
210
+ # Load traced semantic model
211
+ logger.info(f"Loading traced semantic model from {semantic_model_path}")
212
+ self.semantic_model = torch.jit.load(semantic_model_path, map_location=self.device)
213
+ self.semantic_model.eval()
214
+
215
+ # Probe to get input/output dims
216
+ # Try a few common input sizes to find the right one
217
+ self._semantic_input_dim = None
218
+ self._semantic_output_dim = None
219
+ for test_dim in [64, 10112]:
220
+ try:
221
+ dummy = torch.randn(1, test_dim, device=self.device)
222
+ with torch.no_grad():
223
+ out = self.semantic_model(dummy)
224
+ self._semantic_input_dim = test_dim
225
+ self._semantic_output_dim = out.shape[1]
226
+ logger.info(f"Semantic model: input_dim={self._semantic_input_dim}, output_dim={self._semantic_output_dim}")
227
+ break
228
+ except Exception:
229
+ continue
230
+
231
+ if self._semantic_input_dim is None:
232
+ logger.warning("Could not auto-detect semantic model input dim. Will adapt at runtime.")
233
+
234
+ self.log_file = setup_eeg_logger(eeg_file_path) if eeg_file_path else None
235
+
236
+ # Initialize database connections
237
+ self.nexus_conn = sqlite3.connect(nexus_db_path)
238
+ self.embeddings_conn = sqlite3.connect(embeddings_db_path)
239
+
240
+ # Message tracking system
241
+ self.search_k = search_k
242
+ self.final_k = final_k
243
+ self.message_counts = defaultdict(int)
244
+ self.recent_messages = deque(maxlen=10)
245
+ self.repetition_penalty = 1.5
246
+
247
+ logger.info("Creating embedding index")
248
+ self.embedding_index = self._create_index(index_path)
249
+
250
+ self.error_count = 0
251
+ self.max_consecutive_errors = 5
252
+
253
+ self.save_vectors = save_vectors
254
+ self.vector_output_path = vector_output_path
255
+
256
+ if self.save_vectors:
257
+ self.vectors_list = []
258
+ self.timestamps = []
259
+ logger.info(f"Vector saving enabled. Output will be saved to {self.vector_output_path}")
260
+
261
+ self.previous_message_sets = deque(maxlen=self.last_n_messages)
262
+
263
+ def _create_index(self, index_path: str = None) -> EmbeddingIndex:
264
+ """Create or load the embedding index for similarity search"""
265
+
266
+ cursor = self.embeddings_conn.cursor()
267
+
268
+ cursor.execute("SELECT COUNT(*) FROM embeddings")
269
+ db_count = cursor.fetchone()[0]
270
+
271
+ cursor.execute("SELECT MAX(message_id) FROM embeddings")
272
+ db_max_id = cursor.fetchone()[0]
273
+
274
+ if index_path and os.path.exists(f"{index_path}.index"):
275
+ try:
276
+ logger.info(f"Checking existing index at {index_path}")
277
+
278
+ index = EmbeddingIndex.load(index_path)
279
+
280
+ metadata_path = f"{index_path}_metadata.npz"
281
+ if os.path.exists(metadata_path):
282
+ metadata = np.load(metadata_path, allow_pickle=True)
283
+ saved_count = int(metadata.get('count', 0))
284
+ saved_max_id = int(metadata.get('max_message_id', 0))
285
+
286
+ logger.info(f"Saved index: {saved_count} items, max_id={saved_max_id}")
287
+ logger.info(f"Database: {db_count} items, max_id={db_max_id}")
288
+
289
+ if db_count != saved_count or db_max_id != saved_max_id:
290
+ logger.info("Database has changed. Recreating index...")
291
+ else:
292
+ logger.info("Database unchanged. Using existing index...")
293
+ return index
294
+
295
+ except Exception as e:
296
+ logger.warning(f"Error checking existing index: {str(e)}")
297
+ logger.info("Will create new index")
298
+
299
+ logger.info("Creating new index from database...")
300
+
301
+ cursor.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id")
302
+
303
+ embeddings = []
304
+ message_ids = []
305
+
306
+ for message_id, emb in cursor.fetchall():
307
+ embedding = np.frombuffer(emb, dtype=np.float32)
308
+ embeddings.append(embedding)
309
+ message_ids.append(message_id)
310
+
311
+ if not embeddings:
312
+ raise ValueError("No embeddings found in database")
313
+
314
+ embeddings = np.vstack(embeddings)
315
+ logger.info(f"Loaded {len(embeddings)} embeddings with shape: {embeddings.shape}")
316
+
317
+ index = EmbeddingIndex(dim=embeddings.shape[1])
318
+ index.add_embeddings(embeddings, message_ids)
319
+
320
+ if index_path:
321
+ logger.info(f"Saving index to {index_path}")
322
+ index.save(index_path)
323
+
324
+ metadata = {
325
+ 'count': db_count,
326
+ 'max_message_id': db_max_id
327
+ }
328
+ np.savez(f"{index_path}_metadata.npz", **metadata)
329
+
330
+ return index
331
+
332
+ def process_eeg_embedding(self, eeg_embedding: np.ndarray) -> torch.Tensor:
333
+ """Convert EEG embedding to text embedding using the traced semantic model"""
334
+ with torch.no_grad():
335
+ tensor = torch.tensor(eeg_embedding, dtype=torch.float32).to(self.device)
336
+
337
+ if len(tensor.shape) < 2:
338
+ tensor = tensor.unsqueeze(0)
339
+
340
+ batch_size = tensor.shape[0]
341
+ tensor = tensor.reshape(batch_size, -1)
342
+
343
+ # Adapt dimensions if needed
344
+ if self._semantic_input_dim is not None:
345
+ current_features = tensor.shape[1]
346
+ if current_features != self._semantic_input_dim:
347
+ if current_features < self._semantic_input_dim:
348
+ padded = torch.zeros(batch_size, self._semantic_input_dim, device=self.device)
349
+ padded[:, :current_features] = tensor
350
+ tensor = padded
351
+ else:
352
+ tensor = tensor[:, :self._semantic_input_dim]
353
+
354
+ return self.semantic_model(tensor)
355
+
356
+ def find_similar_messages(self, embedding: torch.Tensor, assistant_only=False) -> List[str]:
357
+ """Find similar messages using the embedding index"""
358
+ embedding_np = embedding.detach().cpu().numpy()
359
+ if len(embedding_np.shape) > 1:
360
+ embedding_np = embedding_np.reshape(1, -1)
361
+
362
+ try:
363
+ distances, indices = self.embedding_index.search(embedding_np, self.search_k)
364
+ distances = distances.flatten()
365
+ indices = indices.flatten()
366
+
367
+ cursor = self.nexus_conn.cursor()
368
+ candidates = []
369
+
370
+ if assistant_only:
371
+ query = """
372
+ SELECT content FROM messages
373
+ WHERE id = ? AND role = 'assistant'
374
+ """
375
+ else:
376
+ query = """
377
+ SELECT content FROM messages
378
+ WHERE id = ?
379
+ """
380
+
381
+ for message_id, distance in zip(indices, distances):
382
+ cursor.execute(query, (int(message_id),))
383
+ if result := cursor.fetchone():
384
+ content = result[0]
385
+ candidates.append(content)
386
+
387
+ return candidates[:self.final_k]
388
+
389
+ except Exception as e:
390
+ logger.error(f"Error during similarity search: {str(e)}")
391
+ traceback.print_exc()
392
+ return []
393
+
394
+ def save_vectors_to_disk(self):
395
+ """Save the collected vectors and timestamps to disk"""
396
+ if not self.vectors_list:
397
+ logger.warning("No vectors to save")
398
+ return
399
+
400
+ output_dir = os.path.dirname(self.vector_output_path)
401
+ if output_dir and not os.path.exists(output_dir):
402
+ os.makedirs(output_dir)
403
+
404
+ vectors_array = np.vstack(self.vectors_list)
405
+ timestamps_array = np.array(self.timestamps)
406
+
407
+ logger.info(f"Saving {len(self.vectors_list)} vectors to {self.vector_output_path}")
408
+ np.savez(
409
+ self.vector_output_path,
410
+ vectors=vectors_array,
411
+ timestamps=timestamps_array
412
+ )
413
+ logger.info(f"Vectors saved successfully to {self.vector_output_path}")
414
+
415
+ def process_streaming_embeddings(self, callback=None):
416
+ """
417
+ Process streaming EEG embeddings through the semantic model
418
+ and find similar messages.
419
+ """
420
+ self.eeg_stream.start()
421
+
422
+ try:
423
+ consecutive_errors = 0
424
+ while True:
425
+ try:
426
+ for embedding_data in self.eeg_stream.get_embeddings(timeout=0.5):
427
+ try:
428
+ autoencoder_embedding = embedding_data['embedding']
429
+ semantic_embedding = self.process_eeg_embedding(autoencoder_embedding)
430
+
431
+ if self.save_vectors:
432
+ embedding_np = semantic_embedding.detach().cpu().numpy()
433
+ self.vectors_list.append(embedding_np)
434
+ self.timestamps.append({
435
+ 'start': embedding_data['start_timestamp'],
436
+ 'end': embedding_data['end_timestamp']
437
+ })
438
+
439
+ if len(self.vectors_list) % 100 == 0:
440
+ logger.info(f"Collected {len(self.vectors_list)} vectors so far")
441
+
442
+ continue
443
+
444
+ similar_messages = self.find_similar_messages(semantic_embedding)
445
+
446
+ result = {
447
+ 'start_timestamp': embedding_data['start_timestamp'],
448
+ 'end_timestamp': embedding_data['end_timestamp'],
449
+ 'processing_time': 0,
450
+ 'similar_messages': similar_messages
451
+ }
452
+
453
+ if callback:
454
+ callback(result)
455
+ else:
456
+ self._print_unique_lines(result)
457
+
458
+ consecutive_errors = 0
459
+
460
+ except Exception as e:
461
+ print(f"Error: {str(e)}", file=sys.stderr)
462
+ consecutive_errors += 1
463
+
464
+ if consecutive_errors >= self.max_consecutive_errors:
465
+ raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})")
466
+
467
+ time.sleep(0.01)
468
+
469
+ except Exception as e:
470
+ if "Too many consecutive errors" in str(e):
471
+ raise
472
+ print(f"Error: {str(e)}", file=sys.stderr)
473
+ consecutive_errors += 1
474
+ if consecutive_errors >= self.max_consecutive_errors:
475
+ raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})")
476
+ time.sleep(1)
477
+
478
+ except KeyboardInterrupt:
479
+ pass
480
+ except Exception as e:
481
+ print(f"Fatal error: {str(e)}", file=sys.stderr)
482
+ finally:
483
+ if self.save_vectors and self.vectors_list:
484
+ self.save_vectors_to_disk()
485
+
486
+ self.eeg_stream.stop()
487
+
488
+ def _print_unique_lines(self, result):
489
+ """Print only lines that aren't in common with the last n batches of messages"""
490
+ if not result['similar_messages']:
491
+ return
492
+
493
+ sample_size = min(42, len(result['similar_messages']))
494
+ current_messages = random.sample(result['similar_messages'], sample_size)
495
+
496
+ current_lines = set()
497
+ for message in current_messages:
498
+ for line in message.splitlines():
499
+ line = line.strip()
500
+ if line:
501
+ current_lines.add(line)
502
+
503
+ unique_lines = current_lines.copy()
504
+ for previous_lines in self.previous_message_sets:
505
+ unique_lines -= previous_lines
506
+
507
+ self.previous_message_sets.append(current_lines)
508
+
509
+ __uniq_log_empty = False
510
+ if unique_lines:
511
+ if PRINT_DEBUG_HASH:
512
+ unique_lines = [f"{hash} | {line}" for (hash, line) in zip(
513
+ map(lambda s: hashlib.md5(s.encode()).hexdigest()[:8], unique_lines),
514
+ unique_lines)]
515
+
516
+ unique_lines = filter(lambda s: bool(s), map(fix_encoding, unique_lines))
517
+
518
+ output_text = "\n".join(sorted(unique_lines))
519
+ print(output_text)
520
+
521
+ if hasattr(self, 'log_file') and self.log_file:
522
+ try:
523
+ self.log_file.write(output_text + "\n")
524
+ self.log_file.flush()
525
+ except Exception as e:
526
+ print(f"Error writing to log file: {str(e)}", file=sys.stderr)
527
+
528
+ elif __uniq_log_empty:
529
+ logger.info(f"No unique lines")
530
+
531
+
532
+ def main():
533
+ parser = argparse.ArgumentParser(description='Process EEG data through semantic model and lookup similar messages')
534
+
535
+ parser.add_argument('--autoencoder', '-a', type=str, required=True,
536
+ help='Path to the traced autoencoder encoder model')
537
+ parser.add_argument('--semantic-model', '-s', type=str, required=True,
538
+ help='Path to the traced semantic model')
539
+
540
+ parser.add_argument('--nexus-db', '-n', type=str,
541
+ default=os.path.expanduser('~/.nexus/data/nexus-new.db'),
542
+ help='Path to the nexus database')
543
+ parser.add_argument('--embeddings-db', '-e', type=str, default='emb_full.db',
544
+ help='Path to the embeddings database')
545
+ parser.add_argument('--index', '-i', type=str, default='embedding_index',
546
+ help='Path to save/load the FAISS index')
547
+
548
+ parser.add_argument('--eeg-file', '-f', type=str, required=True,
549
+ help='Path to the EEG data file to monitor')
550
+ parser.add_argument('--window-size', type=int, default=624,
551
+ help='Window size in samples')
552
+ parser.add_argument('--stride', type=int, default=32,
553
+ help='Stride between windows')
554
+ parser.add_argument('--batch-size', type=int, default=32,
555
+ help='Batch size for processing')
556
+ parser.add_argument('--no-normalize', dest='normalize', action='store_false',
557
+ help='Disable normalization of EEG data')
558
+
559
+ parser.add_argument('--search-k', type=int, default=180,
560
+ help='Number of candidates to retrieve for selection')
561
+ parser.add_argument('--final-k', type=int, default=90,
562
+ help='Number of results to show')
563
+
564
+ parser.add_argument('--device', type=str, default=None,
565
+ help='Device to use (cuda or cpu)')
566
+
567
+ parser.add_argument('--last_n', type=int, default=None,
568
+ help='Window queue size for repetition filter')
569
+
570
+ parser.add_argument('--use-raw-eeg', action='store_true',
571
+ help='Use raw EEG data with semantic model (skip autoencoder)')
572
+ parser.add_argument('--input-dim', type=int,
573
+ help='Override the input dimension for the semantic model')
574
+
575
+ parser.add_argument('--save-vectors', action='store_true',
576
+ help='Save semantic vectors to disk without generating output')
577
+ parser.add_argument('--vector-output', type=str, default='semantic_vectors.npz',
578
+ help='Path to save the semantic vectors')
579
+
580
+ args = parser.parse_args()
581
+
582
+ processor = EEGSemanticProcessor(
583
+ autoencoder_model_path=args.autoencoder,
584
+ semantic_model_path=args.semantic_model,
585
+ nexus_db_path=args.nexus_db,
586
+ embeddings_db_path=args.embeddings_db,
587
+ index_path=args.index,
588
+ last_n_messages=args.last_n,
589
+ eeg_file_path=args.eeg_file,
590
+ window_size=args.window_size,
591
+ stride=args.stride,
592
+ batch_size=args.batch_size,
593
+ normalize=args.normalize,
594
+ device=args.device,
595
+ search_k=args.search_k,
596
+ final_k=args.final_k,
597
+ use_raw_eeg=args.use_raw_eeg,
598
+ input_dim_override=args.input_dim,
599
+ save_vectors=args.save_vectors,
600
+ vector_output_path=args.vector_output
601
+ )
602
+
603
+ try:
604
+ processor.process_streaming_embeddings()
605
+ except KeyboardInterrupt:
606
+ pass
607
+ except Exception as e:
608
+ print(f"Error: {str(e)}", file=sys.stderr)
609
+
610
+
611
+ if __name__ == "__main__":
612
+ main()
eegembed.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import csv
4
+ import numpy as np
5
+ import torch
6
+ import threading
7
+ import queue
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import List, Dict, Tuple, Optional, Callable, Generator, Union, Any
11
+ import logging
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger("EEGStream")
18
+
19
+
20
+ class EncoderExtractor:
21
+ def __init__(self, model_path, device=None, force_sequence_length=None):
22
+ if device is None:
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ else:
25
+ self.device = device
26
+
27
+ self.force_sequence_length = force_sequence_length
28
+ logger.info(f"Loading traced encoder from {model_path} to {self.device}")
29
+ self.model = torch.jit.load(model_path, map_location=self.device)
30
+ self.model.eval()
31
+
32
+ dummy = torch.randn(1, 16, force_sequence_length or 624, device=self.device)
33
+ with torch.no_grad():
34
+ self._embedding_size = self.model(dummy).shape[1]
35
+ logger.info(f"Embedding size: {self._embedding_size}")
36
+
37
+ def get_embedding_size(self):
38
+ return self._embedding_size
39
+
40
+ def embed(self, data):
41
+ if self.force_sequence_length and data.shape[2] != self.force_sequence_length:
42
+ data = torch.nn.functional.interpolate(
43
+ data, size=self.force_sequence_length, mode='linear', align_corners=False
44
+ )
45
+ with torch.no_grad():
46
+ return self.model(data.to(self.device))
47
+
48
+ class EEGFileWatcher:
49
+ """
50
+ Watches a CSV file for new data and yields new lines as they appear.
51
+ """
52
+ def __init__(self, file_path: str, poll_interval: float = 0.1):
53
+ """
54
+ Initialize the file watcher.
55
+
56
+ Args:
57
+ file_path: Path to the CSV file to watch
58
+ poll_interval: How often to check for new data (in seconds)
59
+ """
60
+ self.file_path = Path(file_path)
61
+ self.poll_interval = poll_interval
62
+ self.last_position = 0
63
+ self.running = False
64
+ self.thread = None
65
+ self.queue = queue.Queue()
66
+ self.header = None
67
+
68
+ def start(self):
69
+ """Start watching the file in a background thread."""
70
+ if self.running:
71
+ return
72
+
73
+ self.running = True
74
+ self.thread = threading.Thread(target=self._watch_file, daemon=True)
75
+ self.thread.start()
76
+
77
+ def stop(self):
78
+ """Stop watching the file."""
79
+ self.running = False
80
+ if self.thread:
81
+ self.thread.join(timeout=1.0)
82
+
83
+ def _watch_file(self):
84
+ """Background thread that watches the file for changes."""
85
+ # Wait for the file to exist
86
+ while self.running and not self.file_path.exists():
87
+ logger.info(f"Waiting for file {self.file_path} to exist...")
88
+ time.sleep(self.poll_interval)
89
+
90
+ logger.info(f"File {self.file_path} found, starting to watch")
91
+
92
+ # Keep track of the file position
93
+ self.last_position = 0
94
+
95
+ # Read header first
96
+ try:
97
+ with open(self.file_path, 'r') as f:
98
+ self.header = f.readline().strip()
99
+ self.last_position = f.tell()
100
+ self.queue.put(self.header)
101
+ except Exception as e:
102
+ logger.error(f"Error reading header: {e}")
103
+
104
+ while self.running:
105
+ try:
106
+ # Check if the file has grown
107
+ current_size = self.file_path.stat().st_size
108
+ if current_size > self.last_position:
109
+ # Read new data
110
+ with open(self.file_path, 'r') as f:
111
+ f.seek(self.last_position)
112
+ new_data = f.read()
113
+ self.last_position = f.tell()
114
+
115
+ # Process new lines (excluding partial lines)
116
+ lines = new_data.split('\n')
117
+ if not new_data.endswith('\n'):
118
+ # The last line might be incomplete, so we'll read it again next time
119
+ self.last_position -= len(lines[-1])
120
+ lines = lines[:-1]
121
+
122
+ # Add complete lines to the queue
123
+ for line in lines:
124
+ if line.strip(): # Skip empty lines
125
+ self.queue.put(line)
126
+ except Exception as e:
127
+ logger.error(f"Error watching file: {e}")
128
+
129
+ time.sleep(self.poll_interval)
130
+
131
+ def get_new_lines(self, timeout: Optional[float] = None) -> List[str]:
132
+ """
133
+ Get any new lines that have been read since the last call.
134
+
135
+ Args:
136
+ timeout: How long to wait for new data (in seconds). None means don't wait.
137
+
138
+ Returns:
139
+ List of new lines (might be empty if no new data)
140
+ """
141
+ lines = []
142
+ try:
143
+ # Get the first line (with timeout)
144
+ line = self.queue.get(timeout=timeout)
145
+ lines.append(line)
146
+
147
+ # Get any remaining lines (without waiting)
148
+ while True:
149
+ try:
150
+ line = self.queue.get_nowait()
151
+ lines.append(line)
152
+ except queue.Empty:
153
+ break
154
+ except queue.Empty:
155
+ pass
156
+
157
+ return lines
158
+
159
+
160
+ class SlidingWindowProcessor:
161
+ """
162
+ Processes data using a sliding window approach.
163
+ """
164
+ def __init__(
165
+ self,
166
+ window_size: int,
167
+ stride: int,
168
+ num_channels: int,
169
+ channel_means: List[float],
170
+ channel_stds: List[float],
171
+ normalize: bool = True
172
+ ):
173
+ """
174
+ Initialize the sliding window processor.
175
+
176
+ Args:
177
+ window_size: Number of data points in each window
178
+ stride: Number of data points to advance between windows
179
+ num_channels: Number of data channels
180
+ channel_means: Mean value for each channel (for normalization)
181
+ channel_stds: Standard deviation for each channel (for normalization)
182
+ normalize: Whether to normalize the data
183
+ """
184
+ self.window_size = window_size
185
+ self.stride = stride
186
+ self.num_channels = num_channels
187
+ self.channel_means = np.array(channel_means)
188
+ self.channel_stds = np.array(channel_stds)
189
+ self.normalize = normalize
190
+
191
+ # Buffer to hold data points
192
+ self.buffer = []
193
+
194
+ # Current position in the buffer
195
+ self.current_pos = 0
196
+
197
+ def add_data(self, data_points: List[Dict[str, Union[str, float]]]):
198
+ """
199
+ Add new data points to the buffer.
200
+
201
+ Args:
202
+ data_points: List of data points. Each point should be a dictionary with
203
+ 'timestamp' and channel values.
204
+ """
205
+ self.buffer.extend(data_points)
206
+
207
+ def get_windows(self) -> Generator[Tuple[List[str], np.ndarray], None, None]:
208
+ """
209
+ Generate windows from the buffered data using the sliding window approach.
210
+
211
+ Yields:
212
+ Tuple of (timestamps, data array) for each window
213
+ data array shape: [num_channels, window_size]
214
+ """
215
+ while self.current_pos + self.window_size <= len(self.buffer):
216
+ # Extract window
217
+ window = self.buffer[self.current_pos:self.current_pos + self.window_size]
218
+
219
+ # Extract timestamps
220
+ timestamps = [point['timestamp'] for point in window]
221
+
222
+ # Extract data
223
+ data = np.zeros((self.num_channels, self.window_size), dtype=np.float32)
224
+ for i, point in enumerate(window):
225
+ for c in range(self.num_channels):
226
+ channel_key = f'Channel{c+1}'
227
+ if channel_key in point:
228
+ data[c, i] = point[channel_key]
229
+
230
+ # Normalize if requested
231
+ if self.normalize:
232
+ for c in range(self.num_channels):
233
+ if self.channel_stds[c] > 0:
234
+ data[c] = (data[c] - self.channel_means[c]) / self.channel_stds[c]
235
+
236
+ yield timestamps, data
237
+
238
+ # Advance by stride
239
+ self.current_pos += self.stride
240
+
241
+ # Remove processed data points that are no longer needed
242
+ if self.current_pos > 0:
243
+ # Keep the last (window_size - stride) points for the next window
244
+ keep_from = max(0, self.current_pos - (self.window_size - self.stride))
245
+ self.buffer = self.buffer[keep_from:]
246
+ self.current_pos = max(0, self.current_pos - keep_from)
247
+
248
+
249
+ class EEGEmbeddingStream:
250
+ """
251
+ Stream of EEG embeddings from a live CSV file.
252
+ """
253
+ def __init__(
254
+ self,
255
+ file_path: str,
256
+ model_path: str,
257
+ window_size: int = 256,
258
+ stride: int = 64,
259
+ normalizer_params: Dict[str, List[float]] = None,
260
+ poll_interval: float = 0.1,
261
+ batch_size: int = 32,
262
+ normalize: bool = True,
263
+ device: str = None,
264
+ start_from_timestamp: str = None,
265
+ force_sequence_length: int = None # New parameter
266
+ ):
267
+ """
268
+ Initialize the EEG embedding stream.
269
+
270
+ Args:
271
+ file_path: Path to the CSV file to watch
272
+ model_path: Path to the trained model checkpoint
273
+ window_size: Number of data points in each window
274
+ stride: Number of data points to advance between windows
275
+ normalizer_params: Dictionary with 'means' and 'stds' for each channel
276
+ If None, default values will be used
277
+ poll_interval: How often to check for new data (in seconds)
278
+ batch_size: How many windows to encode at once
279
+ normalize: Whether to normalize the data
280
+ device: Device to use for encoding ('cuda' or 'cpu')
281
+ start_from_timestamp: Only process data from this timestamp onwards
282
+ force_sequence_length: Force the model to use this sequence length (to match training)
283
+ """
284
+ self.file_path = file_path
285
+ self.poll_interval = poll_interval
286
+ self.window_size = window_size
287
+ self.stride = stride
288
+ self.normalize = normalize
289
+ self.batch_size = batch_size
290
+ self.start_from_timestamp = start_from_timestamp
291
+
292
+ # Set default normalizer parameters if not provided
293
+ if normalizer_params is None:
294
+ self.channel_means = [-70446.6562, -51197.2070, -42351.2812, -32628.9004, -58139.0547,
295
+ -56271.2852, -48508.2305, -57654.8711, -69949.6484, -49663.8398,
296
+ -43010.7070, -30252.7207, -56295.6250, -56075.9375, -48470.3086,
297
+ -56338.5820]
298
+ self.channel_stds = [76037.4453, 56048.1445, 71950.6328, 60051.6523, 64877.7422,
299
+ 59371.3203, 56742.6055, 62344.4805, 75861.9141, 55614.6055,
300
+ 70795.6719, 59312.4180, 64780.2109, 60292.6992, 56598.4609,
301
+ 61472.3633]
302
+ else:
303
+ self.channel_means = normalizer_params['means']
304
+ self.channel_stds = normalizer_params['stds']
305
+
306
+ # Determine the number of channels
307
+ self.num_channels = len(self.channel_means)
308
+
309
+ # Initialize components
310
+ self.file_watcher = EEGFileWatcher(file_path, poll_interval)
311
+ self.window_processor = SlidingWindowProcessor(
312
+ window_size, stride, self.num_channels,
313
+ self.channel_means, self.channel_stds, normalize
314
+ )
315
+
316
+ # Set the device
317
+ if device is None:
318
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
319
+ else:
320
+ self.device = torch.device(device)
321
+
322
+ # Load the encoder with the forced sequence length
323
+ self.encoder = EncoderExtractor(model_path, self.device, force_sequence_length)
324
+
325
+ # CSV header
326
+ self.header = None
327
+
328
+ # Running flag
329
+ self.running = False
330
+
331
+ # Statistics
332
+ self.stats = {
333
+ 'windows_processed': 0,
334
+ 'start_time': None,
335
+ 'last_timestamp': None
336
+ }
337
+
338
+ def start(self):
339
+ """Start the embedding stream."""
340
+ if self.running:
341
+ return
342
+
343
+ self.running = True
344
+ self.stats['start_time'] = time.time()
345
+ self.file_watcher.start()
346
+
347
+ def stop(self):
348
+ """Stop the embedding stream."""
349
+ self.running = False
350
+ self.file_watcher.stop()
351
+
352
+ if self.stats['start_time'] is not None:
353
+ elapsed = time.time() - self.stats['start_time']
354
+ windows_processed = self.stats['windows_processed']
355
+ if windows_processed > 0 and elapsed > 0:
356
+ rate = windows_processed / elapsed
357
+ logger.info(f"Processed {windows_processed} windows in {elapsed:.2f}s ({rate:.2f} windows/s)")
358
+
359
+ def _parse_csv_line(self, line: str) -> Dict[str, Union[str, float]]:
360
+ """
361
+ Parse a CSV line into a data point.
362
+
363
+ Args:
364
+ line: CSV line
365
+
366
+ Returns:
367
+ Dictionary with timestamp and channel values, or None if header
368
+ """
369
+ if not self.header:
370
+ # First line is the header
371
+ self.header = line.split(',')
372
+ logger.info(f"CSV header: {self.header}")
373
+ return None
374
+
375
+ values = line.split(',')
376
+ if len(values) != len(self.header):
377
+ logger.warning(f"Line has wrong number of values: {line}")
378
+ return None
379
+
380
+ data_point = {}
381
+ for i, column in enumerate(self.header):
382
+ if i == 0:
383
+ # Timestamp column
384
+ data_point['timestamp'] = values[i]
385
+
386
+ # Skip if before start_from_timestamp
387
+ if self.start_from_timestamp and values[i] < self.start_from_timestamp:
388
+ return None
389
+ else:
390
+ # Channel column
391
+ try:
392
+ data_point[column] = float(values[i])
393
+ except ValueError:
394
+ logger.warning(f"Could not parse value {values[i]} as float for column {column}")
395
+ data_point[column] = 0.0
396
+
397
+ return data_point
398
+
399
+ def get_embeddings(self, timeout: Optional[float] = None) -> Generator[Dict[str, Any], None, None]:
400
+ """
401
+ Get embeddings for new data.
402
+
403
+ Args:
404
+ timeout: How long to wait for new data (in seconds). None means don't wait.
405
+
406
+ Yields:
407
+ Dictionary with window information and embedding
408
+ """
409
+ if not self.running:
410
+ self.start()
411
+
412
+ # Get new lines from the file
413
+ new_lines = self.file_watcher.get_new_lines(timeout)
414
+ if not new_lines:
415
+ return
416
+
417
+ # Parse CSV lines
418
+ data_points = []
419
+ for line in new_lines:
420
+ data_point = self._parse_csv_line(line)
421
+ if data_point:
422
+ data_points.append(data_point)
423
+ self.stats['last_timestamp'] = data_point['timestamp']
424
+
425
+ if not data_points:
426
+ return
427
+
428
+ # Add to the window processor
429
+ self.window_processor.add_data(data_points)
430
+
431
+ # Get windows and batch them for embedding
432
+ windows = list(self.window_processor.get_windows())
433
+ if not windows:
434
+ return
435
+
436
+ for batch_start in range(0, len(windows), self.batch_size):
437
+ batch_end = min(batch_start + self.batch_size, len(windows))
438
+ batch = windows[batch_start:batch_end]
439
+
440
+ # Extract timestamps and data
441
+ batch_timestamps = [window[0] for window in batch]
442
+ batch_data = [window[1] for window in batch]
443
+
444
+ # Convert to tensors
445
+ batch_tensor = torch.tensor(np.array(batch_data), dtype=torch.float32)
446
+
447
+ # Generate embeddings
448
+ embeddings = self.encoder.embed(batch_tensor)
449
+
450
+ # Convert to numpy and yield
451
+ embeddings_np = embeddings.cpu().numpy()
452
+
453
+ for i in range(len(batch)):
454
+ self.stats['windows_processed'] += 1
455
+ yield {
456
+ 'start_timestamp': batch_timestamps[i][0],
457
+ 'end_timestamp': batch_timestamps[i][-1],
458
+ 'embedding': embeddings_np[i],
459
+ 'window_index': self.stats['windows_processed'] - 1
460
+ }
461
+
462
+ def get_streaming_embeddings(self, callback: Optional[Callable[[Dict[str, Any]], None]] = None) -> Generator[Dict[str, Any], None, None]:
463
+ """
464
+ Continuously generate embeddings and call the callback function with each one.
465
+
466
+ Args:
467
+ callback: Function to call with each embedding. If None, embeddings are yielded.
468
+
469
+ Yields:
470
+ If no callback is provided, yields dictionaries with window information and embedding
471
+ """
472
+ self.start()
473
+
474
+ try:
475
+ while self.running:
476
+ any_embeddings = False
477
+ for embedding in self.get_embeddings(timeout=self.poll_interval):
478
+ any_embeddings = True
479
+ if callback:
480
+ callback(embedding)
481
+ else:
482
+ yield embedding
483
+
484
+ if not any_embeddings:
485
+ # No new embeddings, just wait a bit
486
+ time.sleep(self.poll_interval)
487
+ finally:
488
+ self.stop()
489
+
490
+ def get_stats(self) -> Dict[str, Any]:
491
+ """
492
+ Get statistics about the streaming process.
493
+
494
+ Returns:
495
+ Dictionary with statistics
496
+ """
497
+ stats = dict(self.stats)
498
+ if stats['start_time'] is not None:
499
+ stats['elapsed'] = time.time() - stats['start_time']
500
+ if stats['windows_processed'] > 0 and stats['elapsed'] > 0:
501
+ stats['windows_per_second'] = stats['windows_processed'] / stats['elapsed']
502
+ return stats
503
+
504
+ # Example usage
505
+ def example():
506
+ def handle_embedding(embedding):
507
+ """Callback function to handle new embeddings."""
508
+ start_time = embedding['start_timestamp']
509
+ end_time = embedding['end_timestamp']
510
+ embedding_data = embedding['embedding']
511
+
512
+ print(f"Got embedding for window from {start_time} to {end_time}")
513
+ print(f"Embedding shape: {embedding_data.shape}")
514
+ print(f"First few values: {embedding_data.flatten()[:5]}")
515
+
516
+ # Create the embedding stream
517
+ stream = EEGEmbeddingStream(
518
+ file_path="eeg_data.csv",
519
+ model_path="models/eeg_autoencoder.pth",
520
+ window_size=256, # Number of data points in each window
521
+ stride=128, # How much to advance between windows
522
+ poll_interval=0.5 # Check for new data every 0.5 seconds
523
+ )
524
+
525
+ print("Starting embedding stream...")
526
+ print("Press Ctrl+C to stop")
527
+
528
+ try:
529
+ # Method 1: Using callback
530
+ stream.get_streaming_embeddings(callback=handle_embedding)
531
+
532
+ # Method 2: Using generator
533
+ # for embedding in stream.get_streaming_embeddings():
534
+ # handle_embedding(embedding)
535
+ except KeyboardInterrupt:
536
+ print("\nStopping...")
537
+ finally:
538
+ stream.stop()
539
+ print("Stopped")
540
+
541
+
542
+ if __name__ == "__main__":
543
+ example()
embed.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Text embedding script with SQLite storage (using numpy buffers)
4
+ Now with flexible text splitting modes!
5
+
6
+ Usage: python embed_flex.py <directory_path> <db_path> [--split-mode MODE]
7
+
8
+ Split modes:
9
+ - line (default): Each non-empty line becomes one embedding
10
+ - block: Double-newline separated blocks (paragraphs)
11
+ - sentence: Split on sentence boundaries (., !, ?)
12
+ - chunk: Fixed token-ish chunks with overlap (for long docs)
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import argparse
18
+ import sqlite3
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+ from transformers import AutoModel, AutoTokenizer
22
+ import torch
23
+ import gc
24
+ import random
25
+ import re
26
+
27
+ INITIAL_BATCH_SIZE = 128
28
+ MIN_BATCH_SIZE = 1
29
+ SHUFFLE_SEED = 42
30
+
31
+ # Chunk mode settings
32
+ DEFAULT_CHUNK_SIZE = 512 # characters
33
+ DEFAULT_CHUNK_OVERLAP = 64
34
+
35
+
36
+ def create_index_if_possible(cursor):
37
+ try:
38
+ cursor.execute("""
39
+ CREATE INDEX IF NOT EXISTS idx_content ON messages(content)
40
+ """)
41
+ except sqlite3.OperationalError:
42
+ pass
43
+
44
+
45
+ def get_existing_content(cursor):
46
+ try:
47
+ cursor.execute("SELECT content FROM messages")
48
+ return {row[0] for row in cursor.fetchall()}
49
+ except sqlite3.OperationalError:
50
+ return set()
51
+
52
+
53
+ def clear_gpu_memory():
54
+ if torch.cuda.is_available():
55
+ torch.cuda.empty_cache()
56
+ gc.collect()
57
+
58
+
59
+ # =============================================================================
60
+ # SPLITTING STRATEGIES
61
+ # =============================================================================
62
+
63
+ def split_by_lines(text):
64
+ """Original behavior: each non-empty line is one unit."""
65
+ lines = []
66
+ for line in text.split('\n'):
67
+ line = line.strip()
68
+ if line:
69
+ lines.append(line)
70
+ return lines
71
+
72
+
73
+ def split_by_blocks(text):
74
+ blocks = re.split(r'\n\s*\n+', text)
75
+ result = []
76
+ for block in blocks:
77
+ cleaned = ' '.join(block.split())
78
+ if cleaned:
79
+ result.append(cleaned)
80
+ return result
81
+
82
+
83
+ def split_by_sentences(text):
84
+ """
85
+ Split on sentence boundaries.
86
+ Handles common abbreviations somewhat gracefully.
87
+ """
88
+ # First normalize whitespace
89
+ text = ' '.join(text.split())
90
+
91
+ # Sentence-ending pattern (handles ., !, ? followed by space and capital or end)
92
+ # This is imperfect but reasonable for most text
93
+ pattern = r'(?<=[.!?])\s+(?=[A-Z])'
94
+
95
+ sentences = re.split(pattern, text)
96
+ result = []
97
+ for sent in sentences:
98
+ sent = sent.strip()
99
+ if sent:
100
+ result.append(sent)
101
+ return result
102
+
103
+
104
+ def split_by_chunks(text, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP):
105
+ """
106
+ Fixed-size character chunks with overlap.
107
+ Good for long documents where you want sliding window coverage.
108
+ """
109
+ # Normalize whitespace
110
+ text = ' '.join(text.split())
111
+
112
+ if len(text) <= chunk_size:
113
+ return [text] if text else []
114
+
115
+ chunks = []
116
+ start = 0
117
+ while start < len(text):
118
+ end = start + chunk_size
119
+ chunk = text[start:end]
120
+
121
+ # Try to break at word boundary if not at end
122
+ if end < len(text):
123
+ last_space = chunk.rfind(' ')
124
+ if last_space > chunk_size // 2: # Only if we're not losing too much
125
+ chunk = chunk[:last_space]
126
+ end = start + last_space
127
+
128
+ chunk = chunk.strip()
129
+ if chunk:
130
+ chunks.append(chunk)
131
+
132
+ # Move forward with overlap
133
+ start = end - overlap
134
+ if start <= chunks[-1] if chunks else 0: # Prevent infinite loop
135
+ start = end
136
+
137
+ return chunks
138
+
139
+
140
+ def get_splitter(mode, chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP):
141
+ """Return the appropriate splitting function."""
142
+ if mode == 'line':
143
+ return split_by_lines
144
+ elif mode == 'block':
145
+ return split_by_blocks
146
+ elif mode == 'sentence':
147
+ return split_by_sentences
148
+ elif mode == 'chunk':
149
+ return lambda text: split_by_chunks(text, chunk_size, chunk_overlap)
150
+ else:
151
+ raise ValueError(f"Unknown split mode: {mode}")
152
+
153
+
154
+ # =============================================================================
155
+ # PROCESSING
156
+ # =============================================================================
157
+
158
+ def process_batch(model, batch_lines, cursor, task="text-matching"):
159
+ try:
160
+ with torch.no_grad():
161
+ batch_embeddings = model.encode(batch_lines, task=task, device="cuda")
162
+
163
+ for line_text, embedding in zip(batch_lines, batch_embeddings):
164
+ try:
165
+ cursor.execute(
166
+ "INSERT INTO messages (content, role) VALUES (?, ?)",
167
+ (line_text, "system")
168
+ )
169
+ message_id = cursor.lastrowid
170
+
171
+ if torch.is_tensor(embedding):
172
+ embedding_np = embedding.cpu().numpy()
173
+ elif not isinstance(embedding, np.ndarray):
174
+ embedding_np = np.array(embedding)
175
+ else:
176
+ embedding_np = embedding
177
+
178
+ embedding_blob = embedding_np.astype(np.float32).tobytes()
179
+
180
+ cursor.execute(
181
+ "INSERT INTO embeddings (message_id, embedding) VALUES (?, ?)",
182
+ (message_id, embedding_blob)
183
+ )
184
+ except sqlite3.Error as e:
185
+ print(f"Error processing entry: {e}")
186
+ continue
187
+
188
+ return True
189
+
190
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
191
+ if "out of memory" in str(e).lower():
192
+ clear_gpu_memory()
193
+ return False
194
+ else:
195
+ raise
196
+
197
+
198
+ def convert_existing_pickles(cursor, conn):
199
+ """Convert any existing pickle embeddings to numpy buffers"""
200
+ import pickle
201
+
202
+ def is_numpy_buffer(blob):
203
+ try:
204
+ np_array = np.frombuffer(blob, dtype=np.float32)
205
+ if np_array.ndim >= 1 and len(np_array) > 0:
206
+ return True
207
+ except Exception:
208
+ pass
209
+ return False
210
+
211
+ def unpickle_to_numpy(blob):
212
+ try:
213
+ pickled_obj = pickle.loads(blob)
214
+ if isinstance(pickled_obj, np.ndarray):
215
+ return pickled_obj
216
+ elif torch.is_tensor(pickled_obj):
217
+ return pickled_obj.cpu().numpy()
218
+ else:
219
+ return np.array(pickled_obj)
220
+ except Exception:
221
+ return None
222
+
223
+ cursor.execute("SELECT COUNT(*) FROM embeddings")
224
+ total_embeddings = cursor.fetchone()[0]
225
+
226
+ if total_embeddings == 0:
227
+ return
228
+
229
+ print(f"Checking {total_embeddings} existing embeddings for pickle->numpy conversion...")
230
+
231
+ cursor.execute("SELECT message_id, embedding FROM embeddings")
232
+ embeddings = cursor.fetchall()
233
+
234
+ converted_count = 0
235
+ for message_id, embedding_blob in embeddings:
236
+ if not is_numpy_buffer(embedding_blob):
237
+ numpy_array = unpickle_to_numpy(embedding_blob)
238
+
239
+ if numpy_array is not None:
240
+ np_buffer = numpy_array.astype(np.float32).tobytes()
241
+ cursor.execute(
242
+ "UPDATE embeddings SET embedding = ? WHERE message_id = ?",
243
+ (np_buffer, message_id)
244
+ )
245
+ converted_count += 1
246
+
247
+ if converted_count > 0:
248
+ conn.commit()
249
+ print(f"Converted {converted_count} pickle embeddings to numpy buffers")
250
+
251
+
252
+ def main():
253
+ parser = argparse.ArgumentParser(
254
+ description='Generate embeddings for text files with flexible splitting modes',
255
+ formatter_class=argparse.RawDescriptionHelpFormatter,
256
+ epilog="""
257
+ Split Modes:
258
+ line Each non-empty line = one embedding (default, original behavior)
259
+ block Double-newline separated paragraphs = one embedding each
260
+ sentence Split on sentence boundaries (., !, ?)
261
+ chunk Fixed-size character chunks with overlap (good for long docs)
262
+
263
+ Examples:
264
+ %(prog)s ~/docs embeddings.db # line mode (default)
265
+ %(prog)s ~/docs embeddings.db --split-mode block # paragraph mode
266
+ %(prog)s ~/docs embeddings.db --split-mode sentence # sentence mode
267
+ %(prog)s ~/docs embeddings.db --split-mode chunk --chunk-size 1024 --chunk-overlap 128
268
+ """
269
+ )
270
+
271
+ parser.add_argument('directory',
272
+ help='Directory containing .txt files to process')
273
+ parser.add_argument('database',
274
+ help='SQLite database path (will be created if not exists)')
275
+ parser.add_argument('--split-mode', '-s',
276
+ choices=['line', 'block', 'sentence', 'chunk'],
277
+ default='line',
278
+ help='Text splitting strategy (default: line)')
279
+ parser.add_argument('--chunk-size', type=int, default=DEFAULT_CHUNK_SIZE,
280
+ help=f'Character chunk size for chunk mode (default: {DEFAULT_CHUNK_SIZE})')
281
+ parser.add_argument('--chunk-overlap', type=int, default=DEFAULT_CHUNK_OVERLAP,
282
+ help=f'Overlap between chunks (default: {DEFAULT_CHUNK_OVERLAP})')
283
+ parser.add_argument('--batch-size', type=int, default=INITIAL_BATCH_SIZE,
284
+ help=f'Initial batch size (default: {INITIAL_BATCH_SIZE})')
285
+ parser.add_argument('--task', default='text-matching',
286
+ help='Encoding task (default: text-matching)')
287
+ parser.add_argument('--model', default='jinaai/jina-embeddings-v3',
288
+ help='Model name (default: jinaai/jina-embeddings-v3)')
289
+ parser.add_argument('--skip-conversion', action='store_true',
290
+ help='Skip checking/converting existing pickle embeddings')
291
+
292
+ args = parser.parse_args()
293
+
294
+ directory_path = os.path.expanduser(args.directory)
295
+ db_path = os.path.expanduser(args.database)
296
+
297
+ if not os.path.isdir(directory_path):
298
+ print(f"Error: Directory '{directory_path}' does not exist")
299
+ sys.exit(1)
300
+
301
+ print(f"Processing directory: {directory_path}")
302
+ print(f"Database: {db_path}")
303
+ print(f"Split mode: {args.split_mode}")
304
+ if args.split_mode == 'chunk':
305
+ print(f"Chunk size: {args.chunk_size}, overlap: {args.chunk_overlap}")
306
+ print(f"Initial batch size: {args.batch_size}")
307
+
308
+ # Get splitter function
309
+ splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap)
310
+
311
+ # Initialize model
312
+ print(f"Loading model: {args.model}")
313
+ model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda()
314
+ model.eval()
315
+
316
+ # Set up SQLite
317
+ conn = sqlite3.connect(db_path)
318
+ cursor = conn.cursor()
319
+
320
+ cursor.execute("""
321
+ CREATE TABLE IF NOT EXISTS messages (
322
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
323
+ content TEXT,
324
+ role TEXT
325
+ )
326
+ """)
327
+
328
+ cursor.execute("""
329
+ CREATE TABLE IF NOT EXISTS embeddings (
330
+ message_id INTEGER PRIMARY KEY,
331
+ embedding BLOB,
332
+ FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE
333
+ )
334
+ """)
335
+ conn.commit()
336
+
337
+ create_index_if_possible(cursor)
338
+ conn.commit()
339
+
340
+ if not args.skip_conversion:
341
+ convert_existing_pickles(cursor, conn)
342
+
343
+ existing_content = get_existing_content(cursor)
344
+ print(f"Already processed: {len(existing_content)} entries")
345
+
346
+ # Collect all text units using the selected splitter
347
+ all_units = []
348
+ txt_files = [f for f in os.listdir(directory_path) if f.lower().endswith(".txt")]
349
+
350
+ if not txt_files:
351
+ print(f"Warning: No .txt files found in {directory_path}")
352
+ conn.close()
353
+ return
354
+
355
+ print(f"Found {len(txt_files)} .txt files")
356
+
357
+ for filename in txt_files:
358
+ filepath = os.path.join(directory_path, filename)
359
+ with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
360
+ content = f.read()
361
+ units = splitter(content)
362
+ all_units.extend(units)
363
+
364
+ print(f"Total units from source ({args.split_mode} mode): {len(all_units)}")
365
+
366
+ # Deterministic shuffle
367
+ random.seed(SHUFFLE_SEED)
368
+ random.shuffle(all_units)
369
+
370
+ # Filter out already processed
371
+ new_units = [u for u in all_units if u not in existing_content]
372
+
373
+ print(f"Remaining to process: {len(new_units)}")
374
+
375
+ if not new_units:
376
+ print("Nothing new to process.")
377
+ conn.close()
378
+ return
379
+
380
+ # Process with dynamic batch sizing
381
+ batch_size = args.batch_size
382
+ total = len(new_units)
383
+ task = args.task
384
+
385
+ idx = 0
386
+ processed_count = 0
387
+
388
+ with tqdm(total=total, desc="Processing") as pbar:
389
+ while idx < total:
390
+ end_idx = min(idx + batch_size, total)
391
+ batch = new_units[idx:end_idx]
392
+
393
+ success = process_batch(model, batch, cursor, task)
394
+
395
+ if success:
396
+ try:
397
+ conn.commit()
398
+ except sqlite3.Error as e:
399
+ print(f"Error committing batch: {e}")
400
+
401
+ batch_processed = len(batch)
402
+ pbar.update(batch_processed)
403
+ processed_count += batch_processed
404
+ idx = end_idx
405
+
406
+ if batch_size < args.batch_size and processed_count % (batch_size * 10) == 0:
407
+ batch_size = min(batch_size * 2, args.batch_size)
408
+ else:
409
+ if batch_size > MIN_BATCH_SIZE:
410
+ batch_size = max(batch_size // 2, MIN_BATCH_SIZE)
411
+ print(f"\nOOM - batch size -> {batch_size}")
412
+ else:
413
+ print(f"\nSkipping: {batch[0][:100]}...")
414
+ idx += 1
415
+ pbar.update(1)
416
+ processed_count += 1
417
+
418
+ conn.close()
419
+ print(f"\nProcessed {processed_count:,} entries total.")
420
+ print("All embeddings stored as numpy buffers (float32).")
421
+
422
+
423
+ if __name__ == "__main__":
424
+ main()
morphism.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ morphism — EEG-to-text semantic search
4
+
5
+ Usage:
6
+ morphism record [options]
7
+ morphism index create|info|rebuild [options]
8
+ morphism decode [options]
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import argparse
14
+
15
+ from retrieval import FloodMode, DriftMode, FocusMode, LayeredMode
16
+
17
+ def cmd_record(args):
18
+ """Record EEG data from OpenBCI Cyton+Daisy"""
19
+ from cyton import (
20
+ init_board, set_sample_rate, read_complete_packet, process_packet,
21
+ start_sd_recording, stop_sd_recording, create_ssh_connection, sd_record
22
+ )
23
+ import serial, time, io
24
+ from datetime import datetime
25
+
26
+ if args.sd:
27
+ sd_record(args.port, args.duration, args.sample_rate)
28
+ return
29
+
30
+ filename = args.output
31
+ if filename is None:
32
+ filename = f"openbci_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
33
+
34
+ ser = serial.Serial(args.port, 115200)
35
+ time.sleep(2)
36
+ init_board(ser)
37
+
38
+ if args.sample_rate != 1000:
39
+ set_sample_rate(ser, args.sample_rate)
40
+
41
+ ssh, sftp, remote_file = None, None, None
42
+ if args.remote:
43
+ ssh = create_ssh_connection()
44
+ if not ssh:
45
+ print("SSH connection failed.")
46
+ return
47
+ sftp = ssh.open_sftp()
48
+ remote_file = sftp.open(filename, 'w')
49
+
50
+ header = "Timestamp," + ",".join(f"Channel{i+1}" for i in range(16)) + "\n"
51
+ if args.remote:
52
+ remote_file.write(header)
53
+ else:
54
+ with open(filename, 'w') as f:
55
+ f.write(header)
56
+
57
+ ser.write(b'b')
58
+ time.sleep(0.5)
59
+ ser.reset_input_buffer()
60
+
61
+ print(f"Recording to {filename} — Ctrl+C to stop")
62
+
63
+ pkt_count = 0
64
+ t0 = time.time()
65
+ buf = io.StringIO()
66
+ last_flush = time.time()
67
+
68
+ try:
69
+ while True:
70
+ p1 = read_complete_packet(ser)
71
+ if not p1:
72
+ continue
73
+ p2 = read_complete_packet(ser)
74
+ if not p2:
75
+ continue
76
+
77
+ d1, d2 = process_packet(p1), process_packet(p2)
78
+ if not (d1 and d2):
79
+ continue
80
+
81
+ pkt_count += 1
82
+ ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
83
+ line = ts + "," + ",".join(f"{x:.6f}" for x in d1 + d2) + "\n"
84
+
85
+ if args.remote:
86
+ buf.write(line)
87
+ if time.time() - last_flush >= 0.1:
88
+ remote_file.write(buf.getvalue())
89
+ buf = io.StringIO()
90
+ last_flush = time.time()
91
+ else:
92
+ with open(filename, 'a') as f:
93
+ f.write(line)
94
+
95
+ if pkt_count % 125 == 0:
96
+ rate = pkt_count / (time.time() - t0)
97
+ print(f"\r {rate:.1f} Hz, {pkt_count} packets", end='')
98
+
99
+ if ser.in_waiting > 1000:
100
+ ser.reset_input_buffer()
101
+
102
+ except KeyboardInterrupt:
103
+ ser.write(b's')
104
+ ser.close()
105
+ if args.remote:
106
+ if buf.getvalue():
107
+ remote_file.write(buf.getvalue())
108
+ remote_file.close()
109
+ sftp.close()
110
+ ssh.close()
111
+
112
+ elapsed = time.time() - t0
113
+ print(f"\n\nDone — {pkt_count} packets in {elapsed:.1f}s ({pkt_count/elapsed:.1f} Hz)")
114
+ print(f"Saved to {filename}")
115
+
116
+
117
+ def cmd_index(args):
118
+ """Manage the text embedding index"""
119
+ from embed import (
120
+ get_splitter, process_batch, create_index_if_possible,
121
+ get_existing_content, INITIAL_BATCH_SIZE, MIN_BATCH_SIZE, SHUFFLE_SEED
122
+ )
123
+ import sqlite3, numpy as np, random
124
+ from tqdm import tqdm
125
+
126
+ db_path = os.path.expanduser(args.db)
127
+ index_prefix = args.index
128
+
129
+ if args.action == 'info':
130
+ if not os.path.exists(db_path):
131
+ print(f"No database at {db_path}")
132
+ return
133
+
134
+ conn = sqlite3.connect(db_path)
135
+ c = conn.cursor()
136
+ c.execute("SELECT COUNT(*) FROM messages")
137
+ msg_count = c.fetchone()[0]
138
+ c.execute("SELECT COUNT(*) FROM embeddings")
139
+ emb_count = c.fetchone()[0]
140
+ conn.close()
141
+
142
+ index_exists = os.path.exists(f"{index_prefix}.index")
143
+
144
+ print(f"Database: {db_path}")
145
+ print(f"Messages: {msg_count:,}")
146
+ print(f"Embeddings: {emb_count:,}")
147
+ print(f"FAISS index: {'exists' if index_exists else 'not built'} ({index_prefix}.index)")
148
+ return
149
+
150
+ if args.action in ('create', 'rebuild'):
151
+ corpus = os.path.expanduser(args.corpus)
152
+ if not os.path.isdir(corpus):
153
+ print(f"Not a directory: {corpus}")
154
+ sys.exit(1)
155
+
156
+ splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap)
157
+
158
+ print(f"Loading model: {args.model}")
159
+ from transformers import AutoModel
160
+ model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda()
161
+ model.eval()
162
+
163
+ conn = sqlite3.connect(db_path)
164
+ c = conn.cursor()
165
+
166
+ if args.action == 'rebuild':
167
+ print("Dropping existing data...")
168
+ c.execute("DELETE FROM embeddings")
169
+ c.execute("DELETE FROM messages")
170
+ conn.commit()
171
+
172
+ c.execute("""CREATE TABLE IF NOT EXISTS messages (
173
+ id INTEGER PRIMARY KEY AUTOINCREMENT, content TEXT, role TEXT)""")
174
+ c.execute("""CREATE TABLE IF NOT EXISTS embeddings (
175
+ message_id INTEGER PRIMARY KEY, embedding BLOB,
176
+ FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE)""")
177
+ conn.commit()
178
+ create_index_if_possible(c)
179
+ conn.commit()
180
+
181
+ existing = get_existing_content(c)
182
+ print(f"Already indexed: {len(existing):,}")
183
+
184
+ txt_files = [f for f in os.listdir(corpus) if f.lower().endswith('.txt')]
185
+ if not txt_files:
186
+ print(f"No .txt files in {corpus}")
187
+ conn.close()
188
+ return
189
+
190
+ units = []
191
+ for fn in txt_files:
192
+ with open(os.path.join(corpus, fn), 'r', encoding='utf-8', errors='ignore') as f:
193
+ units.extend(splitter(f.read()))
194
+
195
+ random.seed(SHUFFLE_SEED)
196
+ random.shuffle(units)
197
+ new_units = [u for u in units if u not in existing]
198
+ print(f"New units to embed: {len(new_units):,}")
199
+
200
+ if not new_units:
201
+ print("Nothing new.")
202
+ conn.close()
203
+ return
204
+
205
+ batch_size = args.batch_size
206
+ idx = 0
207
+ processed = 0
208
+
209
+ with tqdm(total=len(new_units), desc="Embedding") as pbar:
210
+ while idx < len(new_units):
211
+ batch = new_units[idx:idx + batch_size]
212
+ ok = process_batch(model, batch, c, args.task)
213
+ if ok:
214
+ conn.commit()
215
+ pbar.update(len(batch))
216
+ processed += len(batch)
217
+ idx += len(batch)
218
+ else:
219
+ if batch_size > MIN_BATCH_SIZE:
220
+ batch_size = max(batch_size // 2, MIN_BATCH_SIZE)
221
+ print(f"\nOOM — batch size → {batch_size}")
222
+ else:
223
+ idx += 1
224
+ pbar.update(1)
225
+ processed += 1
226
+
227
+ conn.close()
228
+ print(f"Embedded {processed:,} units.")
229
+
230
+ print("Building FAISS index...")
231
+ _build_faiss_index(db_path, index_prefix)
232
+ print("Done.")
233
+
234
+
235
+ def _build_faiss_index(db_path, index_prefix):
236
+ """Build FAISS index from the embeddings database"""
237
+ import sqlite3, numpy as np
238
+ from decode import EmbeddingIndex
239
+
240
+ conn = sqlite3.connect(db_path)
241
+ c = conn.cursor()
242
+ c.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id")
243
+
244
+ embeddings, ids = [], []
245
+ for mid, blob in c.fetchall():
246
+ embeddings.append(np.frombuffer(blob, dtype=np.float32))
247
+ ids.append(mid)
248
+ conn.close()
249
+
250
+ if not embeddings:
251
+ print(" No embeddings found.")
252
+ return
253
+
254
+ embeddings = np.vstack(embeddings)
255
+ print(f" {len(embeddings):,} vectors, dim={embeddings.shape[1]}")
256
+
257
+ idx = EmbeddingIndex(dim=embeddings.shape[1])
258
+ idx.add_embeddings(embeddings, ids)
259
+ idx.save(index_prefix)
260
+
261
+ conn2 = sqlite3.connect(db_path)
262
+ c2 = conn2.cursor()
263
+ c2.execute("SELECT COUNT(*) FROM embeddings")
264
+ count = c2.fetchone()[0]
265
+ c2.execute("SELECT MAX(message_id) FROM embeddings")
266
+ max_id = c2.fetchone()[0]
267
+ conn2.close()
268
+ np.savez(f"{index_prefix}_metadata.npz", count=count, max_message_id=max_id)
269
+
270
+
271
+ def cmd_decode(args):
272
+ """Run EEG → text decoding"""
273
+ from decode import EEGSemanticProcessor
274
+
275
+ processor = EEGSemanticProcessor(
276
+ autoencoder_model_path=args.autoencoder,
277
+ semantic_model_path=args.semantic,
278
+ nexus_db_path=args.db,
279
+ embeddings_db_path=args.db,
280
+ index_path=args.index,
281
+ eeg_file_path=args.eeg,
282
+ window_size=args.window_size,
283
+ stride=args.stride,
284
+ batch_size=args.batch_size,
285
+ device=args.device,
286
+ search_k=args.search_k,
287
+ final_k=args.final_k,
288
+ use_raw_eeg=args.raw_eeg,
289
+ input_dim_override=args.input_dim,
290
+ save_vectors=args.save_vectors,
291
+ vector_output_path=args.vector_output,
292
+ last_n_messages=args.last_n,
293
+ )
294
+
295
+ modes = {
296
+ 'flood': lambda: FloodMode(processor.embedding_index, processor.nexus_conn,
297
+ search_k=args.search_k, final_k=args.final_k,
298
+ last_n=args.last_n),
299
+ 'drift': lambda: DriftMode(processor.embedding_index, processor.nexus_conn,
300
+ search_k=64),
301
+ 'focus': lambda: FocusMode(processor.embedding_index, processor.nexus_conn,
302
+ search_k=48),
303
+ 'layered': lambda: LayeredMode(processor.embedding_index, processor.nexus_conn),
304
+ }
305
+
306
+ mode = modes[args.mode]()
307
+
308
+ processor.eeg_stream.start()
309
+ try:
310
+ consecutive_errors = 0
311
+ while True:
312
+ try:
313
+ for embedding_data in processor.eeg_stream.get_embeddings(timeout=0.5):
314
+ try:
315
+ semantic_embedding = processor.process_eeg_embedding(
316
+ embedding_data['embedding'])
317
+
318
+ if processor.save_vectors:
319
+ embedding_np = semantic_embedding.detach().cpu().numpy()
320
+ processor.vectors_list.append(embedding_np)
321
+ processor.timestamps.append({
322
+ 'start': embedding_data['start_timestamp'],
323
+ 'end': embedding_data['end_timestamp']
324
+ })
325
+ if len(processor.vectors_list) % 100 == 0:
326
+ import logging
327
+ logging.getLogger("EEGSemanticStream").info(
328
+ f"Collected {len(processor.vectors_list)} vectors")
329
+ continue
330
+
331
+ lines = mode.step(semantic_embedding)
332
+ if lines:
333
+ output = "\n".join(lines)
334
+ print(output)
335
+ if processor.log_file:
336
+ processor.log_file.write(output + "\n")
337
+ processor.log_file.flush()
338
+
339
+ consecutive_errors = 0
340
+ except Exception as e:
341
+ import sys
342
+ print(f"Error: {e}", file=sys.stderr)
343
+ consecutive_errors += 1
344
+ if consecutive_errors >= 5:
345
+ raise RuntimeError("Too many consecutive errors")
346
+
347
+ import time
348
+ time.sleep(0.01)
349
+ except Exception as e:
350
+ if "Too many" in str(e):
351
+ raise
352
+ import sys, time
353
+ print(f"Error: {e}", file=sys.stderr)
354
+ consecutive_errors += 1
355
+ if consecutive_errors >= 5:
356
+ raise
357
+ time.sleep(1)
358
+ except KeyboardInterrupt:
359
+ pass
360
+ except Exception as e:
361
+ import sys
362
+ print(f"Fatal: {e}", file=sys.stderr)
363
+ finally:
364
+ if processor.save_vectors and processor.vectors_list:
365
+ processor.save_vectors_to_disk()
366
+ processor.eeg_stream.stop()
367
+
368
+ def main():
369
+ p = argparse.ArgumentParser(
370
+ prog='morphism',
371
+ description='EEG-to-text semantic search',
372
+ )
373
+ sub = p.add_subparsers(dest='command')
374
+
375
+ # --- record ---
376
+ rec = sub.add_parser('record', help='Record EEG from OpenBCI Cyton+Daisy')
377
+ rec.add_argument('--port', '-p', default='/dev/ttyUSB0')
378
+ rec.add_argument('--output', '-o', default=None)
379
+ rec.add_argument('--sample-rate', type=int, default=1000)
380
+ rec.add_argument('--sd', action='store_true', help='Record to SD card')
381
+ rec.add_argument('--duration', default='G')
382
+ rec.add_argument('--remote', action='store_true', help='Stream via SSH')
383
+
384
+ # --- index ---
385
+ idx = sub.add_parser('index', help='Manage the text embedding index')
386
+ idx.add_argument('action', choices=['create', 'info', 'rebuild'])
387
+ idx.add_argument('--corpus', '-c', default=None)
388
+ idx.add_argument('--db', default='morphism.db')
389
+ idx.add_argument('--index', default='morphism')
390
+ idx.add_argument('--split-mode', default='line',
391
+ choices=['line', 'block', 'sentence', 'chunk'])
392
+ idx.add_argument('--chunk-size', type=int, default=512)
393
+ idx.add_argument('--chunk-overlap', type=int, default=64)
394
+ idx.add_argument('--batch-size', type=int, default=128)
395
+ idx.add_argument('--task', default='text-matching')
396
+ idx.add_argument('--model', default='jinaai/jina-embeddings-v3')
397
+
398
+ # --- decode ---
399
+ dec = sub.add_parser('decode', help='Run EEG → text decoding')
400
+ dec.add_argument('--mode', default='flood', choices=['flood', 'drift', 'focus', 'layered'])
401
+ dec.add_argument('--eeg', '-f', required=True)
402
+ dec.add_argument('--autoencoder', '-a', required=True)
403
+ dec.add_argument('--semantic', '-s', required=True)
404
+ dec.add_argument('--db', default='morphism.db')
405
+ dec.add_argument('--index', default='morphism')
406
+ dec.add_argument('--window-size', type=int, default=624)
407
+ dec.add_argument('--stride', type=int, default=32)
408
+ dec.add_argument('--batch-size', type=int, default=32)
409
+ dec.add_argument('--device', default=None)
410
+ dec.add_argument('--search-k', type=int, default=1024)
411
+ dec.add_argument('--final-k', type=int, default=1024)
412
+ dec.add_argument('--last-n', type=int, default=128)
413
+ dec.add_argument('--raw-eeg', action='store_true')
414
+ dec.add_argument('--input-dim', type=int, default=None)
415
+ dec.add_argument('--save-vectors', action='store_true')
416
+ dec.add_argument('--vector-output', default='semantic_vectors.npz')
417
+
418
+ args = p.parse_args()
419
+
420
+ if args.command is None:
421
+ p.print_help()
422
+ sys.exit(0)
423
+
424
+ if args.command == 'record':
425
+ cmd_record(args)
426
+ elif args.command == 'index':
427
+ if args.action in ('create', 'rebuild') and not args.corpus:
428
+ print("--corpus is required for create/rebuild")
429
+ sys.exit(1)
430
+ cmd_index(args)
431
+ elif args.command == 'decode':
432
+ cmd_decode(args)
433
+
434
+
435
+ if __name__ == '__main__':
436
+ main()