CONFLUX
Conditional 3D latent generative models for medical imaging.
CONFLUX synthesizes full 3D medical volumes from structured clinical metadata: a VAE tokenizer compresses a volume into a compact latent, a single-stream rectified-flow transformer generates in that latent space, and a Flow-GRPO reinforcement-learning stage sharpens label faithfulness. This repository holds the released checkpoints, one self-contained folder per modality.
π Paper β coming soon β’ π¦ Datasets β’ π» Code β coming soon
Available checkpoints
| Folder | Modality | Resolution | Conditioning | Dataset |
|---|---|---|---|---|
chest-ct/ |
Chest CT | 216 Γ 176 Γ 200 | 18 findings + sex + age + kernel | conflux-chest-ct |
More modalities (e.g. brain MRI, abdominal CT) will be added as they are trained β each as a new self-contained folder.
Each modality folder contains:
<modality>/
βββ vae.safetensors 3D VAE (encoder + decoder)
βββ dit.safetensors rectified-flow transformer (final, RL-post-trained)
βββ config.json architecture + latent normalization for this modality
Both weight files are needed: the transformer generates a latent, and the VAE decoder turns it into the volume.
Usage
Model classes live in the code repo. Point
MODALITY at the folder you want.
import json, torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from models import build_vae, build_dit, flow_sample # from github.com/mxvp/CONFLUX
REPO, MODALITY = "gevaertlab/conflux", "chest-ct"
cfg = json.load(open(hf_hub_download(REPO, f"{MODALITY}/config.json")))
vae = build_vae({"vae": cfg["vae"]}); vae.load_state_dict(load_file(hf_hub_download(REPO, f"{MODALITY}/vae.safetensors"))); vae.eval().cuda()
dit = build_dit({"dit": cfg["dit"]}); dit.load_state_dict(load_file(hf_hub_download(REPO, f"{MODALITY}/dit.safetensors"))); dit.eval().cuda()
# conditioning vector layout is in cfg["cond_layout"]; for chest-ct:
# [findings(18), sex(1), age one-hot(7), kernel one-hot(16)] = 42
cond = torch.zeros(1, cfg["dit"]["cond_dim"], device="cuda")
cond[0, 2] = 1.0 # e.g. Cardiomegaly (finding index 2)
sc, sh = cfg["latent"]["scale"], cfg["latent"]["shift"]
with torch.no_grad():
z = flow_sample(dit, (1, cfg["dit"]["latent_channels"], *cfg["latent"]["spatial"]),
steps=50, cond=cond, device="cuda")
vol = vae.decode(z / sc + sh) # (1,1,*spatial_size), model units ~ HU/1000
Intended use
Research use β synthetic medical-volume generation, augmentation, and method development. Not for clinical use. Outputs are synthetic and must not inform any patient-facing decision.
Citation
Paper and citation details coming soon.
License
CC BY-NC-SA 4.0 β non-commercial research use.
