PhysioJEPA / scripts /e0_alignment_check.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Validate alignment using PPG foot (onset) rather than systolic peak.
Foot = minimum between two consecutive systolic peaks. This is the feature
that physiologically corresponds to the pulse arrival time. Using it should
collapse the bimodal PTT distribution.
"""
from __future__ import annotations
import json
import os
import random
import re
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from scipy.signal import butter, filtfilt, find_peaks
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
OUT = Path(__file__).resolve().parent.parent / "docs"
FIG = OUT / "figures"
FIG.mkdir(parents=True, exist_ok=True)
RNG = random.Random(7)
def bandpass(x, fs, lo, hi, order=3):
ny = 0.5 * fs
b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
return filtfilt(b, a, x, method="gust")
def r_peaks(ecg, fs):
x = bandpass(ecg, fs, 5.0, 15.0)
s = np.diff(x, prepend=x[:1]) ** 2
w = max(int(0.12 * fs), 1)
mwa = np.convolve(s, np.ones(w) / w, mode="same")
thr = mwa.mean() + 0.5 * mwa.std()
p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs))
snap = max(int(0.06 * fs), 1)
return np.asarray(
[max(0, q - snap) + int(np.argmax(x[max(0, q - snap) : min(len(x), q + snap)])) for q in p]
)
def ppg_feet(ppg, fs):
"""Detect PPG foot via zero-crossing of filtered first derivative going pos, gated by peaks."""
x = bandpass(ppg, fs, 0.5, 8.0)
# find systolic peaks first
peaks, _ = find_peaks(
x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std()
)
feet = []
for i in range(1, len(peaks)):
lo, hi = peaks[i - 1], peaks[i]
# foot = local minimum between peaks
feet.append(lo + int(np.argmin(x[lo:hi])))
return np.asarray(feet, dtype=int)
def clean_ptts_via_foot(ecg, ecg_fs, ppg, ppg_fs, t0e, t0p):
r = r_peaks(ecg, ecg_fs)
f = ppg_feet(ppg, ppg_fs)
if len(r) < 3 or len(f) < 3:
return []
r_t = t0e + r / ecg_fs
f_t = t0p + f / ppg_fs
out = []
for rt in r_t:
cand = f_t[(f_t >= rt + 0.050) & (f_t <= rt + 0.500)]
if len(cand) == 1:
out.append((cand[0] - rt) * 1000.0)
return out
def main():
want = list(range(0, 412, 20))
root = Path(
snapshot_download(
REPO,
repo_type="dataset",
allow_patterns=[f"shard_{i:05d}/*" for i in want],
max_workers=12,
)
)
shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()]
all_ptts = []
stds = []
good = 0
for sidx in shards:
if good >= 100:
break
ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
for i in range(min(len(ds), 30)):
if good >= 100:
break
row = ds[i]
ecg = np.asarray(row["ecg"], dtype=np.float32)
ppg = np.asarray(row["ppg"], dtype=np.float32)
names = list(row["ecg_names"])
if "II" not in names:
continue
lead = ecg[names.index("II")]
ptts = clean_ptts_via_foot(
lead,
float(row["ecg_fs"]),
ppg[0],
float(row["ppg_fs"]),
float(row["ecg_time_s"][0]),
float(row["ppg_time_s"][0]),
)
if len(ptts) >= 5:
all_ptts.extend(ptts)
stds.append(float(np.std(ptts)))
good += 1
res = {
"n_clean_beats": len(all_ptts),
"n_good_segments": good,
"ptt_foot_median_ms": float(np.median(all_ptts)),
"ptt_foot_p5_ms": float(np.percentile(all_ptts, 5)),
"ptt_foot_p95_ms": float(np.percentile(all_ptts, 95)),
"within_segment_std_median_ms": float(np.median(stds)),
"within_segment_std_p90_ms": float(np.percentile(stds, 90)),
}
plt.figure(figsize=(7, 4))
plt.hist(all_ptts, bins=60, color="#36a", edgecolor="black")
plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms")
plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms")
plt.xlabel("PTT (ECG R-peak → PPG foot) (ms)")
plt.ylabel("count")
plt.title(f"PTT via PPG foot — {len(all_ptts)} beats, {good} segments")
plt.legend()
plt.tight_layout()
plt.savefig(FIG / "ptt_histogram_foot.png", dpi=120)
plt.close()
(OUT / "e0_alignment.json").write_text(json.dumps(res, indent=2))
print(json.dumps(res, indent=2))
if __name__ == "__main__":
main()