Instructions to use shenxiaochen/brain-mri-siglip-freeze with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use shenxiaochen/brain-mri-siglip-freeze with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="shenxiaochen/brain-mri-siglip-freeze", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("shenxiaochen/brain-mri-siglip-freeze", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Brain MRI SigLIP Freeze
Brain MRI SigLIP Freeze is the Stage 1 checkpoint from the brain_mri_siglip_run_0509 experiment. In this stage, the text tower initialized from google/medsiglip-448 was frozen while the 3D MRI vision tower and projection layers were trained with a SigLIP-style image-text contrastive objective.
This checkpoint is intended as a research visual encoder for brain MRI representation learning and as a warm-start checkpoint for downstream VLM or medical imaging tasks. It is not a clinical diagnostic device.
Model Summary
- Base text tower:
google/medsiglip-448 - Model class:
BrainMRISiglipModel - Vision input: single-channel 3D MRI volumes
- Expected volume shape:
[1, 128, 192, 192] - Projection dimension:
1152 - Patch size:
[8, 16, 16] - Training precision:
bf16 - Training input format: preprocessed
.pttensors,float16, value range[-1, 1]
Training Context
This repository contains the final stage1_freeze_text checkpoint from brain_mri_siglip_run_0509.
Training summary:
- Training samples:
950,720 - Validation samples:
67,450 - Validation samples with
metadata_text:32,278 - Stage 1 epochs:
12 - World size:
5 - Per-device batch size:
196 - Contrastive forward batch:
980 - Text tower: frozen
- Vision tower: trainable
- Gradient checkpointing: vision enabled, text disabled
Training-time retrieval evaluation used capped validation subsets and should be treated as monitoring rather than a final benchmark.
Loading
This model uses custom Transformers code. Load it with trust_remote_code=True.
import torch
from transformers import AutoModel, AutoProcessor
repo_id = "shenxiaochen/brain-mri-siglip-freeze"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(
repo_id,
trust_remote_code=True,
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()
processor = AutoProcessor.from_pretrained(
repo_id,
trust_remote_code=True,
)
NIfTI Preprocessing
For reproducible inference from NIfTI files, pass paths directly to the saved processor. This repository includes the offline-aligned preprocessing implementation used to match the training tensor distribution.
nifti_path = "/path/to/brain_mri.nii.gz"
inputs = processor(
volumes=nifti_path,
return_tensors="pt",
)
pixel_values = inputs["pixel_values"].to(device)
if torch.cuda.is_available():
pixel_values = pixel_values.to(dtype=torch.bfloat16)
with torch.inference_mode():
image_embeds = model.get_image_features(pixel_values=pixel_values)
print(pixel_values.shape) # [1, 1, 128, 192, 192]
print(image_embeds.shape) # [1, 1152]
The saved path-based preprocessing recipe is:
- canonicalize image orientation to closest RAS
- build foreground mask with threshold
1e-3 - keep the largest connected foreground component
- crop foreground with
5mmmargin - normalize foreground intensities with
0.5/99.5percentiles - map intensities to
[-1, 1] - resample to spacing
(1.25, 1.0, 1.0) - downscale to fit
[128, 192, 192] - center-pad with background value
-1.0
The exact settings are saved in preprocessor_config.json and processor_config.json.
Using Preprocessed .pt Inputs
If your data is already stored as the same offline preprocessed tensors used during training, you can load it directly:
payload = torch.load("/path/to/sample.pt", map_location="cpu")
pixel_values = payload["pixel_values"] if isinstance(payload, dict) else payload
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(0)
pixel_values = pixel_values.to(device=device, dtype=torch.bfloat16)
with torch.inference_mode():
image_embeds = model.get_image_features(pixel_values=pixel_values)
Expected tensor format:
- shape
[1, 128, 192, 192]for one volume, or[B, 1, 128, 192, 192]for a batch - values in
[-1, 1] - padded background voxels near
-1.0
VLM Integration Notes
For VLM construction, use the 3D vision tower as a visual backbone and add a projector, Q-Former, Perceiver resampler, or other token compressor before connecting to an LLM.
A practical downstream recipe is:
- Freeze this MRI encoder and train only the multimodal projector/resampler.
- Evaluate downstream classification, retrieval, report alignment, or instruction-following behavior.
- Optionally unfreeze the top vision layers with a much smaller learning rate.
Limitations
- This checkpoint was trained for representation learning, not diagnosis.
- Performance should be validated on task-specific subject-level or study-level splits.
- Scanner, protocol, site, and preprocessing differences can affect embeddings.
- External users should preserve the saved preprocessing pipeline for NIfTI inference.
- Retrieval monitoring during training is not a substitute for downstream clinical validation.
Citation
If you use this checkpoint, please cite this model repository and the upstream MedSigLIP model where appropriate.
- Downloads last month
- 24
Model tree for shenxiaochen/brain-mri-siglip-freeze
Base model
google/medsiglip-448