eugenehp commited on
Commit
c664669
·
verified ·
1 Parent(s): 7f78cf4

Upload Brain-JEPA model card, weights, and gradient mapping

Browse files
Files changed (4) hide show
  1. README.md +176 -0
  2. benchmark.png +0 -0
  3. brainjepa.safetensors +3 -0
  4. gradient_mapping_450.csv +0 -0
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - fmri
7
+ - neuroscience
8
+ - brain
9
+ - foundation-model
10
+ - vision-transformer
11
+ - jepa
12
+ - burn
13
+ - rust
14
+ datasets:
15
+ - ukbiobank
16
+ pipeline_tag: feature-extraction
17
+ library_name: brainjepa-rs
18
+ ---
19
+
20
+ # Brain-JEPA (safetensors)
21
+
22
+ Pretrained weights for **Brain-JEPA** (NeurIPS 2024, Spotlight) converted to safetensors format for use with [brainjepa-rs](https://github.com/eugenehp/brainjepa-rs).
23
+
24
+ ## Model description
25
+
26
+ Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with:
27
+
28
+ - **Brain gradient positioning** for spatial (ROI) embeddings
29
+ - **Temporal patch embedding** via 1D convolution along time
30
+ - **JEPA architecture** (Joint Embedding Predictive Architecture)
31
+
32
+ The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs.
33
+
34
+ ## Files
35
+
36
+ | File | Description | Shape info |
37
+ |---|---|---|
38
+ | `brainjepa.safetensors` | All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB |
39
+ | `gradient_mapping_450.csv` | Brain gradient coordinates for positional embeddings | 450 rows x 30 columns |
40
+
41
+ ### Weight key structure
42
+
43
+ Keys are prefixed by component (`encoder.`, `predictor.`, `target_encoder.`):
44
+
45
+ ```
46
+ encoder.patch_embed.proj.weight [768, 1, 1, 16]
47
+ encoder.blocks.{i}.norm1.weight [768]
48
+ encoder.blocks.{i}.attn.qkv.weight [2304, 768]
49
+ encoder.blocks.{i}.attn.proj.weight [768, 768]
50
+ encoder.blocks.{i}.mlp.fc1.weight [3072, 768]
51
+ encoder.blocks.{i}.mlp.fc2.weight [768, 3072]
52
+ encoder.norm.weight [768]
53
+ ...
54
+ ```
55
+
56
+ For inference, use `target_encoder.*` keys (EMA-smoothed weights from pretraining).
57
+
58
+ ## Usage with brainjepa-rs (Rust)
59
+
60
+ ```sh
61
+ # Install
62
+ git clone https://github.com/eugenehp/brainjepa-rs
63
+ cd brainjepa-rs
64
+
65
+ # Download weights from this repo
66
+ # Place brainjepa.safetensors and gradient_mapping_450.csv in data/
67
+
68
+ # Run inference (CPU)
69
+ cargo run --release --bin infer -- \
70
+ --weights data/brainjepa.safetensors \
71
+ --gradient data/gradient_mapping_450.csv \
72
+ --input data/fmri_sample.safetensors
73
+
74
+ # Run inference (GPU, Metal/Vulkan)
75
+ cargo run --release --no-default-features --features wgpu --bin infer -- \
76
+ --weights data/brainjepa.safetensors \
77
+ --gradient data/gradient_mapping_450.csv \
78
+ --input data/fmri_sample.safetensors
79
+ ```
80
+
81
+ ### Rust library
82
+
83
+ ```rust
84
+ use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig};
85
+
86
+ let (encoder, _) = BrainJepaEncoder::<B>::from_weights(
87
+ "data/brainjepa.safetensors",
88
+ "data/gradient_mapping_450.csv",
89
+ &ModelConfig::default(),
90
+ &DataConfig::default(),
91
+ &device,
92
+ )?;
93
+ let result = encoder.encode_safetensors("data/fmri.safetensors")?;
94
+ // result.embeddings: [4500, 768] float32
95
+ ```
96
+
97
+ ## Usage with original Python code
98
+
99
+ These weights were converted from the original PyTorch checkpoint. To use with the original code:
100
+
101
+ ```python
102
+ import torch
103
+ from safetensors.torch import load_file
104
+
105
+ tensors = load_file("brainjepa.safetensors")
106
+ # Filter for target_encoder weights and strip prefix:
107
+ state_dict = {
108
+ k.removeprefix("target_encoder."): v
109
+ for k, v in tensors.items()
110
+ if k.startswith("target_encoder.")
111
+ }
112
+ model.load_state_dict(state_dict)
113
+ ```
114
+
115
+ ## Conversion
116
+
117
+ Weights were converted from the original PyTorch checkpoint using:
118
+
119
+ ```sh
120
+ python scripts/convert_weights.py \
121
+ --input jepa-ep300.pth.tar \
122
+ --output brainjepa.safetensors
123
+ ```
124
+
125
+ The conversion script strips the `module.` prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format.
126
+
127
+ ## Benchmark
128
+
129
+ Tested on Mac Mini M4 Pro (14 cores, 64 GB).
130
+ Input: `[1, 1, 450, 160]` (single sample, ViT-Base 86M params). Best-of-3 encode time.
131
+
132
+ | Backend | Encode | vs PyTorch CPU |
133
+ |---|---|---|
134
+ | Rust — NdArray + Rayon (CPU) | 28,778 ms | 0.06x |
135
+ | Rust — NdArray + Accelerate (CPU) | 21,092 ms | 0.08x |
136
+ | Python — PyTorch (CPU) | 1,782 ms | 1.0x |
137
+ | Python — PyTorch MPS (GPU) | 581 ms | 3.1x |
138
+ | **Rust — wgpu f32 / Metal (GPU)** | **83 ms** | **21.5x** |
139
+ | **Rust — wgpu f16 / Metal (GPU)** | **85 ms** | **21.0x** |
140
+
141
+ The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster
142
+ than PyTorch CPU.
143
+
144
+ ![benchmark](benchmark.png)
145
+
146
+ ## Architecture details
147
+
148
+ | Parameter | Value |
149
+ |---|---|
150
+ | Model | ViT-Base |
151
+ | Embedding dim | 768 |
152
+ | Encoder depth | 12 layers |
153
+ | Predictor depth | 6 layers |
154
+ | Attention heads | 12 |
155
+ | Head dim | 64 |
156
+ | MLP ratio | 4x (hidden=3072) |
157
+ | Patch size | 16 (temporal) |
158
+ | Input size | 450 ROIs x 160 time points |
159
+ | Output | 4500 patches x 768 dims |
160
+ | Normalization | LayerNorm (eps=1e-6) |
161
+ | Activation | GELU |
162
+ | Pretraining | 300 epochs on UK Biobank |
163
+ | Loss | Smooth L1 (JEPA representation matching) |
164
+ | Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) |
165
+
166
+ ## Source
167
+
168
+ Original paper and code:
169
+
170
+ > Zijian Dong, Ruilin Li, Yilei Wu, et al.
171
+ > **Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking.**
172
+ > NeurIPS 2024 (Spotlight). [arXiv:2409.19407](https://arxiv.org/abs/2409.19407)
173
+
174
+ - Paper: [arxiv.org/abs/2409.19407](https://arxiv.org/abs/2409.19407)
175
+ - Original code: [github.com/hzlab/Brain-JEPA](https://github.com/hzlab/Brain-JEPA)
176
+ - Rust inference: [github.com/eugenehp/brainjepa-rs](https://github.com/eugenehp/brainjepa-rs)
benchmark.png ADDED
brainjepa.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b49db440a2f74d2448791aa73f1f49464197fd93f1e64e324da0d8754320baaa
3
+ size 742948128
gradient_mapping_450.csv ADDED
The diff for this file is too large to render. See raw diff