Galaxy Diffusion β latent diffusion weights (Galaxy10 DECaLS)
Conditional latent diffusion model (VAE + classifier-free guidance) for generating galaxy images by morphology class, trained on Galaxy10 DECaLS (17,736 RGB images, 256Γ256, 10 morphological classes).
These are the .safetensors weights. The model uses a custom architecture β it is
not a transformers / diffusers model and does not load via AutoModel. You need
the galaxy_diffusion package from the code repository to instantiate it.
- Code: https://github.com/LLapsus/galaxy-diffusion
- License: CC0 1.0 (public domain)
Files
| File | Contents |
|---|---|
latent_diffusion_galaxy10_xattn_v1.model.safetensors |
UNet denoiser (LatentUNetCA, cross-attention conditioning), ~27.9M params |
latent_diffusion_galaxy10_xattn_v1.vae.safetensors |
VAE (image β 4Γ32Γ32 latent), ~1.09M params |
latent_diffusion_galaxy10_xattn_v1.config.json |
constructor args (vae_config, unet_config, unet_type) + latent normalisation stats (latents_mean, latents_std) |
galaxy10_classifier.model.safetensors |
GalaxyCNN evaluation classifier, ~1.75M params (val acc 0.829) |
galaxy10_classifier.config.json |
classifier metadata (val_acc, epoch) |
Installation
pip install "git+https://github.com/LLapsus/galaxy-diffusion.git"
pip install huggingface_hub safetensors
Load the weights
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from galaxy_diffusion.models.vae import VAE
from galaxy_diffusion.models.unet import LatentUNetCA
path = snapshot_download("llapsus/galaxy-diffusion") # downloads all files
cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))
vae = VAE(**cfg["vae_config"])
vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
vae.eval()
unet = LatentUNetCA(**cfg["unet_config"])
unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
unet.eval()
Generate images
from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg
device = "cuda"
vae, unet = vae.to(device), unet.to(device)
_, alpha, alpha_bar = cosine_schedule(1000)
alpha, alpha_bar = alpha.to(device), alpha_bar.to(device)
latents_mean = torch.tensor(cfg["latents_mean"])
latents_std = torch.tensor(cfg["latents_std"])
images = sample_cfg(
unet, vae,
classes=list(range(10)), # one image per class
alpha=alpha, alpha_bar=alpha_bar,
latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32),
latents_mean=latents_mean, latents_std=latents_std,
device=device,
guidance_scale=2.5, # see "Guidance scale" below
cfg_rescale=0.7, # CFG rescaling (Lin et al., 2023)
) # -> tensor (10, 3, 256, 256) in [-1, 1]
The classifier is loaded analogously with GalaxyCNN from
galaxy_diffusion.models.classifier.
Model details
- VAE: 8Γ spatial compression, 3Γ256Γ256 β 4Γ32Γ32, KL-regularised.
- UNet (
LatentUNetCA): time conditioning via AdaGN, class conditioning via a cross-attention block after each encoder/decoder level + bottleneck; cosine noise schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout. - Classifier (
GalaxyCNN): trained on VAE-reconstructed images (to match the distribution of diffusion outputs) for evaluating class fidelity of generated samples.
Guidance scale
Classifier recall on generated images peaks around w β 3, but latent-space coverage
analysis shows w β 2.5 is the better fidelity/diversity operating point (matched
within-class spread). Higher w over-extrapolates samples toward neighbouring classes.
See the coverage analysis in the code repository.
Training data
Galaxy10 DECaLS β https://astronn.readthedocs.io/en/latest/galaxy10.html (17,736 images; 10 classes; not redistributed here).