--- tags: - astronomy - multimodal - classification datasets: - AstroMLCore/AstroM3Processed - AstroMLCore/AstroM3Dataset --- AstroM³ is a self-supervised multimodal model for astronomy that integrates time-series photometry, spectra, and metadata into a unified embedding space for classification and other downstream tasks. AstroM³ is trained on [AstroM3Processed](https://huggingface.co/datasets/AstroMLCore/AstroM3Processed), which is the pre-processed version of [AstroM3Dataset](https://huggingface.co/datasets/AstroMLCore/AstroM3Dataset). For more details on the AstroM³ architecture, training, and results, please refer to the [paper](https://arxiv.org/abs/2411.08842).


Figure 1: Overview of the multimodal CLIP framework adapted for astronomy, incorporating three data modalities: photometric time-series, spectra, and metadata. Each modality is processed by a dedicated encoder to create embeddings, which are then mapped into a shared embedding space through projection heads. Pairwise similarity matrices align the embeddings across modalities, and a symmetric cross-entropy loss, computed over these matrices, optimizes the model. The total loss, derived from all pairwise losses, guides the model’s trimodal learning.

To use AstroM³ for inference, install the AstroM3 library from our [GitHub repo](https://github.com/MeriDK/AstroM3). ```sh git clone https://github.com/MeriDK/AstroM3.git cd AstroM3 ``` Create a virtual environment (tested with Python 3.10.14), then install the required dependencies: ```sh uv venv venv --python 3.10.14 source venv/bin/activate uv pip install -r requirements.txt ``` ## A simple example to get started 1. Data Loading & Preprocessing ```python from datasets import load_dataset from src.data import process_photometry # Load the test dataset test_dataset = load_dataset('AstroMLCore/AstroM3Processed', name='full_42', split='test') # Process photometry to have a fixed sequence length of 200 (center-cropped) test_dataset = test_dataset.map(process_photometry, batched=True, fn_kwargs={'seq_len': 200, 'how': 'center'}) test_dataset = test_dataset.with_format('torch') ``` 2. Model Loading & Embedding Extraction ```python import torch from src.model import AstroM3 # Load the base AstroM3-CLIP model model = AstroM3.from_pretrained('AstroMLCore/AstroM3-CLIP') # Retrieve the first sample (batch size = 1) sample = test_dataset[0:1] photometry = sample['photometry'] photometry_mask = sample['photometry_mask'] spectra = sample['spectra'] metadata = sample['metadata'] # Example 1: Generate embeddings when all modalities are present p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata) multimodal_emb = (p_emb + s_emb + m_emb) / 3 print('Multimodal Embedding (All Modalities):', multimodal_emb) # Example 2: Generate embeddings when the spectra modality is missing dummy_spectra = torch.zeros_like(spectra) # Dummy tensor for missing spectra p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, dummy_spectra, metadata) multimodal_emb_missing = (p_emb + m_emb) / 2 print('Multimodal Embedding (Spectra Missing):', multimodal_emb_missing) ``` 3. Classification Examples ```python from src.model import AstroM3, Informer, GalSpecNet, MetaModel # Photometry classification photo_model = Informer.from_pretrained('AstroMLCore/AstroM3-CLIP-photo') prediction = photo_model(photometry, photometry_mask).argmax(dim=1).item() print('Photometry Classification:', test_dataset.features['label'].int2str(prediction)) # Spectra classification spectra_model = GalSpecNet.from_pretrained('AstroMLCore/AstroM3-CLIP-spectra') prediction = spectra_model(spectra).argmax(dim=1).item() print('Spectra Classification:', test_dataset.features['label'].int2str(prediction)) # Metadata classification meta_model = MetaModel.from_pretrained('AstroMLCore/AstroM3-CLIP-meta') prediction = meta_model(metadata).argmax(dim=1).item() print('Metadata Classification:', test_dataset.features['label'].int2str(prediction)) # Multimodal classification all_model = AstroM3.from_pretrained('AstroMLCore/AstroM3-CLIP-all') prediction = all_model(photometry, photometry_mask, spectra, metadata).argmax(dim=1).item() print('Multimodal Classification:', test_dataset.features['label'].int2str(prediction)) ``` ## The AstroM³ Family | # Model | # Description | | :--- | :--- | | [AstroM3-CLIP](https://huggingface.co/AstroMLCore/AstroM3-CLIP) | The base model pre-trained using the trimodal CLIP approach. | | [AstroM3-CLIP-meta](https://huggingface.co/AstroMLCore/AstroM3-CLIP-meta) | Fine-tuned for metadata-only classification. | | [AstroM3-CLIP-spectra](https://huggingface.co/AstroMLCore/AstroM3-CLIP-spectra) | Fine-tuned for spectra-only classification. | | [AstroM3-CLIP-photo](https://huggingface.co/AstroMLCore/AstroM3-CLIP-photo) | Fine-tuned for photometry-only classification. | | [AstroM3-CLIP-all](https://huggingface.co/AstroMLCore/AstroM3-CLIP-all) | Fine-tuned for multimodal classification. | ## AstroM3-CLIP Variants These variants of the base AstroM3-CLIP model are trained using different random seeds (42, 0, 66, 12, 123); ensure that the dataset is loaded with the corresponding seed for consistency. | # Model | # Description | | :--- | :--- | | [AstroM3-CLIP-42](https://huggingface.co/AstroMLCore/AstroM3-CLIP-42) | The base model pre-trained with random seed 42 (identical to AstroM3-CLIP). | | [AstroM3-CLIP-0](https://huggingface.co/AstroMLCore/AstroM3-CLIP-0) | AstroM3-CLIP pre-trained with random seed 0 (use dataset with seed 0). | | [AstroM3-CLIP-66](https://huggingface.co/AstroMLCore/AstroM3-CLIP-66) | AstroM3-CLIP pre-trained with random seed 66 (use dataset with seed 66). | | [AstroM3-CLIP-12](https://huggingface.co/AstroMLCore/AstroM3-CLIP-12) | AstroM3-CLIP pre-trained with random seed 12 (use dataset with seed 12). | | [AstroM3-CLIP-123](https://huggingface.co/AstroMLCore/AstroM3-CLIP-123) | AstroM3-CLIP pre-trained with random seed 123 (use dataset with seed 123). | ## Using your own data Note that the data in the AstroM3Processed dataset is already pre-processed. If you want to use the model with your own data, you must pre-process it in the same way: 1. **Spectra**: Each spectrum is interpolated to a fixed wavelength grid (3850–9000 Å), normalized using mean and MAD, and log-MAD is added as an auxiliary feature. 2. **Photometry**: Light curves are deduplicated, sorted by time, normalized using mean and MAD, time-scaled to [0, 1], and augmented with auxiliary features like log-MAD and time span. 3. **Metadata**: Scalar metadata is transformed via domain-specific functions (e.g., absolute magnitude, log, sin/cos), then normalized using dataset-level statistics. For a detailed description, read the [paper](https://arxiv.org/abs/2411.08842). To see exactly how we performed this preprocessing, refer to [`preprocess.py`](https://huggingface.co/datasets/AstroMLCore/AstroM3Dataset/blob/main/preprocess.py) in the AstroM3Dataset repo. --- ## Citation 🤗 If you find this model usefull, please cite our paper 🤗 ```bibtex @article{rizhko2024astrom, title={AstroM $\^{} 3$: A self-supervised multimodal model for astronomy}, author={Rizhko, Mariia and Bloom, Joshua S}, journal={arXiv preprint arXiv:2411.08842}, year={2024} } ```