| """Quickstart: Gaussian likelihood for one Gaia photometric measurement. |
| |
| This is a compact reference script for collaborators loading the released |
| isochrone emulator from Hugging Face. It follows the same pattern as the main |
| README examples: |
| |
| 1. download a bundle with ``Emulator.from_pretrained(...)``; |
| 2. freeze a JAX callable with ``make_frozen_apply(jit=False)``; |
| 3. explicitly normalize physical inputs into the bundle's canonical space; |
| 4. explicitly denormalize canonical outputs back to physical magnitudes; |
| 5. evaluate a simple diagonal Gaussian log likelihood in one outer jit. |
| |
| The input vector is ``[log10_age_yr, eep, feh]``. The measurement vector is |
| absolute ``[G_mag, BP_mag, RP_mag]``. If your data are apparent magnitudes, |
| include distance modulus, extinction, calibration offsets, or other nuisance |
| terms in your own likelihood around the emulator prediction. |
| |
| Related examples in the source repository: |
| |
| - examples/basic/02_load_bundle_predict.py |
| - examples/basic/04_use_bundle_in_map_fit.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
|
|
| from astro_emulators_toolkit import Emulator, denormalize_tree, normalize_tree |
|
|
| REPO_ID = "RozanskiT/isochrones-mlp" |
| REVISION = None |
| CACHE_DIR = ".emuspec_cache" |
|
|
| OUTPUT_LEAF = "magnitudes" |
| OUTPUT_CHANNELS = ("G_mag", "BP_mag", "RP_mag") |
|
|
| |
| THETA_PHYSICAL = np.asarray([9.4, 300.0, 0.0], dtype=np.float32) |
|
|
| |
| |
| OBSERVED_MAGNITUDES = np.asarray([6.94, 7.52, 6.21], dtype=np.float32) |
| OBSERVED_SIGMA_MAG = np.asarray([0.03, 0.03, 0.03], dtype=np.float32) |
|
|
|
|
| def main() -> None: |
| emu = Emulator.from_pretrained( |
| REPO_ID, |
| revision=REVISION, |
| cache_dir=CACHE_DIR, |
| verbose=True, |
| ) |
| apply_magnitudes = emu.make_frozen_apply(jit=False) |
|
|
| ref_inputs = emu.reference_scaling_inputs |
| ref_outputs = emu.reference_scaling_outputs |
| if ref_inputs is None or ref_outputs is None: |
| raise ValueError( |
| "This likelihood example requires reference_scaling_inputs and " |
| "reference_scaling_outputs in the bundle metadata." |
| ) |
|
|
| y_obs = jnp.asarray(OBSERVED_MAGNITUDES, dtype=jnp.float32) |
| y_err = jnp.asarray(OBSERVED_SIGMA_MAG, dtype=jnp.float32) |
|
|
| def predict_magnitudes(theta): |
| """Predict physical magnitudes; jit the outer objective, not this helper.""" |
| x_physical = {"parameters": theta[None, :]} |
| x_scaled = normalize_tree( |
| x_physical, |
| ref_inputs["min_tree"], |
| ref_inputs["max_tree"], |
| ) |
| y_scaled = apply_magnitudes(x_scaled) |
| y_physical = denormalize_tree( |
| y_scaled, |
| ref_outputs["min_tree"], |
| ref_outputs["max_tree"], |
| ) |
| return y_physical[OUTPUT_LEAF][0] |
|
|
| @jax.jit |
| def evaluate_likelihood(theta): |
| y_model = predict_magnitudes(theta) |
| resid = (y_obs - y_model) / y_err |
| log_norm = jnp.sum(jnp.log(2.0 * jnp.pi * y_err**2)) |
| log_likelihood = -0.5 * (jnp.sum(resid**2) + log_norm) |
| return y_model, log_likelihood |
|
|
| theta = jnp.asarray(THETA_PHYSICAL, dtype=jnp.float32) |
| model_magnitudes_jax, logp_jax = evaluate_likelihood(theta) |
| model_magnitudes = np.asarray(jax.block_until_ready(model_magnitudes_jax)) |
| logp = float(jax.block_until_ready(logp_jax)) |
|
|
| print("theta_physical [age, eep, feh]:", THETA_PHYSICAL.tolist()) |
| print("model absolute magnitudes:") |
| for name, value in zip(OUTPUT_CHANNELS, model_magnitudes, strict=True): |
| print(f" {name}: {value:.6f}") |
| print("observed absolute magnitudes:") |
| for name, value in zip(OUTPUT_CHANNELS, OBSERVED_MAGNITUDES, strict=True): |
| print(f" {name}: {value:.6f}") |
| print("sigma_mag:", OBSERVED_SIGMA_MAG.tolist()) |
| print("log_likelihood:", f"{logp:.6f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|