Galaxy Diffusion β€” latent diffusion weights (Galaxy10 DECaLS)

Conditional latent diffusion model (VAE + classifier-free guidance) for generating galaxy images by morphology class, trained on Galaxy10 DECaLS (17,736 RGB images, 256Γ—256, 10 morphological classes).

These are the .safetensors weights. The model uses a custom architecture β€” it is not a transformers / diffusers model and does not load via AutoModel. You need the galaxy_diffusion package from the code repository to instantiate it.

Files

File Contents
latent_diffusion_galaxy10_xattn_v1.model.safetensors UNet denoiser (LatentUNetCA, cross-attention conditioning), ~27.9M params
latent_diffusion_galaxy10_xattn_v1.vae.safetensors VAE (image ↔ 4Γ—32Γ—32 latent), ~1.09M params
latent_diffusion_galaxy10_xattn_v1.config.json constructor args (vae_config, unet_config, unet_type) + latent normalisation stats (latents_mean, latents_std)
galaxy10_classifier.model.safetensors GalaxyCNN evaluation classifier, ~1.75M params (val acc 0.829)
galaxy10_classifier.config.json classifier metadata (val_acc, epoch)

Installation

pip install "git+https://github.com/LLapsus/galaxy-diffusion.git"
pip install huggingface_hub safetensors

Load the weights

import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

from galaxy_diffusion.models.vae import VAE
from galaxy_diffusion.models.unet import LatentUNetCA

path = snapshot_download("llapsus/galaxy-diffusion")  # downloads all files
cfg  = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))

vae = VAE(**cfg["vae_config"])
vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
vae.eval()

unet = LatentUNetCA(**cfg["unet_config"])
unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
unet.eval()

Generate images

from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg

device = "cuda"
vae, unet = vae.to(device), unet.to(device)

_, alpha, alpha_bar = cosine_schedule(1000)
alpha, alpha_bar = alpha.to(device), alpha_bar.to(device)

latents_mean = torch.tensor(cfg["latents_mean"])
latents_std  = torch.tensor(cfg["latents_std"])

images = sample_cfg(
    unet, vae,
    classes=list(range(10)),                       # one image per class
    alpha=alpha, alpha_bar=alpha_bar,
    latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32),
    latents_mean=latents_mean, latents_std=latents_std,
    device=device,
    guidance_scale=2.5,                            # see "Guidance scale" below
    cfg_rescale=0.7,                               # CFG rescaling (Lin et al., 2023)
)   # -> tensor (10, 3, 256, 256) in [-1, 1]

The classifier is loaded analogously with GalaxyCNN from galaxy_diffusion.models.classifier.

Model details

  • VAE: 8Γ— spatial compression, 3Γ—256Γ—256 ↔ 4Γ—32Γ—32, KL-regularised.
  • UNet (LatentUNetCA): time conditioning via AdaGN, class conditioning via a cross-attention block after each encoder/decoder level + bottleneck; cosine noise schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout.
  • Classifier (GalaxyCNN): trained on VAE-reconstructed images (to match the distribution of diffusion outputs) for evaluating class fidelity of generated samples.

Guidance scale

Classifier recall on generated images peaks around w β‰ˆ 3, but latent-space coverage analysis shows w β‰ˆ 2.5 is the better fidelity/diversity operating point (matched within-class spread). Higher w over-extrapolates samples toward neighbouring classes. See the coverage analysis in the code repository.

Training data

Galaxy10 DECaLS β€” https://astronn.readthedocs.io/en/latest/galaxy10.html (17,736 images; 10 classes; not redistributed here).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support