stardust-coder commited on
Commit
e648c90
·
1 Parent(s): 5ec4552

[add] directed graph

Browse files
Files changed (4) hide show
  1. requirements.txt +2 -1
  2. src/loader.py +309 -33
  3. src/metrics.py +73 -0
  4. src/streamlit_app.py +409 -122
requirements.txt CHANGED
@@ -5,4 +5,5 @@ scipy
5
  mne
6
  h5py
7
  networkx
8
- python-louvain
 
 
5
  mne
6
  h5py
7
  networkx
8
+ python-louvain
9
+ tensorpac
src/loader.py CHANGED
@@ -42,12 +42,123 @@ def same_stem(a_name: str, b_name: str) -> bool:
42
  return a_stem == b_stem
43
 
44
 
45
- def extract_electrode_positions_2d(set_path: str) -> np.ndarray:
46
  """
47
- EEGLABファイルから電極位置(2D)を抽出。
48
 
49
  Returns:
50
- pos: (C, 2) 電極の2D座標、取得できない場合はNone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
  try:
53
  # MNEで読み込み
@@ -55,43 +166,186 @@ def extract_electrode_positions_2d(set_path: str) -> np.ndarray:
55
  montage = raw.get_montage()
56
 
57
  if montage is None:
58
- return None
59
 
60
  # 3D座標を取得
61
- pos_3d = montage.get_positions()['ch_pos']
62
 
63
- if not pos_3d:
64
- return None
65
 
66
  # チャンネル名順に並べ替え
67
  ch_names = raw.ch_names
68
- positions = []
69
  for ch_name in ch_names:
70
- if ch_name in pos_3d:
71
- positions.append(pos_3d[ch_name])
72
  else:
73
  # 座標がないチャンネルは原点に配置
74
- positions.append([0, 0, 0])
75
 
76
- positions = np.array(positions)
 
 
 
 
 
77
 
78
  # 3D -> 2D 投影(上から見た図)
79
- # x, y座標を使用し、正規化
80
- pos_2d = positions[:, :2]
81
 
82
- # 正規化: 最大距離が1になるようにスケーリング
83
- max_dist = np.max(np.sqrt(np.sum(pos_2d**2, axis=1)))
84
- if max_dist > 0:
85
- pos_2d = pos_2d / max_dist * 0.85 # 0.85倍で頭の輪郭内に収める
86
 
87
- return pos_2d.astype(np.float32)
88
 
89
  except Exception as e:
90
  print(f"電極位置の抽出に失敗: {e}")
91
- return None
92
 
93
 
94
- def _load_eeglab_hdf5(set_path: str, fdt_path: Optional[str] = None, debug: bool = False) -> Tuple[np.ndarray, float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  """
96
  EEGLABファイルから電極位置(2D)を抽出。
97
 
@@ -322,13 +576,15 @@ def load_eeglab_tc_from_bytes(
322
  set_name: str,
323
  fdt_bytes: Optional[bytes] = None,
324
  fdt_name: Optional[str] = None,
325
- ) -> Tuple[np.ndarray, float, Optional[np.ndarray]]:
326
  """
327
  Load EEGLAB .set (and optional .fdt) from bytes using MNE or h5py.
328
  Returns:
329
- x_tc: (T, C) float32
330
- fs: sampling rate (Hz)
331
- electrode_pos: (C, 2) float32 or None - 電極の2D座標
 
 
332
 
333
  Notes:
334
  - 多くのEEGLABは .set が .fdt を参照するため、同じディレクトリに同名で置く必要があります。
@@ -359,9 +615,13 @@ def load_eeglab_tc_from_bytes(
359
  x_tc = raw.get_data().T # (T,C)
360
 
361
  # 電極位置を取得
362
- electrode_pos = extract_electrode_positions_2d(set_path)
 
 
 
 
363
 
364
- return x_tc.astype(np.float32), fs, electrode_pos
365
 
366
  except Exception as e_raw:
367
  # 2) Epochsとして読む(エポックデータ用)
@@ -375,9 +635,13 @@ def load_eeglab_tc_from_bytes(
375
  x_tc = x_mean.T # (T,C)
376
 
377
  # 電極位置を取得(epochsからも取得可能)
378
- electrode_pos = extract_electrode_positions_2d(set_path)
 
 
 
 
379
 
380
- return x_tc.astype(np.float32), fs, electrode_pos
381
 
382
  except Exception as e_ep:
383
  # 3) HDF5形式として読む(MATLAB v7.3)
@@ -388,20 +652,32 @@ def load_eeglab_tc_from_bytes(
388
  import sys
389
  if 'streamlit' in sys.modules:
390
  debug = True
391
- x_tc, fs = _load_eeglab_hdf5(set_path, fdt_path=fdt_path, debug=debug)
392
 
393
- # HDF5の場合は電極位置を取得できない(参照形式のため)
394
- electrode_pos = None
 
 
 
 
 
 
 
 
 
 
 
395
 
396
- return x_tc, fs, electrode_pos
397
 
398
  except Exception as e_hdf5:
 
399
  # すべて失敗した場合
400
  msg = (
401
  "EEGLABの読み込みに失敗しました。\n"
402
  f"- read_raw_eeglab error: {e_raw}\n"
403
  f"- read_epochs_eeglab error: {e_ep}\n"
404
  f"- HDF5読み込み error: {e_hdf5}\n"
 
405
  )
406
  raise RuntimeError(msg) from e_hdf5
407
 
 
42
  return a_stem == b_stem
43
 
44
 
45
+ def extract_electrode_positions_from_hdf5(set_path: str) -> tuple:
46
  """
47
+ HDF5形式のEEGLABファイルから電極位置を抽出。
48
 
49
  Returns:
50
+ tuple: (pos_2d, pos_3d)
51
+ pos_2d: (C, 2) 電極の2D座標、取得できない場合はNone
52
+ pos_3d: (C, 3) 電極の3D座標、取得できない場合はNone
53
+ """
54
+ if h5py is None:
55
+ return None, None
56
+
57
+ try:
58
+ with h5py.File(set_path, "r") as f:
59
+ # EEGLABのchanlocs構造を探す
60
+ chanlocs_path = None
61
+ for path in ["EEG/chanlocs", "chanlocs"]:
62
+ if path in f:
63
+ chanlocs_path = path
64
+ break
65
+
66
+ if chanlocs_path is None:
67
+ return None, None
68
+
69
+ chanlocs = f[chanlocs_path]
70
+
71
+ # X, Y, Z座標を取得
72
+ xs, ys, zs = [], [], []
73
+
74
+ # パターン1: chanlocs/X, chanlocs/Y, chanlocs/Z が直接データの場合
75
+ if "X" in chanlocs and "Y" in chanlocs and "Z" in chanlocs:
76
+ x_data = chanlocs["X"][()]
77
+ y_data = chanlocs["Y"][()]
78
+ z_data = chanlocs["Z"][()]
79
+
80
+ # 参照型の場合は各参照を辿る
81
+ if x_data.dtype == h5py.ref_dtype:
82
+ for i in range(len(x_data)):
83
+ try:
84
+ x_val = f[x_data[i, 0]][()]
85
+ y_val = f[y_data[i, 0]][()]
86
+ z_val = f[z_data[i, 0]][()]
87
+
88
+ # スカラー値を取得
89
+ x_val = float(x_val.flat[0]) if hasattr(x_val, 'flat') else float(x_val)
90
+ y_val = float(y_val.flat[0]) if hasattr(y_val, 'flat') else float(y_val)
91
+ z_val = float(z_val.flat[0]) if hasattr(z_val, 'flat') else float(z_val)
92
+
93
+ xs.append(x_val)
94
+ ys.append(y_val)
95
+ zs.append(z_val)
96
+ except:
97
+ # 読み込めない座標は0に
98
+ xs.append(0.0)
99
+ ys.append(0.0)
100
+ zs.append(0.0)
101
+ else:
102
+ # 直接数値データの場合
103
+ xs = x_data.flatten().astype(float)
104
+ ys = y_data.flatten().astype(float)
105
+ zs = z_data.flatten().astype(float)
106
+ else:
107
+ return None, None
108
+
109
+ # リストをnumpy配列に変換
110
+ xs = np.array(xs, dtype=float)
111
+ ys = np.array(ys, dtype=float)
112
+ zs = np.array(zs, dtype=float)
113
+
114
+ if len(xs) == 0:
115
+ return None, None
116
+
117
+ # NaN値をチェック(数値型に変換後)
118
+ valid_mask = ~(np.isnan(xs) | np.isnan(ys) | np.isnan(zs))
119
+ if not np.any(valid_mask):
120
+ return None, None
121
+
122
+ # 無効な座標は平均値で置き換え
123
+ if not np.all(valid_mask):
124
+ xs[~valid_mask] = np.nanmean(xs)
125
+ ys[~valid_mask] = np.nanmean(ys)
126
+ zs[~valid_mask] = np.nanmean(zs)
127
+
128
+ # 3D座標を構築
129
+ positions_3d = np.column_stack([xs, ys, zs])
130
+
131
+ # 正規化
132
+ dists = np.sqrt(np.sum(positions_3d**2, axis=1))
133
+ max_dist_3d = np.max(dists[dists > 0]) if np.any(dists > 0) else 1.0
134
+ if max_dist_3d > 0:
135
+ positions_3d = positions_3d / max_dist_3d
136
+
137
+ # 2D投影
138
+ pos_2d = positions_3d[:, :2]
139
+ dists_2d = np.sqrt(np.sum(pos_2d**2, axis=1))
140
+ max_dist_2d = np.max(dists_2d[dists_2d > 0]) if np.any(dists_2d > 0) else 1.0
141
+ if max_dist_2d > 0:
142
+ pos_2d = pos_2d / max_dist_2d * 0.85
143
+
144
+ print(f"HDF5から電極位置を取得: {len(xs)} channels")
145
+ return pos_2d.astype(np.float32), positions_3d.astype(np.float32)
146
+
147
+ except Exception as e:
148
+ print(f"HDF5から電極位置の抽出に失敗: {e}")
149
+ import traceback
150
+ traceback.print_exc()
151
+ return None, None
152
+
153
+
154
+ def extract_electrode_positions_2d(set_path: str):
155
+ """
156
+ EEGLABファイルから電極位置(2D, 3D)を抽出。
157
+
158
+ Returns:
159
+ tuple: (pos_2d, pos_3d)
160
+ pos_2d: (C, 2) 電極の2D座標、取得できない場合はNone
161
+ pos_3d: (C, 3) 電極の3D座標、取得できない場合はNone
162
  """
163
  try:
164
  # MNEで読み込み
 
166
  montage = raw.get_montage()
167
 
168
  if montage is None:
169
+ return None, None
170
 
171
  # 3D座標を取得
172
+ pos_3d_dict = montage.get_positions()['ch_pos']
173
 
174
+ if not pos_3d_dict:
175
+ return None, None
176
 
177
  # チャンネル名順に並べ替え
178
  ch_names = raw.ch_names
179
+ positions_3d = []
180
  for ch_name in ch_names:
181
+ if ch_name in pos_3d_dict:
182
+ positions_3d.append(pos_3d_dict[ch_name])
183
  else:
184
  # 座標がないチャンネルは原点に配置
185
+ positions_3d.append([0, 0, 0])
186
 
187
+ positions_3d = np.array(positions_3d)
188
+
189
+ # 3D座標を正規化
190
+ max_dist_3d = np.max(np.sqrt(np.sum(positions_3d**2, axis=1)))
191
+ if max_dist_3d > 0:
192
+ positions_3d = positions_3d / max_dist_3d
193
 
194
  # 3D -> 2D 投影(上から見た図)
195
+ pos_2d = positions_3d[:, :2]
 
196
 
197
+ # 2D座標を正規化: 最大距離が0.85になるようにスケーリング
198
+ max_dist_2d = np.max(np.sqrt(np.sum(pos_2d**2, axis=1)))
199
+ if max_dist_2d > 0:
200
+ pos_2d = pos_2d / max_dist_2d * 0.85
201
 
202
+ return pos_2d.astype(np.float32), positions_3d.astype(np.float32)
203
 
204
  except Exception as e:
205
  print(f"電極位置の抽出に失敗: {e}")
206
+ return None, None
207
 
208
 
209
+ def _load_eeglab_hdf5(set_path: str, fdt_path: Optional[str] = None, debug: bool = False):
210
+ """
211
+ Load EEGLAB .set file saved in MATLAB v7.3 (HDF5) format using h5py.
212
+ Returns: (x_tc, fs) where x_tc is (T, C)
213
+ """
214
+ if h5py is None:
215
+ raise RuntimeError("EEGLAB .set ファイルが MATLAB v7.3 (HDF5) 形式ですが、h5py がインストールされていません。pip install h5py を実行してください。")
216
+
217
+ with h5py.File(set_path, "r") as f:
218
+ # デバッグ: ファイル構造を表示
219
+ if debug:
220
+ print("=== HDF5 file structure ===")
221
+ def print_structure(name, obj):
222
+ if isinstance(obj, h5py.Dataset):
223
+ print(f"Dataset: {name}, shape: {obj.shape}, dtype: {obj.dtype}")
224
+ elif isinstance(obj, h5py.Group):
225
+ print(f"Group: {name}")
226
+ f.visititems(print_structure)
227
+ print("===========================")
228
+
229
+ # サンプリングレートを取得
230
+ fs = None
231
+ for path in ["EEG/srate", "srate"]:
232
+ if path in f:
233
+ srate_data = f[path]
234
+ if isinstance(srate_data, h5py.Dataset):
235
+ val = srate_data[()]
236
+ # 配列の場合は最初の要素を取得
237
+ fs = float(val.flat[0]) if hasattr(val, 'flat') else float(val)
238
+ break
239
+
240
+ if fs is None:
241
+ raise ValueError("サンプリングレート (srate) が見つかりません")
242
+
243
+ # チャンネル数を取得
244
+ nbchan = None
245
+ for path in ["EEG/nbchan", "nbchan"]:
246
+ if path in f:
247
+ nbchan_data = f[path]
248
+ if isinstance(nbchan_data, h5py.Dataset):
249
+ val = nbchan_data[()]
250
+ nbchan = int(val.flat[0]) if hasattr(val, 'flat') else int(val)
251
+ break
252
+
253
+ # サンプル数を取得
254
+ pnts = None
255
+ for path in ["EEG/pnts", "pnts"]:
256
+ if path in f:
257
+ pnts_data = f[path]
258
+ if isinstance(pnts_data, h5py.Dataset):
259
+ val = pnts_data[()]
260
+ pnts = int(val.flat[0]) if hasattr(val, 'flat') else int(val)
261
+ break
262
+
263
+ if debug:
264
+ print(f"nbchan: {nbchan}, pnts: {pnts}, fs: {fs}")
265
+
266
+ # データを取得 - まず .set 内を確認
267
+ data = None
268
+ data_shape = None
269
+
270
+ if debug:
271
+ print(f"Checking for data, fdt_path provided: {fdt_path is not None}")
272
+ if fdt_path:
273
+ print(f"fdt_path exists: {os.path.exists(fdt_path)}")
274
+
275
+ # パターン1: EEG/data が参照配列の場合、各参照を辿る
276
+ if "EEG" in f and "data" in f["EEG"]:
277
+ data_ref = f["EEG"]["data"]
278
+ if isinstance(data_ref, h5py.Dataset):
279
+ if debug:
280
+ print(f"EEG/data dtype: {data_ref.dtype}, shape: {data_ref.shape}, size: {data_ref.size}")
281
+
282
+ if data_ref.dtype == h5py.ref_dtype:
283
+ # 参照の場合 - 通常は .fdt ファイルを指す
284
+ if debug:
285
+ print("EEG/data is reference type - data should be in .fdt file")
286
+ # .fdt ファイルが必要
287
+ if fdt_path is not None and os.path.exists(fdt_path):
288
+ data = _load_fdt_file(fdt_path, nbchan, pnts, debug=debug)
289
+ else:
290
+ raise ValueError(".fdt ファイルが必要ですが見つかりません。.set と .fdt の両方をアップロードしてください。")
291
+ elif data_ref.size > 100: # 参照配列ではなく実データ
292
+ data = data_ref[()]
293
+ data_shape = data.shape
294
+ if debug:
295
+ print(f"EEG/data contains actual data, shape: {data_shape}")
296
+ else:
297
+ # 小さい配列 = 参照リスト、.fdtファイルが必要
298
+ if debug:
299
+ print(f"EEG/data is small array (size={data_ref.size}), assuming reference to .fdt")
300
+ if fdt_path is not None and os.path.exists(fdt_path):
301
+ data = _load_fdt_file(fdt_path, nbchan, pnts, debug=debug)
302
+ else:
303
+ raise ValueError(".fdt ファイルが必要ですが見つかりません。.set と .fdt の両方をアップロードしてください。")
304
+
305
+ # パターン2: 直接 data
306
+ if data is None and "data" in f:
307
+ data_obj = f["data"]
308
+ if isinstance(data_obj, h5py.Dataset):
309
+ data = data_obj[()]
310
+ data_shape = data.shape
311
+
312
+ if data is None:
313
+ raise ValueError("EEGデータが見つかりません。.fdt ファイルが必要な可能性があります。")
314
+
315
+ if debug:
316
+ print(f"Data shape: {data.shape if hasattr(data, 'shape') else 'loaded from fdt'}")
317
+
318
+ # データの形状を調整
319
+ if data.ndim != 2:
320
+ raise ValueError(f"予期しないデータ次元: {data.ndim}")
321
+
322
+ dim0, dim1 = data.shape
323
+
324
+ # nbchan情報があればそれを使う
325
+ if nbchan is not None:
326
+ if dim0 == nbchan:
327
+ # (C, T) 形式
328
+ x_tc = data.T.astype(np.float32)
329
+ elif dim1 == nbchan:
330
+ # (T, C) 形式
331
+ x_tc = data.astype(np.float32)
332
+ else:
333
+ # nbchanと一致しない場合は小さい方をチャンネル数と仮定
334
+ if dim0 < dim1:
335
+ x_tc = data.T.astype(np.float32)
336
+ else:
337
+ x_tc = data.astype(np.float32)
338
+ else:
339
+ # 一般的な判定: 小さい方がチャンネル数
340
+ if dim0 < dim1:
341
+ x_tc = data.T.astype(np.float32)
342
+ else:
343
+ x_tc = data.astype(np.float32)
344
+
345
+ if debug:
346
+ print(f"Final shape (T, C): {x_tc.shape}")
347
+
348
+ return x_tc, fs
349
  """
350
  EEGLABファイルから電極位置(2D)を抽出。
351
 
 
576
  set_name: str,
577
  fdt_bytes: Optional[bytes] = None,
578
  fdt_name: Optional[str] = None,
579
+ ):
580
  """
581
  Load EEGLAB .set (and optional .fdt) from bytes using MNE or h5py.
582
  Returns:
583
+ tuple: (x_tc, fs, electrode_pos_2d, electrode_pos_3d)
584
+ x_tc: (T, C) float32
585
+ fs: sampling rate (Hz)
586
+ electrode_pos_2d: (C, 2) float32 or None - 電極の2D座標
587
+ electrode_pos_3d: (C, 3) float32 or None - 電極の3D座標
588
 
589
  Notes:
590
  - 多くのEEGLABは .set が .fdt を参照するため、同じディレクトリに同名で置く必要があります。
 
615
  x_tc = raw.get_data().T # (T,C)
616
 
617
  # 電極位置を取得
618
+ result = extract_electrode_positions_2d(set_path)
619
+ if result is not None:
620
+ electrode_pos_2d, electrode_pos_3d = result
621
+ else:
622
+ electrode_pos_2d, electrode_pos_3d = None, None
623
 
624
+ return x_tc.astype(np.float32), fs, electrode_pos_2d, electrode_pos_3d
625
 
626
  except Exception as e_raw:
627
  # 2) Epochsとして読む(エポックデータ用)
 
635
  x_tc = x_mean.T # (T,C)
636
 
637
  # 電極位置を取得(epochsからも取得可能)
638
+ result = extract_electrode_positions_2d(set_path)
639
+ if result is not None:
640
+ electrode_pos_2d, electrode_pos_3d = result
641
+ else:
642
+ electrode_pos_2d, electrode_pos_3d = None, None
643
 
644
+ return x_tc.astype(np.float32), fs, electrode_pos_2d, electrode_pos_3d
645
 
646
  except Exception as e_ep:
647
  # 3) HDF5形式として読む(MATLAB v7.3)
 
652
  import sys
653
  if 'streamlit' in sys.modules:
654
  debug = True
 
655
 
656
+ try:
657
+ x_tc, fs = _load_eeglab_hdf5(set_path, fdt_path=fdt_path, debug=debug)
658
+ except Exception as e_hdf5_inner:
659
+ import traceback
660
+ print("HDF5読み込みの詳細エラー:")
661
+ print(traceback.format_exc())
662
+ raise e_hdf5_inner
663
+
664
+ # HDF5の場合、電極位置をHDF5から直接取得を試みる
665
+ electrode_pos_2d, electrode_pos_3d = extract_electrode_positions_from_hdf5(set_path)
666
+
667
+ if debug and electrode_pos_2d is not None:
668
+ print(f"HDF5から電極位置を取得しました: {electrode_pos_2d.shape}")
669
 
670
+ return x_tc, fs, electrode_pos_2d, electrode_pos_3d
671
 
672
  except Exception as e_hdf5:
673
+ import traceback
674
  # すべて失敗した場合
675
  msg = (
676
  "EEGLABの読み込みに失敗しました。\n"
677
  f"- read_raw_eeglab error: {e_raw}\n"
678
  f"- read_epochs_eeglab error: {e_ep}\n"
679
  f"- HDF5読み込み error: {e_hdf5}\n"
680
+ f"\n詳細トレースバック:\n{traceback.format_exc()}"
681
  )
682
  raise RuntimeError(msg) from e_hdf5
683
 
src/metrics.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import stats
3
+
4
+ def chatterjee_phase_to_amp(phi, amp, agg="max"):
5
+ """
6
+ phi: phase in radians (1D)
7
+ amp: amplitude (1D)
8
+ agg: 'max' | 'mean' | 'rss'
9
+ """
10
+ s = np.sin(phi)
11
+ c = np.cos(phi)
12
+
13
+ xi_s = stats.chatterjeexi(s, amp).statistic
14
+ xi_c = stats.chatterjeexi(c, amp).statistic
15
+
16
+ if agg == "max":
17
+ xi = np.nanmax([xi_s, xi_c])
18
+ elif agg == "mean":
19
+ xi = np.nanmean([xi_s, xi_c])
20
+ elif agg == "rss":
21
+ xi = np.sqrt(xi_s**2 + xi_c**2)
22
+ xi = float(np.clip(xi, 0.0, 1.0))
23
+ else:
24
+ raise ValueError("agg must be 'max', 'mean', or 'rss'")
25
+
26
+ return xi #, {"xi_sin": xi_s, "xi_cos": xi_c}
27
+
28
+ def circular_correlation(rho, theta, mu=None, tau=None):
29
+ rho = np.asarray(rho)
30
+ theta = np.asarray(theta)
31
+
32
+ if mu is None:
33
+ mu = np.angle(np.mean(np.exp(1j * rho)))
34
+ if tau is None:
35
+ tau = np.angle(np.mean(np.exp(1j * theta)))
36
+
37
+ x = np.sin(rho - mu)
38
+ y = np.sin(theta - tau)
39
+
40
+ return np.mean(x * y) / np.sqrt(np.var(x) * np.var(y))
41
+
42
+
43
+ def modulation_index(phase, amp, n_bins=18, eps=1e-12):
44
+ """
45
+ Tort et al. (2010) Modulation Index
46
+ phase : radians (-pi, pi]
47
+ amp : amplitude envelope (>=0)
48
+ """
49
+ phase = np.asarray(phase).ravel()
50
+ amp = np.asarray(amp).ravel()
51
+ mask = np.isfinite(phase) & np.isfinite(amp)
52
+ phase = phase[mask]
53
+ amp = amp[mask]
54
+
55
+ # phase bins
56
+ edges = np.linspace(-np.pi, np.pi, n_bins + 1)
57
+ bins = np.digitize(phase, edges) - 1
58
+ bins = np.clip(bins, 0, n_bins - 1)
59
+
60
+ mean_amp = np.zeros(n_bins)
61
+ for k in range(n_bins):
62
+ if np.any(bins == k):
63
+ mean_amp[k] = amp[bins == k].mean()
64
+
65
+ if mean_amp.sum() == 0:
66
+ return np.nan
67
+
68
+ p = mean_amp / mean_amp.sum()
69
+ uniform = 1.0 / n_bins
70
+
71
+ kl = np.sum(p * np.log((p + eps) / uniform))
72
+ mi = kl / np.log(n_bins)
73
+ return mi
src/streamlit_app.py CHANGED
@@ -22,6 +22,8 @@ from loader import (
22
  load_mat_candidates,
23
  )
24
 
 
 
25
  st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide")
26
 
27
 
@@ -50,6 +52,19 @@ def ensure_tc(x: np.ndarray) -> np.ndarray:
50
  x = x.T
51
  return x
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # ============================================================
55
  # Signal processing
@@ -107,7 +122,7 @@ def preprocess_all_eeglab(
107
  EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
108
  fsは読み込んだデータのものを使う。
109
  """
110
- x_tc, fs, electrode_pos = load_eeglab_tc_from_bytes(
111
  set_bytes=set_bytes,
112
  set_name=set_name,
113
  fdt_bytes=fdt_bytes,
@@ -117,8 +132,10 @@ def preprocess_all_eeglab(
117
  result = preprocess_tc(x_tc, cfg)
118
 
119
  # 電極位置を追加
120
- if electrode_pos is not None:
121
- result["electrode_pos"] = electrode_pos
 
 
122
 
123
  return result
124
 
@@ -236,30 +253,103 @@ def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray:
236
  np.fill_diagonal(W, 0.0)
237
  return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
238
 
239
-
240
  def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray:
241
  """
242
- Phase の circular correlation (位相同期指標) を計算。
243
  Input: X_tc (T, C) - phase データ (ラジアン)
244
  Output: W (C, C) - circular correlation
245
 
246
- Circular correlation は以下で計算:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t|
248
- これは Phase Locking Value (PLV) とも呼ばれます。
249
  """
250
  T, C = X_tc.shape
251
  W = np.zeros((C, C), dtype=np.float32)
252
 
253
- # 各チャンネルペアについて circular correlation を計算
 
254
  for i in range(C):
255
  for j in range(i + 1, C):
256
  # 位相差
257
  phase_diff = X_tc[:, i] - X_tc[:, j]
258
- # PLV: |mean(exp(i*phase_diff))|
259
  plv = np.abs(np.mean(np.exp(1j * phase_diff)))
260
  W[i, j] = plv
261
  W[j, i] = plv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
264
 
265
 
@@ -275,18 +365,47 @@ def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray:
275
  return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
276
 
277
 
278
- def threshold_edges(W: np.ndarray, thr: float) -> List[Tuple[int, int, float]]:
 
 
 
 
 
 
 
 
 
 
 
 
279
  C = W.shape[0]
280
  edges: List[Tuple[int, int, float]] = []
281
- for i in range(C):
282
- for j in range(i + 1, C):
283
- w = float(W[i, j])
284
- if w >= thr:
285
- edges.append((i, j, w))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  edges.sort(key=lambda x: x[2], reverse=True)
287
  return edges
288
 
289
 
 
290
  def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray:
291
  if weighted:
292
  A = W.copy()
@@ -320,9 +439,9 @@ def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray:
320
 
321
  # 閾値以上のエッジを追加
322
  for i in range(C):
323
- for j in range(i + 1, C):
324
  if W[i, j] >= thr:
325
- G.add_edge(i, j, weight=W[i, j])
326
 
327
  # Louvain法でコミュニティ検出
328
  partition = community_louvain.best_partition(G, weight='weight')
@@ -376,37 +495,101 @@ def get_electrode_positions(prep: dict) -> np.ndarray:
376
  ys = np.sin(angles)
377
  return np.column_stack([xs, ys])
378
 
379
-
380
- def get_head_outline() -> dict:
 
 
 
 
381
  """
382
- 脳の輪郭(頭のアウライン)
383
-
384
- Returns:
385
- outline: {'head': (x, y), 'nose': (x, y), 'ears': [(x_left, y_left), (x_right, y_right)]}
386
  """
387
- # 頭の円
388
- theta = np.linspace(0, 2*np.pi, 100)
389
- head_x = np.cos(theta)
390
- head_y = np.sin(theta)
391
 
392
- # 鼻(上部の三角形)
393
- nose_x = np.array([0, -0.1, 0.1, 0])
394
- nose_y = np.array([1.0, 1.15, 1.15, 1.0])
395
 
396
- # 耳(左右突起)
397
- ear_theta = np.linspace(-np.pi/4, np.pi/4, 20)
398
- ear_left_x = -1.0 + 0.08 * np.cos(ear_theta)
399
- ear_left_y = 0.08 * np.sin(ear_theta)
 
 
 
 
 
 
400
 
401
- ear_right_x = 1.0 - 0.08 * np.cos(ear_theta)
402
- ear_right_y = 0.08 * np.sin(ear_theta)
 
 
 
 
403
 
404
- return {
405
- 'head': (head_x, head_y),
406
- 'nose': (nose_x, nose_y),
407
- 'ear_left': (ear_left_x, ear_left_y),
408
- 'ear_right': (ear_right_x, ear_right_y),
409
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
 
412
  def make_network_figure(
@@ -414,7 +597,6 @@ def make_network_figure(
414
  thr: float,
415
  use_louvain: bool = True,
416
  electrode_pos: np.ndarray = None,
417
- show_head: bool = True,
418
  ) -> tuple[go.Figure, int]:
419
  C = W.shape[0]
420
 
@@ -451,68 +633,106 @@ def make_network_figure(
451
  r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
452
  return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
453
 
454
- # 脳の輪郭を描画
455
- if show_head:
456
- outline = get_head_outline()
457
-
458
- # 頭の円
459
- fig.add_trace(go.Scatter(
460
- x=outline['head'][0], y=outline['head'][1],
461
- mode='lines',
462
- line=dict(color='rgba(150,150,150,0.5)', width=2),
463
- showlegend=False,
464
- hoverinfo='skip',
465
- ))
466
-
467
- # 鼻
468
- fig.add_trace(go.Scatter(
469
- x=outline['nose'][0], y=outline['nose'][1],
470
- mode='lines',
471
- line=dict(color='rgba(150,150,150,0.5)', width=2),
472
- showlegend=False,
473
- hoverinfo='skip',
474
- ))
475
-
476
- # 左耳
477
- fig.add_trace(go.Scatter(
478
- x=outline['ear_left'][0], y=outline['ear_left'][1],
479
- mode='lines',
480
- line=dict(color='rgba(150,150,150,0.5)', width=2),
481
- showlegend=False,
482
- hoverinfo='skip',
483
- ))
484
-
485
- # 右耳
486
- fig.add_trace(go.Scatter(
487
- x=outline['ear_right'][0], y=outline['ear_right'][1],
488
- mode='lines',
489
- line=dict(color='rgba(150,150,150,0.5)', width=2),
490
- showlegend=False,
491
- hoverinfo='skip',
492
- ))
493
-
494
  # エッジを描画(重みに応じて色と太さを変える)
495
- for (i, j, w) in edges:
496
- # 正規化された重み (0-1)
497
- norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
498
-
499
- # レインボーカラー: 弱い(青) → 中間(緑/黄) → い(赤)
500
- color = get_rainbow_color(norm_w)
501
-
502
- # 太さ: 重みに比例 (0.5-4の範囲)
503
- line_width = 0.5 + 3.5 * norm_w
504
-
505
- fig.add_trace(
506
- go.Scatter(
507
- x=[xs[i], xs[j]],
508
- y=[ys[i], ys[j]],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  mode="lines",
510
  hoverinfo="text",
511
- hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}",
512
  line=dict(width=line_width, color=color),
513
  showlegend=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  )
515
- )
516
 
517
  # Louvainクラスタリング
518
  if use_louvain and LOUVAIN_AVAILABLE:
@@ -607,8 +827,12 @@ st.sidebar.header("Input format")
607
  input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0)
608
 
609
  st.sidebar.header("Preprocess (auto)")
610
- f_low = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=8.0, step=0.5)
611
- f_high = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=12.0, step=0.5)
 
 
 
 
612
 
613
  st.sidebar.header("Viewer controls")
614
  win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1)
@@ -643,17 +867,26 @@ if input_mode.startswith("EEGLAB"):
643
  else:
644
  try:
645
  with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."):
646
- prep = preprocess_all_eeglab(
 
 
 
 
 
 
 
 
647
  set_bytes=set_file.getvalue(),
648
  fdt_bytes=fdt_file.getvalue(),
649
  set_name=set_file.name,
650
  fdt_name=fdt_file.name,
651
- f_low=float(f_low),
652
- f_high=float(f_high),
653
  )
654
- st.session_state["prep"] = prep
 
655
  st.session_state["W"] = None
656
- st.success(f"Loaded & preprocessed. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
657
  except Exception as e:
658
  st.session_state.pop("prep", None)
659
  st.session_state["W"] = None
@@ -714,10 +947,13 @@ else:
714
  st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}")
715
  try:
716
  with st.spinner("Preprocessing (bandpass + hilbert)..."):
717
- cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low), f_high=float(f_high))
718
  prep = preprocess_tc(x, cfg)
 
 
719
 
720
  st.session_state["prep"] = prep
 
721
  st.session_state["W"] = None
722
  st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
723
  except Exception as e:
@@ -866,7 +1102,7 @@ with col2:
866
  st.subheader("Data info")
867
  signal_desc = {
868
  "raw": "生信号(前処理なし)",
869
- "filtered": f"バンドパスフィルタ後 ({f_low}-{f_high} Hz)",
870
  "amplitude": "Hilbert振幅 (envelope)",
871
  "phase": "Hilbert位相 (-π ~ π)"
872
  }
@@ -894,20 +1130,29 @@ estimation_method = st.radio(
894
  "推定手法を選択",
895
  options=[
896
  "envelope_corr",
 
897
  "phase_corr",
 
 
898
  ],
899
  format_func=lambda x: {
900
- "envelope_corr": "Envelope correlation (振幅の相関)",
901
- "phase_corr": "Phase circular correlation (位相同期, PLV)",
 
 
 
902
  }[x],
903
  horizontal=True,
904
- help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_corr: 位相の circular correlation (Phase Locking Value)",
905
  )
906
 
907
  # 推定手法の説明
908
  method_info = {
909
  "envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。",
910
- "phase_corr": "**Phase circular correlation (PLV)**: 位相間の circular correlation を計算します。Phase Locking Value (PLV) とも呼ばれ、位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
 
 
 
911
  }
912
  st.info(method_info[estimation_method])
913
 
@@ -919,13 +1164,27 @@ W = st.session_state.get("W")
919
  need_estimation = (W is None) or (last_method != estimation_method)
920
 
921
  if need_estimation:
 
922
  with st.spinner(f"推定中... ({estimation_method})"):
923
  if estimation_method == "envelope_corr":
924
  X_in = prep["amplitude"]
925
  W = estimate_network_envelope_corr(X_in)
 
 
 
926
  elif estimation_method == "phase_corr":
927
  X_in = prep["phase"]
928
  W = estimate_network_phase_corr(X_in)
 
 
 
 
 
 
 
 
 
 
929
  else:
930
  st.error("未知の推定手法です")
931
  st.stop()
@@ -941,13 +1200,13 @@ else:
941
  # 閾値スライダーとネットワーク図の表示
942
  wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
943
 
944
- col_thr1, col_thr2, col_thr3 = st.columns([2, 1, 1])
945
  with col_thr1:
946
  thr = st.slider(
947
  "閾値 (threshold) ※下げるほどエッジが増えます",
948
  min_value=0.0,
949
  max_value=max(0.0001, wmax),
950
- value=min(0.5, wmax),
951
  step=max(wmax / 200, 0.001),
952
  )
953
  with col_thr2:
@@ -957,21 +1216,28 @@ with col_thr2:
957
  disabled=not LOUVAIN_AVAILABLE,
958
  help="ノードの色をコミュニティ検出結果で塗り分けます"
959
  )
960
- with col_thr3:
961
- show_head = st.checkbox(
962
- "脳の輪郭を表示",
963
- value=True,
964
- help="頭部のアウトラインを表示します"
965
- )
966
 
967
  # 電極位置を取得
968
  electrode_pos = prep.get("electrode_pos", None)
 
 
 
 
 
 
 
969
 
970
  if electrode_pos is not None:
971
  st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)")
972
  else:
973
  st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します")
974
 
 
 
 
 
 
 
975
  net_col1, net_col2 = st.columns([2, 1])
976
  with net_col1:
977
  fig_net, edge_n = make_network_figure(
@@ -979,10 +1245,31 @@ with net_col1:
979
  float(thr),
980
  use_louvain=use_louvain,
981
  electrode_pos=electrode_pos,
982
- show_head=show_head,
983
  )
984
  st.plotly_chart(fig_net)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
  with net_col2:
987
  st.metric("Edges", edge_n)
988
- st.plotly_chart(make_edgecount_curve(W))
 
 
 
 
 
22
  load_mat_candidates,
23
  )
24
 
25
+ import metrics
26
+
27
  st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide")
28
 
29
 
 
52
  x = x.T
53
  return x
54
 
55
+ def _quad_bezier_points(p0, p1, c, n=20):
56
+ """2次Bezierを点列にして返す (n点)"""
57
+ ts = np.linspace(0, 1, n)
58
+ pts = (1-ts)[:,None]**2 * p0 + 2*(1-ts)[:,None]*ts[:,None]*c + ts[:,None]**2 * p1
59
+ return pts # shape (n,2)
60
+
61
+ def _quad_bezier_point_and_tangent(p0, p1, c, t):
62
+ """2次Bezierの点と接線ベクトル(微分)を返す"""
63
+ # B(t) = (1-t)^2 p0 + 2(1-t)t c + t^2 p1
64
+ pt = (1-t)**2 * p0 + 2*(1-t)*t * c + t**2 * p1
65
+ # B'(t) = 2(1-t)(c-p0) + 2t(p1-c)
66
+ tan = 2*(1-t)*(c-p0) + 2*t*(p1-c)
67
+ return pt, tan
68
 
69
  # ============================================================
70
  # Signal processing
 
122
  EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
123
  fsは読み込んだデータのものを使う。
124
  """
125
+ x_tc, fs, electrode_pos_2d, electrode_pos_3d = load_eeglab_tc_from_bytes(
126
  set_bytes=set_bytes,
127
  set_name=set_name,
128
  fdt_bytes=fdt_bytes,
 
132
  result = preprocess_tc(x_tc, cfg)
133
 
134
  # 電極位置を追加
135
+ if electrode_pos_2d is not None:
136
+ result["electrode_pos"] = electrode_pos_2d
137
+ if electrode_pos_3d is not None:
138
+ result["electrode_pos_3d"] = electrode_pos_3d
139
 
140
  return result
141
 
 
253
  np.fill_diagonal(W, 0.0)
254
  return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
255
 
 
256
  def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray:
257
  """
258
+ Phase の PLV を計算。
259
  Input: X_tc (T, C) - phase データ (ラジアン)
260
  Output: W (C, C) - circular correlation
261
 
262
+ circular correlationは以下で計算:
263
+
264
+ """
265
+ T, C = X_tc.shape
266
+ W = np.zeros((C, C), dtype=np.float32)
267
+
268
+ # 各チャンネルペアについて PLV を計算
269
+ for i in range(C):
270
+ for j in range(i + 1, C):
271
+ #Jammalamadaka–Sengupta circular correlation
272
+ corr = metrics.circular_correlation(X_tc[:, i], X_tc[:, j])
273
+ W[i, j] = corr
274
+ W[j, i] = corr
275
+
276
+ return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
277
+
278
+ def estimate_network_phase_PLV(X_tc: np.ndarray, progress) -> np.ndarray:
279
+ """
280
+ Phase の PLV を計算。
281
+ Input: X_tc (T, C) - phase データ (ラジアン)
282
+ Output: W (C, C) - PLV
283
+
284
+ PLV は以下で計算:
285
  r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t|
 
286
  """
287
  T, C = X_tc.shape
288
  W = np.zeros((C, C), dtype=np.float32)
289
 
290
+ # 各チャンネルペアについて PLV を計算
291
+ tmp_ = 0
292
  for i in range(C):
293
  for j in range(i + 1, C):
294
  # 位相差
295
  phase_diff = X_tc[:, i] - X_tc[:, j]
 
296
  plv = np.abs(np.mean(np.exp(1j * phase_diff)))
297
  W[i, j] = plv
298
  W[j, i] = plv
299
+ tmp_ += 1
300
+ progress.progress(tmp_ / (int(C*(C-1)/2)))
301
+
302
+ return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
303
+
304
+
305
+ def estimate_network_pac_tort(X_tc1, X_tc2, progress):
306
+ """
307
+ PACを目的としてModulation Indexを計算
308
+ Input: X_tc1 (T, C) - phase データ (ラジアン)
309
+ Input: X_tc2 (T, C) - envelope データ
310
+ Output: W (C, C) - Modulation Index
311
+ """
312
+ assert X_tc1.shape == X_tc2.shape
313
+ T, C = X_tc1.shape
314
+ W = np.zeros((C, C), dtype=np.float32)
315
 
316
+ # 各チャンネルペアについて Chatterjee correlation を計算
317
+ tmp_ = 0
318
+ for i in range(C):
319
+ for j in range(C):
320
+ if i == j:
321
+ continue
322
+ # Modulation Index from Tort et al.(2010)
323
+ mi_ = metrics.modulation_index(X_tc1[:, i], X_tc2[:, j])
324
+ W[i, j] = mi_
325
+ tmp_ += 1
326
+ progress.progress(tmp_ / (C*C))
327
+
328
+ return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
329
+
330
+ def estimate_network_pac_chatterjee(X_tc1, X_tc2, progress):
331
+ """
332
+ PACを目的としてChatterjee相関を計算
333
+ Input: X_tc1 (T, C) - phase データ (ラジアン)
334
+ Input: X_tc2 (T, C) - envelope データ
335
+ Output: W (C, C) - Chatterjee correlation from phase to envelope
336
+ """
337
+ assert X_tc1.shape == X_tc2.shape
338
+ T, C = X_tc1.shape
339
+ W = np.zeros((C, C), dtype=np.float32)
340
+
341
+ # 各チャンネルペアについて Chatterjee correlation を計算
342
+ tmp_ = 0
343
+ for i in range(C):
344
+ for j in range(C):
345
+ if i == j:
346
+ continue
347
+ # Chatterjee相関係数
348
+ corr_ = metrics.chatterjee_phase_to_amp(X_tc1[:, i], X_tc2[:, j])
349
+ W[i, j] = corr_
350
+ tmp_ += 1
351
+ progress.progress(tmp_ / (C*C))
352
+
353
  return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
354
 
355
 
 
365
  return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
366
 
367
 
368
+ def threshold_edges(
369
+ W: np.ndarray,
370
+ thr: float,
371
+ ) -> List[Tuple[int, int, float]]:
372
+ """
373
+ エッジ抽出関数
374
+
375
+ - W が対称 → 無向グラフとして i < j のみ抽出
376
+ - W が非対称 → 有向グラフとして i -> j をすべて抽出
377
+
378
+ Returns:
379
+ (i, j, w): 対称の場合は無向、非対称の場合は i→j
380
+ """
381
  C = W.shape[0]
382
  edges: List[Tuple[int, int, float]] = []
383
+
384
+ is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0)
385
+
386
+ if is_symmetric:
387
+ # --- 無向グラフ ---
388
+ for i in range(C):
389
+ for j in range(i + 1, C):
390
+ w = float(W[i, j])
391
+ if w >= thr:
392
+ edges.append((i, j, w))
393
+ else:
394
+ # --- 有向グラフ ---
395
+ for i in range(C):
396
+ for j in range(C):
397
+ if i == j:
398
+ continue
399
+ w = float(W[i, j])
400
+ if w >= thr:
401
+ edges.append((i, j, w))
402
+
403
+ # 重みの大きい順にソート
404
  edges.sort(key=lambda x: x[2], reverse=True)
405
  return edges
406
 
407
 
408
+
409
  def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray:
410
  if weighted:
411
  A = W.copy()
 
439
 
440
  # 閾値以上のエッジを追加
441
  for i in range(C):
442
+ for j in range(C):
443
  if W[i, j] >= thr:
444
+ G.add_edge(i, j, weight=max(W[i, j],W[j, i]))
445
 
446
  # Louvain法でコミュニティ検出
447
  partition = community_louvain.best_partition(G, weight='weight')
 
495
  ys = np.sin(angles)
496
  return np.column_stack([xs, ys])
497
 
498
+ def make_network_figure_3d(
499
+ W: np.ndarray,
500
+ thr: float,
501
+ electrode_pos_3d: np.ndarray,
502
+ use_louvain: bool = True,
503
+ ) -> go.Figure:
504
  """
505
+ 3Dネッワーク図(ドラッグで回転可能)
 
 
 
506
  """
507
+ C = W.shape[0]
508
+ xs = electrode_pos_3d[:, 0]
509
+ ys = electrode_pos_3d[:, 1]
510
+ zs = electrode_pos_3d[:, 2]
511
 
512
+ edges = threshold_edges(W, thr)
513
+ fig = go.Figure()
 
514
 
515
+ # エッジ重みの範囲を取得
516
+ if edges:
517
+ weights = [w for _, _, w in edges]
518
+ min_w = min(weights)
519
+ max_w = max(weights)
520
+ weight_range = max_w - min_w if max_w > min_w else 1.0
521
+ else:
522
+ min_w = 0
523
+ max_w = 1
524
+ weight_range = 1.0
525
 
526
+ # レインボーカラーマップ関数
527
+ def get_rainbow_color(norm_val):
528
+ import colorsys
529
+ hue = (1.0 - norm_val) * 0.67
530
+ r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
531
+ return f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})'
532
 
533
+ # エッジを描画
534
+ for (i, j, w) in edges:
535
+ norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
536
+ color = get_rainbow_color(norm_w)
537
+ line_width = 1 + 4 * norm_w
538
+
539
+ fig.add_trace(go.Scatter3d(
540
+ x=[xs[i], xs[j], None],
541
+ y=[ys[i], ys[j], None],
542
+ z=[zs[i], zs[j], None],
543
+ mode='lines',
544
+ line=dict(color=color, width=line_width),
545
+ hoverinfo='skip',
546
+ showlegend=False,
547
+ ))
548
+
549
+ # Louvainクラスタリング
550
+ if use_louvain and LOUVAIN_AVAILABLE:
551
+ clusters = compute_louvain_clusters(W, thr)
552
+ node_colors = get_cluster_colors(clusters)
553
+ n_clusters = len(np.unique(clusters))
554
+ title_suffix = f" | Louvain clusters: {n_clusters}"
555
+ else:
556
+ node_colors = ['#FFD700'] * C
557
+ clusters = np.zeros(C, dtype=int)
558
+ title_suffix = ""
559
+
560
+ # ノードを描画
561
+ fig.add_trace(go.Scatter3d(
562
+ x=xs,
563
+ y=ys,
564
+ z=zs,
565
+ mode='markers+text',
566
+ text=[f"{k}" for k in range(C)],
567
+ textposition='top center',
568
+ textfont=dict(size=8),
569
+ marker=dict(
570
+ size=8,
571
+ color=node_colors,
572
+ line=dict(color='white', width=1),
573
+ ),
574
+ hoverinfo='text',
575
+ hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
576
+ showlegend=False,
577
+ ))
578
+
579
+ fig.update_layout(
580
+ title=f"3D Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
581
+ height=700,
582
+ scene=dict(
583
+ xaxis=dict(visible=False),
584
+ yaxis=dict(visible=False),
585
+ zaxis=dict(visible=False),
586
+ bgcolor='rgba(0,0,0,0.9)',
587
+ ),
588
+ paper_bgcolor='rgba(0,0,0,0.9)',
589
+ margin=dict(l=0, r=0, t=50, b=0),
590
+ )
591
+
592
+ return fig
593
 
594
 
595
  def make_network_figure(
 
597
  thr: float,
598
  use_louvain: bool = True,
599
  electrode_pos: np.ndarray = None,
 
600
  ) -> tuple[go.Figure, int]:
601
  C = W.shape[0]
602
 
 
633
  r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
634
  return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  # エッジを描画(重みに応じて色と太さを変える)
637
+
638
+ # --- 有向のときだけ:矢印(三角マーカー)を終端側に置く ---
639
+ is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0)
640
+ if (not is_symmetric):
641
+ curve_strength = 0.1 # 湾曲のさ(要調整)
642
+ node_radius = 0.08 # ノード中心からどれくらい手前に終点/矢印を置くか(要調整)
643
+ bezier_n = 18 # 曲線の分割数(増やすほど滑らか)
644
+ t_arrow = 0.90 # 矢印を置く位置(0〜1)
645
+ for (i, j, w) in edges:
646
+ norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
647
+ color = get_rainbow_color(norm_w)
648
+ line_width = 0.5 + 3.5 * norm_w
649
+
650
+ p0 = np.array([xs[i], ys[i]], dtype=float)
651
+ p1 = np.array([xs[j], ys[j]], dtype=float)
652
+
653
+ v = p1 - p0
654
+ dist = np.hypot(v[0], v[1])
655
+ if dist < 1e-9:
656
+ continue
657
+ u = v / dist
658
+
659
+ # ノードに重ならないよう端点を縮める
660
+ p0s = p0 + u * node_radius
661
+ p1s = p1 - u * node_radius
662
+
663
+ # 垂直方向(法線)
664
+ n = np.array([-u[1], u[0]])
665
+
666
+ # ★ 有向エッジは全部曲げる(規則的に)
667
+ sign = 1.0 #if i < j else -1.0
668
+
669
+ # 制御点
670
+ mid = 0.5 * (p0s + p1s)
671
+ c = mid + sign * curve_strength * dist * n
672
+
673
+ # 曲線点列
674
+ pts = _quad_bezier_points(p0s, p1s, c, n=bezier_n)
675
+
676
+ fig.add_trace(go.Scatter(
677
+ x=pts[:, 0],
678
+ y=pts[:, 1],
679
  mode="lines",
680
  hoverinfo="text",
681
+ hovertext=f"ch{i} ch{j}<br>weight: {w:.4f}",
682
  line=dict(width=line_width, color=color),
683
  showlegend=False,
684
+ ))
685
+
686
+ # 矢印(曲線接線方向)
687
+ pt, tan = _quad_bezier_point_and_tangent(p0s, p1s, c, t_arrow)
688
+
689
+ # 接線がゼロに近い場合の保険
690
+ tx, ty = float(tan[0]), float(tan[1])
691
+ if tx*tx + ty*ty < 1e-18:
692
+ tx, ty = float(p1s[0] - p0s[0]), float(p1s[1] - p0s[1])
693
+
694
+ theta = np.degrees(np.arctan2(ty, tx)) # 接線の角度(+x基準)
695
+ ANGLE_OFFSET = -90.0 # triangle-up(上向き) を接線方向に合わせる補正
696
+ ang = (theta + ANGLE_OFFSET) % 360
697
+
698
+ fig.add_trace(go.Scatter(
699
+ x=[pt[0]],
700
+ y=[pt[1]],
701
+ mode="markers",
702
+ hoverinfo="skip",
703
+ marker=dict(
704
+ symbol="triangle-up",
705
+ size=10,
706
+ angle=-ang,
707
+ angleref="up",
708
+ color=color,
709
+ line=dict(width=0),
710
+ ),
711
+ showlegend=False,
712
+ ))
713
+ else:
714
+ for (i, j, w) in edges:
715
+ # 正規化された重み (0-1)
716
+ norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
717
+
718
+ # レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤)
719
+ color = get_rainbow_color(norm_w)
720
+
721
+ # 太さ: 重みに比例 (0.5-4の範囲)
722
+ line_width = 0.5 + 3.5 * norm_w
723
+
724
+ fig.add_trace(
725
+ go.Scatter(
726
+ x=[xs[i], xs[j]],
727
+ y=[ys[i], ys[j]],
728
+ mode="lines",
729
+ hoverinfo="text",
730
+ hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}",
731
+ line=dict(width=line_width, color=color),
732
+ showlegend=False,
733
+ )
734
  )
735
+
736
 
737
  # Louvainクラスタリング
738
  if use_louvain and LOUVAIN_AVAILABLE:
 
827
  input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0)
828
 
829
  st.sidebar.header("Preprocess (auto)")
830
+ f_low_src = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=4.0, step=1.0, key="low_src")
831
+ f_high_src = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=8.0, step=1.0, key="high_src")
832
+
833
+ st.sidebar.header("if you use CFC+PAC:")
834
+ f_low_tgt = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=25.0, step=1.0, key="low_tgt")
835
+ f_high_tgt = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=40.0, step=1.0, key="high_tgt")
836
 
837
  st.sidebar.header("Viewer controls")
838
  win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1)
 
867
  else:
868
  try:
869
  with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."):
870
+ prep_src = preprocess_all_eeglab(
871
+ set_bytes=set_file.getvalue(),
872
+ fdt_bytes=fdt_file.getvalue(),
873
+ set_name=set_file.name,
874
+ fdt_name=fdt_file.name,
875
+ f_low=float(f_low_src),
876
+ f_high=float(f_high_src),
877
+ )
878
+ prep_tgt = preprocess_all_eeglab(
879
  set_bytes=set_file.getvalue(),
880
  fdt_bytes=fdt_file.getvalue(),
881
  set_name=set_file.name,
882
  fdt_name=fdt_file.name,
883
+ f_low=float(f_low_tgt),
884
+ f_high=float(f_high_tgt),
885
  )
886
+ st.session_state["prep"] = prep_src
887
+ st.session_state["prep_tgt"] = prep_tgt
888
  st.session_state["W"] = None
889
+ st.success(f"Loaded & preprocessed. (T,C)={prep_src['raw'].shape} fs={prep_src['fs']:.2f}Hz")
890
  except Exception as e:
891
  st.session_state.pop("prep", None)
892
  st.session_state["W"] = None
 
947
  st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}")
948
  try:
949
  with st.spinner("Preprocessing (bandpass + hilbert)..."):
950
+ cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_src), f_high=float(f_high_src))
951
  prep = preprocess_tc(x, cfg)
952
+ cfg_tgt = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_tgt), f_high=float(f_high_tgt))
953
+ prep_tgt = preprocess_tc(x, cfg_tgt)
954
 
955
  st.session_state["prep"] = prep
956
+ st.session_state["prep_tgt"] = prep_tgt
957
  st.session_state["W"] = None
958
  st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
959
  except Exception as e:
 
1102
  st.subheader("Data info")
1103
  signal_desc = {
1104
  "raw": "生信号(前処理なし)",
1105
+ "filtered": f"バンドパスフィルタ後 ({f_low_src}-{f_high_src} Hz)",
1106
  "amplitude": "Hilbert振幅 (envelope)",
1107
  "phase": "Hilbert位相 (-π ~ π)"
1108
  }
 
1130
  "推定手法を選択",
1131
  options=[
1132
  "envelope_corr",
1133
+ "phase_PLV",
1134
  "phase_corr",
1135
+ "pac_tort",
1136
+ "pac_chatterjee"
1137
  ],
1138
  format_func=lambda x: {
1139
+ "envelope_corr": "Envelope Pearson correlation (振幅の相関)",
1140
+ "phase_PLV": "PLV(位相同期, PLV",
1141
+ "phase_corr": "Circular correlation",
1142
+ "pac_tort": "Modulation Index(位相と振幅のPAC指標)",
1143
+ "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)",
1144
  }[x],
1145
  horizontal=True,
1146
+ help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_PLV: 位相のPhase Locking Value | phase_corr: 位相の相関係数 | pac_tort: Modulation index | pac_chatterjee: 位相から振幅へのChatterjee相関",
1147
  )
1148
 
1149
  # 推定手法の説明
1150
  method_info = {
1151
  "envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。",
1152
+ "phase_PLV": "**PLV**: 位相間のPhase locking valueを計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
1153
+ "phase_corr": "**Circular correlation**: 位相間の相関係数を計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
1154
+ "pac_tort": "Modulation Index(位相と振幅のPAC指標)",
1155
+ "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)",
1156
  }
1157
  st.info(method_info[estimation_method])
1158
 
 
1164
  need_estimation = (W is None) or (last_method != estimation_method)
1165
 
1166
  if need_estimation:
1167
+ progress = st.progress(0.0)
1168
  with st.spinner(f"推定中... ({estimation_method})"):
1169
  if estimation_method == "envelope_corr":
1170
  X_in = prep["amplitude"]
1171
  W = estimate_network_envelope_corr(X_in)
1172
+ elif estimation_method == "phase_PLV":
1173
+ X_in = prep["phase"]
1174
+ W = estimate_network_phase_PLV(X_in, progress)
1175
  elif estimation_method == "phase_corr":
1176
  X_in = prep["phase"]
1177
  W = estimate_network_phase_corr(X_in)
1178
+ elif estimation_method == "pac_tort":
1179
+ X_in_low_phase = prep["phase"]
1180
+ prep_tgt = st.session_state["prep_tgt"]
1181
+ X_in_high_amplitude = prep_tgt["amplitude"]
1182
+ W = estimate_network_pac_tort(X_in_low_phase,X_in_high_amplitude,progress)
1183
+ elif estimation_method == "pac_chatterjee":
1184
+ X_in_low_phase = prep["phase"]
1185
+ prep_tgt = st.session_state["prep_tgt"]
1186
+ X_in_high_amplitude = prep_tgt["amplitude"]
1187
+ W = estimate_network_pac_chatterjee(X_in_low_phase,X_in_high_amplitude,progress)
1188
  else:
1189
  st.error("未知の推定手法です")
1190
  st.stop()
 
1200
  # 閾値スライダーとネットワーク図の表示
1201
  wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
1202
 
1203
+ col_thr1, col_thr2 = st.columns([3, 1])
1204
  with col_thr1:
1205
  thr = st.slider(
1206
  "閾値 (threshold) ※下げるほどエッジが増えます",
1207
  min_value=0.0,
1208
  max_value=max(0.0001, wmax),
1209
+ value=wmax/2,
1210
  step=max(wmax / 200, 0.001),
1211
  )
1212
  with col_thr2:
 
1216
  disabled=not LOUVAIN_AVAILABLE,
1217
  help="ノードの色をコミュニティ検出結果で塗り分けます"
1218
  )
 
 
 
 
 
 
1219
 
1220
  # 電極位置を取得
1221
  electrode_pos = prep.get("electrode_pos", None)
1222
+ # 2D座標を90度左回転(上が正面になる向きに)
1223
+ if electrode_pos is not None:
1224
+ electrode_pos = np.asarray(electrode_pos, dtype=np.float32)
1225
+ if electrode_pos.ndim == 2 and electrode_pos.shape[1] >= 2:
1226
+ pos2 = electrode_pos[:, :2]
1227
+ electrode_pos = np.column_stack([-pos2[:, 1], pos2[:, 0]])
1228
+ electrode_pos_3d = prep.get("electrode_pos_3d", None)
1229
 
1230
  if electrode_pos is not None:
1231
  st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)")
1232
  else:
1233
  st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します")
1234
 
1235
+ # 3D座標の有無を表示
1236
+ if electrode_pos_3d is not None:
1237
+ st.success(f"✓ 3D電極座標を取得しました ({electrode_pos_3d.shape[0]} channels) - 下部に3Dビューアを表示します")
1238
+ else:
1239
+ st.info("ℹ️ 3D電極座標が取得できませんでした - 2D表示のみです")
1240
+
1241
  net_col1, net_col2 = st.columns([2, 1])
1242
  with net_col1:
1243
  fig_net, edge_n = make_network_figure(
 
1245
  float(thr),
1246
  use_louvain=use_louvain,
1247
  electrode_pos=electrode_pos,
 
1248
  )
1249
  st.plotly_chart(fig_net)
1250
+ # 3Dネットワーク表示(3D座標がある場合のみ)
1251
+ if electrode_pos_3d is not None:
1252
+ electrode_pos_3d = np.asarray(electrode_pos_3d, dtype=np.float32)
1253
+ if electrode_pos_3d.ndim == 2 and electrode_pos_3d.shape[0] == W.shape[0] and electrode_pos_3d.shape[1] == 3:
1254
+ st.subheader("3D Viewer")
1255
+ fig_3d = make_network_figure_3d(
1256
+ W=W,
1257
+ thr=float(thr),
1258
+ electrode_pos_3d=electrode_pos_3d,
1259
+ use_louvain=use_louvain,
1260
+ )
1261
+ st.plotly_chart(
1262
+ fig_3d,
1263
+ width="stretch",
1264
+ config={"displayModeBar": True, "scrollZoom": True},
1265
+ )
1266
+ else:
1267
+ st.warning(f"3D座標のshapeが不正です: {electrode_pos_3d.shape}(期待: (C,3), C={W.shape[0]})")
1268
 
1269
  with net_col2:
1270
  st.metric("Edges", edge_n)
1271
+ st.plotly_chart(make_edgecount_curve(W))
1272
+
1273
+
1274
+ st.write("# Hypothesis testing")
1275
+ st.write("Coming soon ...")