Brain MRI SigLIP Freeze

Brain MRI SigLIP Freeze is the Stage 1 checkpoint from the brain_mri_siglip_run_0509 experiment. In this stage, the text tower initialized from google/medsiglip-448 was frozen while the 3D MRI vision tower and projection layers were trained with a SigLIP-style image-text contrastive objective.

This checkpoint is intended as a research visual encoder for brain MRI representation learning and as a warm-start checkpoint for downstream VLM or medical imaging tasks. It is not a clinical diagnostic device.

Model Summary

  • Base text tower: google/medsiglip-448
  • Model class: BrainMRISiglipModel
  • Vision input: single-channel 3D MRI volumes
  • Expected volume shape: [1, 128, 192, 192]
  • Projection dimension: 1152
  • Patch size: [8, 16, 16]
  • Training precision: bf16
  • Training input format: preprocessed .pt tensors, float16, value range [-1, 1]

Training Context

This repository contains the final stage1_freeze_text checkpoint from brain_mri_siglip_run_0509.

Training summary:

  • Training samples: 950,720
  • Validation samples: 67,450
  • Validation samples with metadata_text: 32,278
  • Stage 1 epochs: 12
  • World size: 5
  • Per-device batch size: 196
  • Contrastive forward batch: 980
  • Text tower: frozen
  • Vision tower: trainable
  • Gradient checkpointing: vision enabled, text disabled

Training-time retrieval evaluation used capped validation subsets and should be treated as monitoring rather than a final benchmark.

Loading

This model uses custom Transformers code. Load it with trust_remote_code=True.

import torch
from transformers import AutoModel, AutoProcessor

repo_id = "shenxiaochen/brain-mri-siglip-freeze"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModel.from_pretrained(
    repo_id,
    trust_remote_code=True,
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()

processor = AutoProcessor.from_pretrained(
    repo_id,
    trust_remote_code=True,
)

NIfTI Preprocessing

For reproducible inference from NIfTI files, pass paths directly to the saved processor. This repository includes the offline-aligned preprocessing implementation used to match the training tensor distribution.

nifti_path = "/path/to/brain_mri.nii.gz"

inputs = processor(
    volumes=nifti_path,
    return_tensors="pt",
)
pixel_values = inputs["pixel_values"].to(device)

if torch.cuda.is_available():
    pixel_values = pixel_values.to(dtype=torch.bfloat16)

with torch.inference_mode():
    image_embeds = model.get_image_features(pixel_values=pixel_values)

print(pixel_values.shape)  # [1, 1, 128, 192, 192]
print(image_embeds.shape)  # [1, 1152]

The saved path-based preprocessing recipe is:

  • canonicalize image orientation to closest RAS
  • build foreground mask with threshold 1e-3
  • keep the largest connected foreground component
  • crop foreground with 5mm margin
  • normalize foreground intensities with 0.5/99.5 percentiles
  • map intensities to [-1, 1]
  • resample to spacing (1.25, 1.0, 1.0)
  • downscale to fit [128, 192, 192]
  • center-pad with background value -1.0

The exact settings are saved in preprocessor_config.json and processor_config.json.

Using Preprocessed .pt Inputs

If your data is already stored as the same offline preprocessed tensors used during training, you can load it directly:

payload = torch.load("/path/to/sample.pt", map_location="cpu")
pixel_values = payload["pixel_values"] if isinstance(payload, dict) else payload

if pixel_values.ndim == 4:
    pixel_values = pixel_values.unsqueeze(0)

pixel_values = pixel_values.to(device=device, dtype=torch.bfloat16)

with torch.inference_mode():
    image_embeds = model.get_image_features(pixel_values=pixel_values)

Expected tensor format:

  • shape [1, 128, 192, 192] for one volume, or [B, 1, 128, 192, 192] for a batch
  • values in [-1, 1]
  • padded background voxels near -1.0

VLM Integration Notes

For VLM construction, use the 3D vision tower as a visual backbone and add a projector, Q-Former, Perceiver resampler, or other token compressor before connecting to an LLM.

A practical downstream recipe is:

  1. Freeze this MRI encoder and train only the multimodal projector/resampler.
  2. Evaluate downstream classification, retrieval, report alignment, or instruction-following behavior.
  3. Optionally unfreeze the top vision layers with a much smaller learning rate.

Limitations

  • This checkpoint was trained for representation learning, not diagnosis.
  • Performance should be validated on task-specific subject-level or study-level splits.
  • Scanner, protocol, site, and preprocessing differences can affect embeddings.
  • External users should preserve the saved preprocessing pipeline for NIfTI inference.
  • Retrieval monitoring during training is not a substitute for downstream clinical validation.

Citation

If you use this checkpoint, please cite this model repository and the upstream MedSigLIP model where appropriate.

Downloads last month
24
Safetensors
Model size
0.9B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for shenxiaochen/brain-mri-siglip-freeze

Finetuned
(46)
this model