Solaris Small Patch 4

This repository contains a Solaris-Small checkpoint trained for 12-hour multi-wavelength solar forecasting, following the Solaris pretraining setup from Solaris: A Foundation Model of the Sun.

This run uses patch size 4. The earlier patch-size-8 checkpoint is published separately as hrrsmjd/solaris_small_patch8.

The checkpoint was trained on hrrsmjd/AIA_12hour_512x512 for 7750 optimizer steps using two history frames (t-12h, t) to predict all eight pretraining wavelengths at t+12h.

Files

  • solaris_small_patch4_model_state_dict.pt: reusable PyTorch checkpoint containing model_state_dict, learned normalization coefficients, wavelengths, scale factors, patch size, seed, training step, and final training loss.
  • config.json: lightweight metadata for reconstructing the model and normalization.
  • assets/solaris_small_patch4_test0_prediction.png: example qualitative test prediction plot.
  • eval/solaris_pretrain_paperloss_p4_ema_seed42_test_mse_subset_0352.md: full test-split raw-scale MSE/RMSE/MAE report.

Example Plot

The plot below shows one test sample with rows for input t-12h, input t, target t+12h, prediction t+12h, and prediction - target across all eight wavelengths.

Solaris Small Patch 4 test prediction

Model Details

  • Architecture: SolarisSmall
  • Patch size: 4
  • Embedding dimension: 256
  • Encoder depths: (2, 6, 2)
  • Decoder depths: (2, 6, 2)
  • Output wavelengths: 0094, 0131, 0171, 0193, 0211, 0304, 0335, 1600
  • Training dataset: hrrsmjd/AIA_12hour_512x512
  • Training target: 12-hour forecast
  • Training budget: 7750 optimizer steps, batch size 8, gradient accumulation 4
  • Seed: 42

Loading

import torch

from solaris.model.solaris import SolarisSmall

checkpoint = torch.load("solaris_small_patch4_model_state_dict.pt", map_location="cpu", weights_only=False)

model = SolarisSmall(
    out_levels=len(checkpoint["wavelengths"]),
    patch_size=checkpoint["patch_size"],
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

scale_factors = torch.tensor(checkpoint["scale_factors"], dtype=torch.float32)
norm_coeff_1 = checkpoint["norm_coeff_1"]
norm_coeff_2 = checkpoint["norm_coeff_2"]

Inputs should be normalized with the Solaris transform used during training. Model outputs are normalized intensities; multiply by the per-wavelength scale factors before comparing to raw-intensity targets.

Test Metrics

Metrics below use all 352 test samples and are computed on the raw intensity scale. Regular final weights are recommended; EMA weights from the training checkpoint were worse on the full test split and are not included in this model-state checkpoint.

Wavelength (A) MSE RMSE MAE
0094 8.87581 2.97923 0.240353
0131 243.534 15.6056 1.446
0171 9067.35 95.2226 37.4301
0193 18337.8 135.417 56.6422
0211 3811.31 61.7358 23.776
0304 1089.31 33.0047 12.0236
0335 31.7769 5.6371 1.63412
1600 54.0763 7.35366 3.33375
Mean 4080.5 44.6195 17.0658

Training Notes

Scale factors were computed as half the average per-image maximum over unique train-split timestamps:

[58.224720422037755, 216.21549287451052, 1616.446579054541, 2551.0149615718674, 1190.0182024885178, 887.1800787601859, 112.33733897339224, 266.61844876445224]

The final logged training-batch metrics at step 7750 were:

weighted MAE: 0.009133
mean raw RMSE: 25.278
per-wavelength raw RMSE: [2.535, 4.554, 62.817, 67.771, 28.201, 22.793, 2.572, 10.982]
Downloads last month
19
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train hrrsmjd/solaris_small_patch4

Paper for hrrsmjd/solaris_small_patch4