stardust-coder commited on
Commit
b11ec91
·
1 Parent(s): 1442b78

[add] app files

Browse files
Files changed (4) hide show
  1. requirements.txt +7 -3
  2. src/loader.py +393 -0
  3. src/preprocess.py +59 -0
  4. src/streamlit_app.py +853 -37
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ plotly
4
+ scipy
5
+ mne
6
+ h5py
7
+ networkx
src/loader.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple, Optional, Dict, Any
4
+ import io
5
+ import os
6
+ import tempfile
7
+
8
+ import numpy as np
9
+
10
+ import mne
11
+ from scipy.io import loadmat
12
+
13
+ try:
14
+ import h5py # MAT v7.3 (HDF5)
15
+ except Exception: # pragma: no cover
16
+ h5py = None
17
+
18
+
19
+ # ============================================================
20
+ # EEGLAB loader (.set + .fdt)
21
+ # ============================================================
22
+ def pick_set_fdt(files) -> Tuple[Optional[object], Optional[object]]:
23
+ """
24
+ Streamlitの accept_multiple_files=True で受け取ったfilesから .set と .fdt を拾う。
25
+ Returns: (set_file, fdt_file)
26
+ """
27
+ set_file = None
28
+ fdt_file = None
29
+ for f in files:
30
+ name = (getattr(f, "name", "") or "").lower()
31
+ if name.endswith(".set"):
32
+ set_file = f
33
+ elif name.endswith(".fdt"):
34
+ fdt_file = f
35
+ return set_file, fdt_file
36
+
37
+
38
+ def same_stem(a_name: str, b_name: str) -> bool:
39
+ """Check if two filenames have the same stem (basename without extension)."""
40
+ a_stem = os.path.splitext(os.path.basename(a_name))[0]
41
+ b_stem = os.path.splitext(os.path.basename(b_name))[0]
42
+ return a_stem == b_stem
43
+
44
+
45
+ def _load_eeglab_hdf5(set_path: str, fdt_path: Optional[str] = None, debug: bool = False) -> Tuple[np.ndarray, float]:
46
+ """
47
+ Load EEGLAB .set file saved in MATLAB v7.3 (HDF5) format using h5py.
48
+ Returns: (x_tc, fs) where x_tc is (T, C)
49
+ """
50
+ if h5py is None:
51
+ raise RuntimeError("EEGLAB .set ファイルが MATLAB v7.3 (HDF5) 形式ですが、h5py がインストールされていません。pip install h5py を実行してください。")
52
+
53
+ with h5py.File(set_path, "r") as f:
54
+ # デバッグ: ファイル構造を表示
55
+ if debug:
56
+ print("=== HDF5 file structure ===")
57
+ def print_structure(name, obj):
58
+ if isinstance(obj, h5py.Dataset):
59
+ print(f"Dataset: {name}, shape: {obj.shape}, dtype: {obj.dtype}")
60
+ elif isinstance(obj, h5py.Group):
61
+ print(f"Group: {name}")
62
+ f.visititems(print_structure)
63
+ print("===========================")
64
+
65
+ # サンプリングレートを取得
66
+ fs = None
67
+ for path in ["EEG/srate", "srate"]:
68
+ if path in f:
69
+ srate_data = f[path]
70
+ if isinstance(srate_data, h5py.Dataset):
71
+ val = srate_data[()]
72
+ # 配列の場合は最初の要素を取得
73
+ fs = float(val.flat[0]) if hasattr(val, 'flat') else float(val)
74
+ break
75
+
76
+ if fs is None:
77
+ raise ValueError("サンプリングレート (srate) が見つかりません")
78
+
79
+ # チャンネル数を取得
80
+ nbchan = None
81
+ for path in ["EEG/nbchan", "nbchan"]:
82
+ if path in f:
83
+ nbchan_data = f[path]
84
+ if isinstance(nbchan_data, h5py.Dataset):
85
+ val = nbchan_data[()]
86
+ nbchan = int(val.flat[0]) if hasattr(val, 'flat') else int(val)
87
+ break
88
+
89
+ # サンプル数を取得
90
+ pnts = None
91
+ for path in ["EEG/pnts", "pnts"]:
92
+ if path in f:
93
+ pnts_data = f[path]
94
+ if isinstance(pnts_data, h5py.Dataset):
95
+ val = pnts_data[()]
96
+ pnts = int(val.flat[0]) if hasattr(val, 'flat') else int(val)
97
+ break
98
+
99
+ if debug:
100
+ print(f"nbchan: {nbchan}, pnts: {pnts}, fs: {fs}")
101
+
102
+ # データを取得 - まず .set 内を確認
103
+ data = None
104
+ data_shape = None
105
+
106
+ if debug:
107
+ print(f"Checking for data, fdt_path provided: {fdt_path is not None}")
108
+ if fdt_path:
109
+ print(f"fdt_path exists: {os.path.exists(fdt_path)}")
110
+
111
+ # パターン1: EEG/data が参照配列の場合、各参照を辿る
112
+ if "EEG" in f and "data" in f["EEG"]:
113
+ data_ref = f["EEG"]["data"]
114
+ if isinstance(data_ref, h5py.Dataset):
115
+ if debug:
116
+ print(f"EEG/data dtype: {data_ref.dtype}, shape: {data_ref.shape}, size: {data_ref.size}")
117
+
118
+ if data_ref.dtype == h5py.ref_dtype:
119
+ # 参照の場合 - 通常は .fdt ファイルを指す
120
+ if debug:
121
+ print("EEG/data is reference type - data should be in .fdt file")
122
+ # .fdt ファイルが必要
123
+ if fdt_path is not None and os.path.exists(fdt_path):
124
+ data = _load_fdt_file(fdt_path, nbchan, pnts, debug=debug)
125
+ else:
126
+ raise ValueError(".fdt ファイルが必要ですが見つかりません。.set と .fdt の両方をアップロードしてください。")
127
+ elif data_ref.size > 100: # 参照配列ではなく実データ
128
+ data = data_ref[()]
129
+ data_shape = data.shape
130
+ if debug:
131
+ print(f"EEG/data contains actual data, shape: {data_shape}")
132
+ else:
133
+ # 小さい配列 = 参照リスト、.fdtファイルが必要
134
+ if debug:
135
+ print(f"EEG/data is small array (size={data_ref.size}), assuming reference to .fdt")
136
+ if fdt_path is not None and os.path.exists(fdt_path):
137
+ data = _load_fdt_file(fdt_path, nbchan, pnts, debug=debug)
138
+ else:
139
+ raise ValueError(".fdt ファイルが必要ですが見つかりません。.set と .fdt の両方をアップロードしてください。")
140
+
141
+ # パターン2: 直接 data
142
+ if data is None and "data" in f:
143
+ data_obj = f["data"]
144
+ if isinstance(data_obj, h5py.Dataset):
145
+ data = data_obj[()]
146
+ data_shape = data.shape
147
+
148
+ if data is None:
149
+ raise ValueError("EEGデータが見つかりません。.fdt ファイルが必要な可能性があります。")
150
+
151
+ if debug:
152
+ print(f"Data shape: {data.shape if hasattr(data, 'shape') else 'loaded from fdt'}")
153
+
154
+ # データの形状を調整
155
+ if data.ndim != 2:
156
+ raise ValueError(f"予期しないデータ次元: {data.ndim}")
157
+
158
+ dim0, dim1 = data.shape
159
+
160
+ # nbchan情報があればそれを使う
161
+ if nbchan is not None:
162
+ if dim0 == nbchan:
163
+ # (C, T) 形式
164
+ x_tc = data.T.astype(np.float32)
165
+ elif dim1 == nbchan:
166
+ # (T, C) 形式
167
+ x_tc = data.astype(np.float32)
168
+ else:
169
+ # nbchanと一致しない場合は小さい方をチャンネル数と仮定
170
+ if dim0 < dim1:
171
+ x_tc = data.T.astype(np.float32)
172
+ else:
173
+ x_tc = data.astype(np.float32)
174
+ else:
175
+ # 一般的な判定: 小さい方がチャンネル数
176
+ if dim0 < dim1:
177
+ x_tc = data.T.astype(np.float32)
178
+ else:
179
+ x_tc = data.astype(np.float32)
180
+
181
+ if debug:
182
+ print(f"Final shape (T, C): {x_tc.shape}")
183
+
184
+ return x_tc, fs
185
+
186
+
187
+ def _load_fdt_file(fdt_path: str, nbchan: Optional[int], pnts: Optional[int], debug: bool = False) -> np.ndarray:
188
+ """
189
+ Load .fdt file (raw binary float32 data).
190
+ EEGLAB .fdt files are stored as float32 in (C, T) order.
191
+ """
192
+ if debug:
193
+ print(f"Loading .fdt file: {fdt_path}")
194
+
195
+ # .fdt ファイルは float32 のバイナリデータ
196
+ data = np.fromfile(fdt_path, dtype=np.float32)
197
+
198
+ if debug:
199
+ print(f"Loaded {data.size} float32 values from .fdt")
200
+
201
+ # チャンネル数とサンプル数がわかっている場合はリシェイプ
202
+ if nbchan is not None and pnts is not None:
203
+ expected_size = nbchan * pnts
204
+ if data.size == expected_size:
205
+ # EEGLAB は (C, T) 順で保存
206
+ data = data.reshape(nbchan, pnts)
207
+ if debug:
208
+ print(f"Reshaped to ({nbchan}, {pnts})")
209
+ else:
210
+ if debug:
211
+ print(f"Warning: expected {expected_size} values but got {data.size}")
212
+ # 可能な限りリシェイプを試みる
213
+ if data.size % nbchan == 0:
214
+ data = data.reshape(nbchan, -1)
215
+ elif data.size % pnts == 0:
216
+ data = data.reshape(-1, pnts)
217
+ else:
218
+ raise ValueError(f"Cannot reshape data of size {data.size} with nbchan={nbchan}, pnts={pnts}")
219
+ else:
220
+ raise ValueError("nbchan と pnts の情報が必要です")
221
+
222
+ return data
223
+
224
+
225
+ def load_eeglab_tc_from_bytes(
226
+ set_bytes: bytes,
227
+ set_name: str,
228
+ fdt_bytes: Optional[bytes] = None,
229
+ fdt_name: Optional[str] = None,
230
+ ) -> Tuple[np.ndarray, float]:
231
+ """
232
+ Load EEGLAB .set (and optional .fdt) from bytes using MNE or h5py.
233
+ Returns:
234
+ x_tc: (T, C) float32
235
+ fs: sampling rate (Hz)
236
+
237
+ Notes:
238
+ - 多くのEEGLABは .set が .fdt を参照するため、同じディレクトリに同名で置く必要があります。
239
+ - .set単体で完結している場合は fdt_* を省略可能にしています。
240
+ - MATLAB v7.3 (HDF5) 形式の .set にも対応しています。
241
+ """
242
+ if fdt_bytes is not None or fdt_name is not None:
243
+ if fdt_bytes is None or fdt_name is None:
244
+ raise ValueError("fdt_bytes と fdt_name は両方指定してください。")
245
+ if not same_stem(set_name, fdt_name):
246
+ raise ValueError(f".set と .fdt のファイル名(拡張子除く)が一致していません: {set_name} vs {fdt_name}")
247
+
248
+ with tempfile.TemporaryDirectory() as tmpdir:
249
+ set_path = os.path.join(tmpdir, os.path.basename(set_name))
250
+ with open(set_path, "wb") as f:
251
+ f.write(set_bytes)
252
+
253
+ fdt_path = None # 初期化
254
+ if fdt_bytes is not None and fdt_name is not None:
255
+ fdt_path = os.path.join(tmpdir, os.path.basename(fdt_name))
256
+ with open(fdt_path, "wb") as f:
257
+ f.write(fdt_bytes)
258
+
259
+ # 1) Rawとして読む(通常のEEGLAB形式)
260
+ try:
261
+ raw = mne.io.read_raw_eeglab(set_path, preload=True, verbose=False)
262
+ fs = float(raw.info["sfreq"])
263
+ x_tc = raw.get_data().T # (T,C)
264
+ return x_tc.astype(np.float32), fs
265
+
266
+ except Exception as e_raw:
267
+ # 2) Epochsとして読む(エポックデータ用)
268
+ try:
269
+ epochs = mne.io.read_epochs_eeglab(set_path, verbose=False, montage_units="cm")
270
+ fs = float(epochs.info["sfreq"])
271
+ x = epochs.get_data(copy=True) # (n_epochs, n_channels, n_times)
272
+
273
+ # ここは方針を選ぶ:平均 or 連結
274
+ x_mean = x.mean(axis=0) # (C,T)
275
+ x_tc = x_mean.T # (T,C)
276
+ return x_tc.astype(np.float32), fs
277
+
278
+ except Exception as e_ep:
279
+ # 3) HDF5形式として読む(MATLAB v7.3)
280
+ try:
281
+ # デバッグモードを有効化(環境変数で制御可能)
282
+ debug = os.environ.get("EEGLAB_DEBUG", "0") == "1"
283
+ # Streamlit環境では常にデバッグ情報を表示
284
+ import sys
285
+ if 'streamlit' in sys.modules:
286
+ debug = True
287
+ x_tc, fs = _load_eeglab_hdf5(set_path, fdt_path=fdt_path, debug=debug)
288
+ return x_tc, fs
289
+
290
+ except Exception as e_hdf5:
291
+ # すべて失敗した場合
292
+ msg = (
293
+ "EEGLABの読み込みに失敗しました。\n"
294
+ f"- read_raw_eeglab error: {e_raw}\n"
295
+ f"- read_epochs_eeglab error: {e_ep}\n"
296
+ f"- HDF5読み込み error: {e_hdf5}\n"
297
+ )
298
+ raise RuntimeError(msg) from e_hdf5
299
+
300
+
301
+
302
+ # ============================================================
303
+ # MAT loader (.mat)
304
+ # ============================================================
305
+ def _mat_keys_loadmat(mat_dict: Dict[str, Any]) -> List[str]:
306
+ return sorted([k for k in mat_dict.keys() if not k.startswith("__")])
307
+
308
+
309
+ def _try_get_numeric_arrays_loadmat(mat_dict: Dict[str, Any]) -> Dict[str, np.ndarray]:
310
+ """
311
+ loadmatで読んだdictから、1D/2Dの数値ndarrayだけ抽出して返す。
312
+ 3次元配列も含める(エポックデータの可能性)。
313
+ """
314
+ out: Dict[str, np.ndarray] = {}
315
+ for k in _mat_keys_loadmat(mat_dict):
316
+ v = mat_dict[k]
317
+ if isinstance(v, np.ndarray) and v.size > 0:
318
+ # 数値型かどうかチェック
319
+ if np.issubdtype(v.dtype, np.number):
320
+ if v.ndim in (1, 2):
321
+ out[k] = v
322
+ elif v.ndim == 3:
323
+ # 3次元配列の場合は (epochs, channels, time) の可能性
324
+ # 平均を取って2次元にする、または連結する
325
+ out[k + "_mean"] = v.mean(axis=0) # (C, T)
326
+ out[k + "_concat"] = v.reshape(-1, v.shape[-1]) # (epochs*C, T)
327
+ return out
328
+
329
+
330
+ def _load_mat_v72(bytes_data: bytes) -> Dict[str, Any]:
331
+ # v7.2以前のMAT(一般的なMAT)
332
+ return loadmat(io.BytesIO(bytes_data), squeeze_me=False, struct_as_record=False)
333
+
334
+
335
+ def _load_mat_v73_candidates(bytes_data: bytes) -> Dict[str, np.ndarray]:
336
+ """
337
+ v7.3(HDF5)のMATから、数値1D/2D/3D dataset を拾って返す。
338
+ keyは HDF5内のパスになります(例: 'group/data')。
339
+
340
+ 修正: h5pyの新しいバージョンに対応。BytesIOではなく一時ファイルを使用。
341
+ """
342
+ if h5py is None:
343
+ raise RuntimeError("MAT v7.3(HDF5) 形式の可能性がありますが、h5py が入っていません。pip install h5py を実行してください。")
344
+
345
+ out: Dict[str, np.ndarray] = {}
346
+
347
+ # h5pyの新しいバージョンではBytesIOから直接開けない場合があるため、一時ファイルを使用
348
+ with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp:
349
+ tmp.write(bytes_data)
350
+ tmp_path = tmp.name
351
+
352
+ try:
353
+ with h5py.File(tmp_path, "r") as f:
354
+
355
+ def visitor(name, obj):
356
+ if not isinstance(obj, h5py.Dataset):
357
+ return
358
+ try:
359
+ arr = obj[()]
360
+ except Exception:
361
+ return
362
+
363
+ # MATLABの文字列/参照等は除外して、数値だけ
364
+ if isinstance(arr, np.ndarray) and arr.size > 0 and np.issubdtype(arr.dtype, np.number):
365
+ if arr.ndim in (1, 2):
366
+ out[name] = arr
367
+ elif arr.ndim == 3:
368
+ # 3次元配列も含める
369
+ out[name + "_mean"] = arr.mean(axis=0)
370
+ out[name + "_concat"] = arr.reshape(-1, arr.shape[-1])
371
+
372
+ f.visititems(lambda name, obj: visitor(name, obj))
373
+ finally:
374
+ # 一時ファイルを削除
375
+ try:
376
+ os.unlink(tmp_path)
377
+ except Exception:
378
+ pass
379
+
380
+ return out
381
+
382
+
383
+ def load_mat_candidates(bytes_data: bytes) -> Dict[str, np.ndarray]:
384
+ """
385
+ Return dict: variable_name -> ndarray(1D/2D numeric)
386
+ Tries v7.2 (scipy.io.loadmat). If it fails, tries v7.3 (h5py).
387
+ """
388
+ try:
389
+ md = _load_mat_v72(bytes_data)
390
+ cands = _try_get_numeric_arrays_loadmat(md)
391
+ return cands
392
+ except Exception:
393
+ return _load_mat_v73_candidates(bytes_data)
src/preprocess.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # preprocessing.py
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import mne
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class PreprocessConfig:
10
+ fs: float
11
+ f_low: float
12
+ f_high: float
13
+
14
+
15
+ def to_time_channel(x: np.ndarray) -> np.ndarray:
16
+ if x.ndim == 1:
17
+ return x[:, None]
18
+ if x.ndim != 2:
19
+ raise ValueError(f"Expected 1D or 2D array, got {x.shape}")
20
+ T, C = x.shape
21
+ if T <= 256 and C > T:
22
+ x = x.T
23
+ return x
24
+
25
+
26
+ def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
27
+ info = mne.create_info(
28
+ ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
29
+ sfreq=cfg.fs,
30
+ ch_types="eeg",
31
+ )
32
+ raw = mne.io.RawArray(x_tc.T, info, verbose=False)
33
+ raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False)
34
+ return raw_filt.get_data().T
35
+
36
+
37
+ def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
38
+ Xf = np.fft.fft(x_tc, axis=0)
39
+ N = Xf.shape[0]
40
+ h = np.zeros(N)
41
+ if N % 2 == 0:
42
+ h[0] = h[N // 2] = 1
43
+ h[1:N // 2] = 2
44
+ else:
45
+ h[0] = 1
46
+ h[1:(N + 1) // 2] = 2
47
+ env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0))
48
+ return env.astype(np.float32)
49
+
50
+
51
+ def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig):
52
+ x_tc = to_time_channel(x)
53
+ x_filt = bandpass_tc(x_tc, cfg)
54
+ env = hilbert_envelope_tc(x_filt)
55
+ return {
56
+ "raw": x_tc,
57
+ "filtered": x_filt,
58
+ "envelope": env,
59
+ }
src/streamlit_app.py CHANGED
@@ -1,40 +1,856 @@
1
- import altair as alt
 
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import io
2
+ from dataclasses import dataclass
3
+ from typing import List, Tuple
4
+
5
  import numpy as np
 
6
  import streamlit as st
7
+ import plotly.graph_objects as go
8
+ import mne
9
+ from scipy.signal import hilbert
10
+
11
+ try:
12
+ import community as community_louvain
13
+ import networkx as nx
14
+ LOUVAIN_AVAILABLE = True
15
+ except ImportError:
16
+ LOUVAIN_AVAILABLE = False
17
+ st.warning("⚠️ Louvainクラスタリングを使用するには `pip install python-louvain networkx` を実行してください。")
18
+
19
+ from loader import (
20
+ pick_set_fdt,
21
+ load_eeglab_tc_from_bytes,
22
+ load_mat_candidates,
23
+ )
24
+
25
+ st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide")
26
+
27
+
28
+ # ============================================================
29
+ # Preprocess config
30
+ # ============================================================
31
+ @dataclass(frozen=True)
32
+ class PreprocessConfig:
33
+ fs: float
34
+ f_low: float
35
+ f_high: float
36
+
37
+
38
+ # ============================================================
39
+ # Helpers
40
+ # ============================================================
41
+ def ensure_tc(x: np.ndarray) -> np.ndarray:
42
+ """Ensure array is (T,C). Accept (T,), (T,C), (C,T) with heuristic transpose."""
43
+ x = np.asarray(x)
44
+ if x.ndim == 1:
45
+ return x[:, None]
46
+ if x.ndim != 2:
47
+ raise ValueError(f"2次元配列のみ対応です: shape={x.shape}")
48
+ T, C = x.shape
49
+ if T <= 256 and C > T: # heuristic transpose
50
+ x = x.T
51
+ return x
52
+
53
+
54
+ # ============================================================
55
+ # Signal processing
56
+ # ============================================================
57
+ def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
58
+ """Bandpass filter each channel using MNE RawArray. Input/Output: (T,C)."""
59
+ info = mne.create_info(
60
+ ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
61
+ sfreq=float(cfg.fs),
62
+ ch_types="eeg",
63
+ )
64
+ raw = mne.io.RawArray(x_tc.T, info, verbose=False) # (C,T)
65
+ raw_filt = raw.copy().filter(l_freq=cfg.f_low, h_freq=cfg.f_high, verbose=False)
66
+ return raw_filt.get_data().T.astype(np.float32)
67
+
68
+
69
+ def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
70
+ """Hilbert envelope per channel using SciPy. Input/Output: (T,C)."""
71
+ analytic = hilbert(x_tc, axis=0)
72
+ return np.abs(analytic).astype(np.float32)
73
+
74
+
75
+ def hilbert_phase_tc(x_tc: np.ndarray) -> np.ndarray:
76
+ """Hilbert phase per channel using SciPy. Input/Output: (T,C)."""
77
+ analytic = hilbert(x_tc, axis=0)
78
+ return np.angle(analytic).astype(np.float32)
79
+
80
+
81
+ def preprocess_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> dict:
82
+ """raw(T,C) -> filtered/envelope/phase をまとめて返す"""
83
+ x_tc = ensure_tc(x_tc).astype(np.float32)
84
+ x_filt = bandpass_tc(x_tc, cfg)
85
+ env = hilbert_envelope_tc(x_filt)
86
+ phase = hilbert_phase_tc(x_filt)
87
+ return {
88
+ "fs": float(cfg.fs),
89
+ "raw": x_tc,
90
+ "filtered": x_filt,
91
+ "envelope": env,
92
+ "amplitude": env, # envelope のエイリアス
93
+ "phase": phase
94
+ }
95
+
96
+
97
+ @st.cache_data(show_spinner=False)
98
+ def preprocess_all_eeglab(
99
+ set_bytes: bytes,
100
+ fdt_bytes: bytes,
101
+ set_name: str,
102
+ fdt_name: str,
103
+ f_low: float,
104
+ f_high: float,
105
+ ) -> dict:
106
+ """
107
+ EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
108
+ fsは読み込んだデータのものを使う。
109
+ """
110
+ x_tc, fs = load_eeglab_tc_from_bytes(
111
+ set_bytes=set_bytes,
112
+ set_name=set_name,
113
+ fdt_bytes=fdt_bytes,
114
+ fdt_name=fdt_name,
115
+ )
116
+ cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high))
117
+ return preprocess_tc(x_tc, cfg)
118
+
119
+
120
+ @st.cache_data(show_spinner=False)
121
+ def load_mat_candidates_cached(mat_bytes: bytes) -> dict:
122
+ """MAT candidatesをキャッシュ(UI操作で毎回読まない)"""
123
+ return load_mat_candidates(mat_bytes)
124
+
125
+
126
+ # ============================================================
127
+ # Viewer
128
+ # ============================================================
129
+ def window_slice(X_tc: np.ndarray, start_idx: int, end_idx: int, decim: int) -> np.ndarray:
130
+ start_idx = max(0, min(start_idx, X_tc.shape[0] - 1))
131
+ end_idx = max(start_idx + 1, min(end_idx, X_tc.shape[0]))
132
+ decim = max(1, int(decim))
133
+ return X_tc[start_idx:end_idx:decim, :]
134
+
135
+
136
+ def make_timeseries_figure(
137
+ X_tc: np.ndarray,
138
+ selected_channels: List[int],
139
+ fs: float,
140
+ start_sec: float,
141
+ win_sec: float,
142
+ decim: int,
143
+ offset_mode: bool,
144
+ show_rangeslider: bool,
145
+ signal_type: str = "filtered",
146
+ ) -> go.Figure:
147
+ start_idx = int(round(start_sec * fs))
148
+ end_idx = int(round((start_sec + win_sec) * fs))
149
+
150
+ Xw = window_slice(X_tc, start_idx, end_idx, decim)
151
+ Tw = Xw.shape[0]
152
+ t = (np.arange(Tw) * decim + start_idx) / fs
153
+
154
+ fig = go.Figure()
155
+
156
+ if not selected_channels:
157
+ fig.update_layout(
158
+ title="Timeseries (no channel selected)",
159
+ height=450,
160
+ xaxis_title="time (s)",
161
+ yaxis_title="amplitude",
162
+ )
163
+ return fig
164
+
165
+ # 位相データの場合は特別な処理
166
+ is_phase = signal_type == "phase"
167
+
168
+ if offset_mode and len(selected_channels) > 1 and not is_phase:
169
+ per_ch_std = np.std(Xw[:, selected_channels], axis=0)
170
+ base = float(np.median(per_ch_std)) if np.isfinite(np.median(per_ch_std)) and np.median(per_ch_std) > 0 else 1.0
171
+ offset = 5.0 * base
172
+
173
+ for k, ch in enumerate(selected_channels):
174
+ y = Xw[:, ch] + k * offset
175
+ fig.add_trace(go.Scatter(x=t, y=y, mode="lines", name=f"ch{ch}", line=dict(width=1)))
176
+ ylab = "amplitude (offset)"
177
+ else:
178
+ for ch in selected_channels:
179
+ fig.add_trace(go.Scatter(x=t, y=Xw[:, ch], mode="lines", name=f"ch{ch}", line=dict(width=1)))
180
+
181
+ if is_phase:
182
+ ylab = "phase (rad)"
183
+ else:
184
+ ylab = "amplitude"
185
+
186
+ # rangeslider の高さを考慮して調整
187
+ plot_height = 550 if show_rangeslider else 450
188
+ bottom_margin = 150 if show_rangeslider else 80
189
+
190
+ title_text = f"Timeseries: {signal_type} (window={win_sec:.2f}s, start={start_sec:.2f}s, decim={decim})"
191
+
192
+ fig.update_layout(
193
+ title=title_text,
194
+ height=plot_height,
195
+ xaxis_title="time (s)",
196
+ yaxis_title=ylab,
197
+ legend=dict(orientation="h"),
198
+ margin=dict(l=60, r=20, t=80, b=bottom_margin),
199
+ )
200
+
201
+ # 位相の場合は y軸の範囲を -π ~ π に固定
202
+ if is_phase:
203
+ fig.update_yaxes(range=[-np.pi - 0.5, np.pi + 0.5])
204
+
205
+ if show_rangeslider:
206
+ fig.update_xaxes(
207
+ rangeslider=dict(
208
+ visible=True,
209
+ thickness=0.05,
210
+ )
211
+ )
212
+ else:
213
+ fig.update_xaxes(rangeslider=dict(visible=False))
214
+
215
+ return fig
216
+
217
+
218
+ # ============================================================
219
+ # Network (multiple methods) + export
220
+ # ============================================================
221
+ def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray:
222
+ """
223
+ Envelope (amplitude) の Pearson 相関係数を計算。
224
+ Input: X_tc (T, C) - envelope データ
225
+ Output: W (C, C) - 相関係数の絶対値
226
+ """
227
+ X = X_tc - X_tc.mean(axis=0, keepdims=True)
228
+ corr = np.corrcoef(X, rowvar=False)
229
+ W = np.abs(corr)
230
+ np.fill_diagonal(W, 0.0)
231
+ return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
232
+
233
+
234
+ def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray:
235
+ """
236
+ Phase の circular correlation (位相同期指標) を計算。
237
+ Input: X_tc (T, C) - phase データ (ラジアン)
238
+ Output: W (C, C) - circular correlation
239
+
240
+ Circular correlation は以下で計算:
241
+ r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t|
242
+ これは Phase Locking Value (PLV) とも呼ばれます。
243
+ """
244
+ T, C = X_tc.shape
245
+ W = np.zeros((C, C), dtype=np.float32)
246
+
247
+ # 各チャンネルペアについて circular correlation を計算
248
+ for i in range(C):
249
+ for j in range(i + 1, C):
250
+ # 位相差
251
+ phase_diff = X_tc[:, i] - X_tc[:, j]
252
+ # PLV: |mean(exp(i*phase_diff))|
253
+ plv = np.abs(np.mean(np.exp(1j * phase_diff)))
254
+ W[i, j] = plv
255
+ W[j, i] = plv
256
+
257
+ return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
258
+
259
+
260
+ def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray:
261
+ """
262
+ ダミー実装: 単純な相関係数の絶対値
263
+ (後方互換性のため残す)
264
+ """
265
+ X = X_tc - X_tc.mean(axis=0, keepdims=True)
266
+ corr = np.corrcoef(X, rowvar=False)
267
+ W = np.abs(corr)
268
+ np.fill_diagonal(W, 0.0)
269
+ return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
270
+
271
+
272
+ def threshold_edges(W: np.ndarray, thr: float) -> List[Tuple[int, int, float]]:
273
+ C = W.shape[0]
274
+ edges: List[Tuple[int, int, float]] = []
275
+ for i in range(C):
276
+ for j in range(i + 1, C):
277
+ w = float(W[i, j])
278
+ if w >= thr:
279
+ edges.append((i, j, w))
280
+ edges.sort(key=lambda x: x[2], reverse=True)
281
+ return edges
282
+
283
+
284
+ def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray:
285
+ if weighted:
286
+ A = W.copy()
287
+ A[A < thr] = 0.0
288
+ np.fill_diagonal(A, 0.0)
289
+ return A
290
+ A = (W >= thr).astype(int)
291
+ np.fill_diagonal(A, 0)
292
+ return A
293
+
294
+
295
+ def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray:
296
+ """
297
+ Louvain法でクラスタリングを実行。
298
+
299
+ Args:
300
+ W: 重み行列 (C, C)
301
+ thr: 閾値(これ以下のエッジは削除)
302
+
303
+ Returns:
304
+ clusters: クラスタID配列 (C,)
305
+ """
306
+ if not LOUVAIN_AVAILABLE:
307
+ # Louvainが使えない場合は全ノードを同じクラスタに
308
+ return np.zeros(W.shape[0], dtype=int)
309
+
310
+ # NetworkXグラフを作成
311
+ G = nx.Graph()
312
+ C = W.shape[0]
313
+ G.add_nodes_from(range(C))
314
+
315
+ # 閾値以上のエッジを追加
316
+ for i in range(C):
317
+ for j in range(i + 1, C):
318
+ if W[i, j] >= thr:
319
+ G.add_edge(i, j, weight=W[i, j])
320
+
321
+ # Louvain法でコミュニティ検出
322
+ partition = community_louvain.best_partition(G, weight='weight')
323
+
324
+ # クラスタIDの配列に変換
325
+ clusters = np.array([partition[i] for i in range(C)])
326
+
327
+ return clusters
328
+
329
+
330
+ def get_cluster_colors(clusters: np.ndarray) -> List[str]:
331
+ """
332
+ クラスタIDから色のリストを生成。
333
+
334
+ Args:
335
+ clusters: クラスタID配列 (C,)
336
+
337
+ Returns:
338
+ colors: 色のリスト
339
+ """
340
+ import colorsys
341
+
342
+ n_clusters = len(np.unique(clusters))
343
+
344
+ # クラスタ数に応じて色相を均等に分割
345
+ colors = []
346
+ for cluster_id in clusters:
347
+ hue = cluster_id / max(n_clusters, 1)
348
+ r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.95)
349
+ colors.append(f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})')
350
+
351
+ return colors
352
+
353
+
354
+ def make_network_figure(W: np.ndarray, thr: float, use_louvain: bool = True) -> tuple[go.Figure, int]:
355
+ C = W.shape[0]
356
+ angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
357
+ xs = np.cos(angles)
358
+ ys = np.sin(angles)
359
+
360
+ edges = threshold_edges(W, thr)
361
+ fig = go.Figure()
362
+
363
+ # エッジの重みの範囲を取得(色と太さのスケーリング用)
364
+ if edges:
365
+ weights = [w for _, _, w in edges]
366
+ min_w = min(weights)
367
+ max_w = max(weights)
368
+ weight_range = max_w - min_w if max_w > min_w else 1.0
369
+ else:
370
+ min_w = 0
371
+ max_w = 1
372
+ weight_range = 1.0
373
+
374
+ # レインボーカラーマップ関数 (0=青 → 0.5=緑/黄 → 1=赤)
375
+ def get_rainbow_color(norm_val):
376
+ """正規化された値 (0-1) からレインボーカラーを生成"""
377
+ import colorsys
378
+ # HSVのHue: 240°(青) → 0°(赤) に変換
379
+ hue = (1.0 - norm_val) * 0.67 # 0.67 ≈ 240/360 (青)
380
+ r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
381
+ return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
382
+
383
+ # エッジを描画(重みに応じて色と太さを変える)
384
+ for (i, j, w) in edges:
385
+ # 正規化された重み (0-1)
386
+ norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
387
+
388
+ # レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤)
389
+ color = get_rainbow_color(norm_w)
390
+
391
+ # 太さ: 重みに比例 (0.5-4の範囲)
392
+ line_width = 0.5 + 3.5 * norm_w
393
+
394
+ fig.add_trace(
395
+ go.Scatter(
396
+ x=[xs[i], xs[j]],
397
+ y=[ys[i], ys[j]],
398
+ mode="lines",
399
+ hoverinfo="text",
400
+ hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}",
401
+ line=dict(width=line_width, color=color),
402
+ showlegend=False,
403
+ )
404
+ )
405
+
406
+ # Louvainクラスタリング
407
+ if use_louvain and LOUVAIN_AVAILABLE:
408
+ clusters = compute_louvain_clusters(W, thr)
409
+ node_colors = get_cluster_colors(clusters)
410
+ n_clusters = len(np.unique(clusters))
411
+ title_suffix = f" | Louvain clusters: {n_clusters}"
412
+ else:
413
+ node_colors = ['#FFD700'] * C # デフォルトのゴールド
414
+ clusters = np.zeros(C, dtype=int)
415
+ title_suffix = ""
416
+
417
+ # ノードを描画
418
+ fig.add_trace(
419
+ go.Scatter(
420
+ x=xs,
421
+ y=ys,
422
+ mode="markers+text",
423
+ text=[f"{k}" for k in range(C)],
424
+ textposition="bottom center",
425
+ marker=dict(
426
+ size=14,
427
+ color=node_colors,
428
+ line=dict(width=2, color='white')
429
+ ),
430
+ hoverinfo="text",
431
+ hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
432
+ showlegend=False,
433
+ )
434
+ )
435
+
436
+ fig.update_layout(
437
+ title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
438
+ height=500,
439
+ xaxis=dict(visible=False),
440
+ yaxis=dict(visible=False),
441
+ margin=dict(l=10, r=10, t=50, b=10),
442
+ paper_bgcolor='rgba(0,0,0,0.9)',
443
+ plot_bgcolor='rgba(0,0,0,0.9)',
444
+ )
445
+ fig.update_yaxes(scaleanchor="x", scaleratio=1)
446
+
447
+ # カラーバー的な説明を追加
448
+ if edges:
449
+ fig.add_annotation(
450
+ text=f"Edge color/width: weak (blue/thin) → medium (green/yellow) → strong (red/thick)<br>Weight range: {min_w:.3f} - {max_w:.3f}",
451
+ xref="paper", yref="paper",
452
+ x=0.5, y=-0.05,
453
+ showarrow=False,
454
+ font=dict(size=10, color='white'),
455
+ xanchor='center',
456
+ )
457
+
458
+ return fig, len(edges)
459
+
460
+
461
+ def make_edgecount_curve(W: np.ndarray) -> go.Figure:
462
+ vals = np.sort(W[np.triu_indices(W.shape[0], k=1)])
463
+ thr_grid = np.linspace(float(vals.max()), float(vals.min()), 120) if vals.size else np.array([0.0])
464
+ counts = [len(threshold_edges(W, float(thr))) for thr in thr_grid]
465
+
466
+ fig = go.Figure()
467
+ fig.add_trace(go.Scatter(x=thr_grid, y=counts, mode="lines"))
468
+ fig.update_layout(
469
+ title="Edge count vs threshold (lower thr => more edges)",
470
+ xaxis_title="threshold",
471
+ yaxis_title="edge count",
472
+ height=300,
473
+ )
474
+ return fig
475
+
476
+
477
+ def to_csv_bytes_matrix(mat: np.ndarray, fmt: str) -> bytes:
478
+ buf = io.StringIO()
479
+ np.savetxt(buf, mat, delimiter=",", fmt=fmt)
480
+ return buf.getvalue().encode("utf-8")
481
+
482
+
483
+ def to_csv_bytes_edges(edges: List[Tuple[int, int, float]]) -> bytes:
484
+ buf = io.StringIO()
485
+ buf.write("source,target,weight\n")
486
+ for i, j, w in edges:
487
+ buf.write(f"{i},{j},{w:.6f}\n")
488
+ return buf.getvalue().encode("utf-8")
489
+
490
+
491
+ # ============================================================
492
+ # Sidebar UI
493
+ # ============================================================
494
+ st.sidebar.header("Input format")
495
+ input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0)
496
+
497
+ st.sidebar.header("Preprocess (auto)")
498
+ f_low = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=8.0, step=0.5)
499
+ f_high = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=12.0, step=0.5)
500
+
501
+ st.sidebar.header("Viewer controls")
502
+ win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1)
503
+ decim = st.sidebar.selectbox("Decimation (間引き)", options=[1, 2, 5, 10, 20, 50], index=1)
504
+ offset_mode = st.sidebar.checkbox("重ね描画のオフセット表示", value=True)
505
+ show_rangeslider = st.sidebar.checkbox("Plotly rangesliderを表示", value=False)
506
+ signal_view = st.sidebar.radio(
507
+ "表示する信号",
508
+ ["raw", "filtered", "amplitude", "phase"],
509
+ index=1,
510
+ help="raw: 生信号, filtered: バンドパス後, amplitude: Hilbert振幅(envelope), phase: Hilbert位相"
511
+ )
512
+
513
+ st.title("EEG timeseries viewer + network estimation")
514
+
515
+
516
+ # ============================================================
517
+ # Load + preprocess (EEGLAB / MAT)
518
+ # ============================================================
519
+ if input_mode.startswith("EEGLAB"):
520
+ st.sidebar.header("Upload (.set + .fdt)")
521
+ uploaded_files = st.sidebar.file_uploader(
522
+ "Upload EEGLAB files",
523
+ type=["set", "fdt"],
524
+ accept_multiple_files=True,
525
+ )
526
+
527
+ if uploaded_files:
528
+ set_file, fdt_file = pick_set_fdt(uploaded_files)
529
+ if set_file is None or fdt_file is None:
530
+ st.warning("`.set` と `.fdt` の両方をアップロードしてください。")
531
+ else:
532
+ try:
533
+ with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."):
534
+ prep = preprocess_all_eeglab(
535
+ set_bytes=set_file.getvalue(),
536
+ fdt_bytes=fdt_file.getvalue(),
537
+ set_name=set_file.name,
538
+ fdt_name=fdt_file.name,
539
+ f_low=float(f_low),
540
+ f_high=float(f_high),
541
+ )
542
+ st.session_state["prep"] = prep
543
+ st.session_state["W"] = None
544
+ st.success(f"Loaded & preprocessed. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
545
+ except Exception as e:
546
+ st.session_state.pop("prep", None)
547
+ st.session_state["W"] = None
548
+ st.error(f"読み込み/前処理エラー: {e}")
549
+
550
+ else:
551
+ st.sidebar.header("Upload (.mat)")
552
+ mat_file = st.sidebar.file_uploader("Upload .mat", type=["mat"])
553
+
554
+ if mat_file is not None:
555
+ mat_bytes = mat_file.getvalue()
556
+ try:
557
+ cands = load_mat_candidates_cached(mat_bytes)
558
+ if not cands:
559
+ st.error("数値の1D/2D配列が見つかりませんでした。")
560
+ st.info("MATファイルの構造を確認しています...")
561
+
562
+ # デバッグ: MATファイルの中身を表示
563
+ try:
564
+ from scipy.io import loadmat
565
+ mat_data = loadmat(io.BytesIO(mat_bytes))
566
+ st.write("**MATファイルに含まれる変数:**")
567
+ for k, v in mat_data.items():
568
+ if not k.startswith('__'):
569
+ if isinstance(v, np.ndarray):
570
+ st.write(f"- `{k}`: shape={v.shape}, dtype={v.dtype}, ndim={v.ndim}")
571
+ else:
572
+ st.write(f"- `{k}`: type={type(v).__name__}")
573
+ except Exception as e:
574
+ st.write(f"デバッグ情報の取得に失敗: {e}")
575
+
576
+ # HDF5形式の場合も試す
577
+ try:
578
+ import h5py
579
+ import tempfile
580
+ with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp:
581
+ tmp.write(mat_bytes)
582
+ tmp_path = tmp.name
583
+
584
+ st.write("**HDF5形式として読み込み中...**")
585
+ with h5py.File(tmp_path, 'r') as f:
586
+ def show_structure(name, obj):
587
+ if isinstance(obj, h5py.Dataset):
588
+ st.write(f"- `{name}`: shape={obj.shape}, dtype={obj.dtype}")
589
+ f.visititems(show_structure)
590
+
591
+ import os
592
+ os.unlink(tmp_path)
593
+ except Exception as e2:
594
+ st.write(f"HDF5としても読み込めませんでした: {e2}")
595
+ else:
596
+ key = st.sidebar.selectbox("EEG配列(変数)を選択", options=list(cands.keys()))
597
+ fs_mat = st.sidebar.number_input("Sampling rate (Hz)", min_value=0.1, value=256.0, step=0.1)
598
+
599
+ # 変数が選択されたら自動的に前処理を実行
600
+ if key:
601
+ x = cands[key]
602
+ st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}")
603
+ try:
604
+ with st.spinner("Preprocessing (bandpass + hilbert)..."):
605
+ cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low), f_high=float(f_high))
606
+ prep = preprocess_tc(x, cfg)
607
+
608
+ st.session_state["prep"] = prep
609
+ st.session_state["W"] = None
610
+ st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
611
+ except Exception as e:
612
+ st.session_state.pop("prep", None)
613
+ st.session_state["W"] = None
614
+ st.error(f"前処理エラー: {e}")
615
+ import traceback
616
+ st.code(traceback.format_exc())
617
+ except Exception as e:
618
+ st.session_state.pop("prep", None)
619
+ st.session_state["W"] = None
620
+ st.error(f".mat 読み込みエラー: {e}")
621
+ import traceback
622
+ st.code(traceback.format_exc())
623
+
624
+
625
+ if "prep" not in st.session_state:
626
+ st.info("左のサイドバーからデータをアップロードしてください。")
627
+ st.stop()
628
+
629
+
630
+ # ============================================================
631
+ # Viewer
632
+ # ============================================================
633
+ prep = st.session_state["prep"]
634
+ fs = float(prep["fs"])
635
+ X_tc = prep[signal_view]
636
+ T, C = X_tc.shape
637
+
638
+ duration_sec = (T - 1) / fs if T > 1 else 0.0
639
+ max_start = max(0.0, float(duration_sec - win_sec))
640
+
641
+ start_sec = st.sidebar.slider(
642
+ "Start time (sec)",
643
+ min_value=0.0,
644
+ max_value=float(max_start),
645
+ value=0.0,
646
+ step=float(max(0.01, win_sec / 200)),
647
+ )
648
+
649
+ st.sidebar.header("Channels")
650
+
651
+ # チャンネル選択の便利機能
652
+ col_ch1, col_ch2 = st.sidebar.columns(2)
653
+ with col_ch1:
654
+ select_all = st.button("全選択")
655
+ with col_ch2:
656
+ deselect_all = st.button("全解除")
657
+
658
+ # 範囲選択
659
+ with st.sidebar.expander("📊 範囲で選択"):
660
+ range_start = st.number_input("開始ch", min_value=0, max_value=C-1, value=0, step=1)
661
+ range_end = st.number_input("終了ch", min_value=0, max_value=C-1, value=min(C-1, 7), step=1)
662
+ if st.button("範囲を選択"):
663
+ st.session_state["selected_channels"] = list(range(int(range_start), int(range_end) + 1))
664
+
665
+ # プリセット選択
666
+ with st.sidebar.expander("⚡ プリセット"):
667
+ preset_col1, preset_col2 = st.columns(2)
668
+ with preset_col1:
669
+ if st.button("前頭部 (0-15)"):
670
+ st.session_state["selected_channels"] = list(range(min(16, C)))
671
+ with preset_col2:
672
+ if st.button("頭頂部 (16-31)"):
673
+ st.session_state["selected_channels"] = list(range(16, min(32, C)))
674
+ preset_col3, preset_col4 = st.columns(2)
675
+ with preset_col3:
676
+ if st.button("側頭部 (32-47)"):
677
+ st.session_state["selected_channels"] = list(range(32, min(48, C)))
678
+ with preset_col4:
679
+ if st.button("後頭部 (48-63)"):
680
+ st.session_state["selected_channels"] = list(range(48, min(64, C)))
681
+
682
+ # セッションステートの初期化
683
+ if "selected_channels" not in st.session_state:
684
+ st.session_state["selected_channels"] = list(range(min(C, 8)))
685
+
686
+ # ボタンによる選択の処理
687
+ if select_all:
688
+ st.session_state["selected_channels"] = list(range(C))
689
+ if deselect_all:
690
+ st.session_state["selected_channels"] = []
691
+
692
+ # メインの選択UI(最大表示数を制限)
693
+ max_display = 20 # multiselect で一度に表示する数を制限
694
+ if C <= max_display:
695
+ selected_channels = st.sidebar.multiselect(
696
+ f"表示するチャンネル(全{C}ch)",
697
+ options=list(range(C)),
698
+ default=st.session_state["selected_channels"],
699
+ key="ch_select",
700
+ )
701
+ else:
702
+ # 大量のチャンネルがある場合は、選択済みのものだけ表示
703
+ st.sidebar.caption(f"選択中: {len(st.session_state['selected_channels'])} / {C} channels")
704
+
705
+ # 個別追加
706
+ add_ch = st.sidebar.number_input(
707
+ "チャンネルを追加",
708
+ min_value=0,
709
+ max_value=C-1,
710
+ value=0,
711
+ step=1,
712
+ key="add_ch_input"
713
+ )
714
+ col_add, col_remove = st.sidebar.columns(2)
715
+ with col_add:
716
+ if st.button("➕ 追加"):
717
+ if add_ch not in st.session_state["selected_channels"]:
718
+ st.session_state["selected_channels"].append(int(add_ch))
719
+ st.session_state["selected_channels"].sort()
720
+ with col_remove:
721
+ if st.button("➖ 削除"):
722
+ if add_ch in st.session_state["selected_channels"]:
723
+ st.session_state["selected_channels"].remove(int(add_ch))
724
+
725
+ # 現在の選択を表示
726
+ if st.session_state["selected_channels"]:
727
+ selected_str = ", ".join(map(str, st.session_state["selected_channels"][:10]))
728
+ if len(st.session_state["selected_channels"]) > 10:
729
+ selected_str += f", ... (+{len(st.session_state['selected_channels']) - 10})"
730
+ st.sidebar.text(f"選択済み: {selected_str}")
731
+
732
+ selected_channels = st.session_state["selected_channels"]
733
+
734
+ # セッションステートを更新(multiselectを使った場合)
735
+ if C <= max_display:
736
+ st.session_state["selected_channels"] = selected_channels
737
+
738
+ col1, col2 = st.columns([2, 1])
739
+ with col1:
740
+ fig_ts = make_timeseries_figure(
741
+ X_tc=X_tc,
742
+ selected_channels=selected_channels,
743
+ fs=fs,
744
+ start_sec=float(start_sec),
745
+ win_sec=float(win_sec),
746
+ decim=int(decim),
747
+ offset_mode=bool(offset_mode),
748
+ show_rangeslider=bool(show_rangeslider),
749
+ signal_type=signal_view,
750
+ )
751
+ st.plotly_chart(fig_ts)
752
+
753
+ with col2:
754
+ st.subheader("Data info")
755
+ signal_desc = {
756
+ "raw": "生信号(前処理なし)",
757
+ "filtered": f"バンドパスフィルタ後 ({f_low}-{f_high} Hz)",
758
+ "amplitude": "Hilbert振幅 (envelope)",
759
+ "phase": "Hilbert位相 (-π ~ π)"
760
+ }
761
+ st.write(f"- view: **{signal_view}** ({signal_desc.get(signal_view, '')})")
762
+ st.write(f"- fs: **{fs:.2f} Hz**")
763
+ st.write(f"- T: {T} samples")
764
+ st.write(f"- C: {C} channels")
765
+ st.write(f"- duration: {duration_sec:.2f} sec")
766
+
767
+ if signal_view == "phase":
768
+ st.caption("※ 位相は -π (rad) から π (rad) の範囲で表示されます")
769
+
770
+ st.caption("※ 大規模データは window + decimation 推奨。rangesliderは重い場合OFF。")
771
+
772
+ st.divider()
773
+
774
+
775
+ # ============================================================
776
+ # Estimation
777
+ # ============================================================
778
+ st.subheader("Network estimation")
779
+
780
+ # 推定手法の選択
781
+ estimation_method = st.radio(
782
+ "推定手法を選択",
783
+ options=[
784
+ "envelope_corr",
785
+ "phase_corr",
786
+ ],
787
+ format_func=lambda x: {
788
+ "envelope_corr": "Envelope correlation (振幅の相関)",
789
+ "phase_corr": "Phase circular correlation (位相同期, PLV)",
790
+ }[x],
791
+ horizontal=True,
792
+ help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_corr: 位相の circular correlation (Phase Locking Value)",
793
+ )
794
+
795
+ # 推定手法の説明
796
+ method_info = {
797
+ "envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。",
798
+ "phase_corr": "**Phase circular correlation (PLV)**: 位相間の circular correlation を計算します。Phase Locking Value (PLV) とも呼ばれ、位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
799
+ }
800
+ st.info(method_info[estimation_method])
801
+
802
+ # セッションステートから前回の手法と W を取得
803
+ last_method = st.session_state.get("last_estimation_method")
804
+ W = st.session_state.get("W")
805
+
806
+ # 推定が必要かチェック(初回 or 手法変更)
807
+ need_estimation = (W is None) or (last_method != estimation_method)
808
+
809
+ if need_estimation:
810
+ with st.spinner(f"推定中... ({estimation_method})"):
811
+ if estimation_method == "envelope_corr":
812
+ X_in = prep["amplitude"]
813
+ W = estimate_network_envelope_corr(X_in)
814
+ elif estimation_method == "phase_corr":
815
+ X_in = prep["phase"]
816
+ W = estimate_network_phase_corr(X_in)
817
+ else:
818
+ st.error("未知の推定手法です")
819
+ st.stop()
820
+
821
+ # セッションステートに保存
822
+ st.session_state["W"] = W
823
+ st.session_state["last_estimation_method"] = estimation_method
824
+ st.success(f"✅ 推定完了: {estimation_method} (ネットワークサイズ: {W.shape[0]} nodes)")
825
+ else:
826
+ st.success(f"✓ 推定済み: **{estimation_method}** (ネットワークサイズ: {W.shape[0]} nodes)")
827
+
828
+ # この時点で W は必ず存在する
829
+ # 閾値スライダーとネットワーク図の表示
830
+ wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
831
+
832
+ col_thr1, col_thr2 = st.columns([3, 1])
833
+ with col_thr1:
834
+ thr = st.slider(
835
+ "閾値 (threshold) ※下げるほどエッジが増えます",
836
+ min_value=0.0,
837
+ max_value=max(0.0001, wmax),
838
+ value=min(0.5, wmax),
839
+ step=max(wmax / 200, 0.001),
840
+ )
841
+ with col_thr2:
842
+ use_louvain = st.checkbox(
843
+ "Louvainクラスタリング",
844
+ value=True,
845
+ disabled=not LOUVAIN_AVAILABLE,
846
+ help="ノードの色をコミュニティ検出結果で塗り分けます"
847
+ )
848
+
849
+ net_col1, net_col2 = st.columns([2, 1])
850
+ with net_col1:
851
+ fig_net, edge_n = make_network_figure(W, float(thr), use_louvain=use_louvain)
852
+ st.plotly_chart(fig_net)
853
 
854
+ with net_col2:
855
+ st.metric("Edges", edge_n)
856
+ st.plotly_chart(make_edgecount_curve(W))