GreenGenomicsLab's picture
Upload README.md with huggingface_hub
28dd986 verified
metadata
language: en
license: apache-2.0
tags:
  - marine-biology
  - metagenomics
  - environmental-modeling
  - protein-domains
  - tara-oceans
  - vicreg
  - joint-embedding
  - self-supervised-learning
  - pytorch
library_name: pytorch
pipeline_tag: tabular-regression

TARA-WorldModel-VICReg

Joint environment--genome embedding model for marine ecosystem productivity prediction, trained with Variance-Invariance-Covariance Regularization (VICReg) on TARA Oceans data.

Model Description

The World Model learns a shared latent space that aligns environmental context (satellite-derived variables) with microalgal protein domain composition (Pfam module abundances), then predicts marine productivity (chlorophyll-a, POC, NFLH) from the joint embedding.

Architecture

Training:
  env (24 dims) β†’ EncoderE(128 β†’ 32) β†’ z_env ─┐
                                                 β”œβ†’ VICReg loss
  pfam (20 dims) β†’ EncoderP(256 β†’ 128 β†’ 32) β†’ z_pfam
                                                 β””β†’ Predictor(64 β†’ 3) β†’ productivity

Inference (environment-only):
  env β†’ EncoderE β†’ z_env β†’ Predictor β†’ [chl-a, POC, NFLH]
  • EncoderE: Linear(24β†’128) + BN + ReLU + Dropout(0.3) β†’ Linear(128β†’32) + BN + ReLU + Dropout(0.3)
  • EncoderP: Linear(20β†’256) + BN + ReLU + Dropout(0.3) β†’ Linear(256β†’128) + BN + ReLU + Dropout(0.3) β†’ Linear(128β†’32) + BN + ReLU + Dropout(0.3)
  • Predictor: Linear(32β†’64) + ReLU β†’ Linear(64β†’3)
  • Total parameters: 53,187

VICReg Loss

Non-contrastive self-supervised alignment (Bardes et al., ICLR 2022):

  • Invariance: MSE between co-located env/pfam embeddings (Ξ»=25)
  • Variance: Hinge loss preventing embedding collapse (Ξ»=25)
  • Covariance: Off-diagonal penalty decorrelating dimensions (Ξ»=1)
  • Prediction: MSE on productivity targets (Ξ±=1)

Performance

Joint embedding improves POC prediction (RΒ² 0.422 β†’ 0.532, 26% relative improvement) over environment-only baseline. Chlorophyll-a and NFLH are better predicted by environment alone (directly satellite-measured).

Files

Fold Checkpoints (leave-one-basin-out spatial CV)

Two training runs are provided:

  • world_model_fold_*_20260127_110243.pt β€” Initial configuration (latent_dim=16)
  • world_model_fold_*_20260127_111754.pt β€” Best configuration from hyperparameter sweep (latent_dim=32)

Six folds per run: Arctic, Atlantic, Indian, Mediterranean, Pacific, Southern.

Configuration

  • phase2_best_config.json β€” Hyperparameter sweep results (54 configurations, 3 seeds each)

Hyperparameters (Best Config)

Parameter Value
latent_dim 32
dropout 0.3
Ξ»_invariance 25.0
Ξ»_variance 25.0
Ξ»_covariance 1.0
pred_alpha 1.0
learning_rate 0.001
weight_decay 1e-4
batch_size 128
max_epochs 300
patience 30
grad_clip 1.0

Usage

import torch

# Load fold checkpoint
ckpt = torch.load("world_model_fold_Atlantic_20260127_111754.pt", map_location="cpu")

# ckpt contains model_state_dict for the full WorldModel
# Requires WorldModel class from the training codebase

Dataset

  • 1,810 ocean samples with co-located environment and Pfam profiles
  • 24 environmental variables (GEE oceanographic/atmospheric)
  • 20 Pfam module features (aggregated from 9,466 domains via co-occurrence clustering)
  • 3 productivity targets (chlorophyll-a, POC, NFLH)
  • Spatial cross-validation: Leave-one-basin-out (6 ocean basins)

Related Models

Citation

LA4SR classification models:

Nelson DR, Jaiswal AK, Ismail NS, Mystikou A, Salehi-Ashtiani K. Patterns. 2024;6(11).

License

Apache 2.0