SupraWeather-Nano-Preview

Wheather Supra Nano

SupraWeather-Nano-Preview is an FT-Transformer (Feature Tokenizer Transformer) trained to classify a weather phenomenon from a small set of tabular meteorological features (temperature, humidity, pressure, wind, altitude, month, and air mass type).

This model is trained exclusively on a synthetically generated dataset and is NOT intended for real-world weather forecasting. The synthetic data was produced by a rule-based generator that assigns labels using simplified, hand-crafted physical heuristics (e.g. hard temperature/humidity gating per class) rather than real observational or reanalysis weather data. The model should be treated as an architecture/training-pipeline demonstration, not as a meteorological tool.


Architecture

Type: FT-Transformer (Feature Tokenizer + Transformer Encoder), following the general design described by Gorishniy et al. (2021), "Revisiting Deep Learning Models for Tabular Data."

Extracted directly from the model configuration in the notebook:

Component Value
d_token (hidden size) 192
n_blocks (transformer layers) 3
n_heads (attention heads) 8
d_ffn_factor 4/3
attention_dropout 0.2
ffn_dropout 0.1
residual_dropout 0.0
num_labels 11
Normalization Pre-LayerNorm transformer blocks
Classification head Linear(192β†’192) β†’ ReLU β†’ Dropout β†’ Linear(192β†’11)
Tokens per sample 1 CLS + 6 numerical tokens + 3 categorical tokens = 10

Parameter count: 500k-700k

Feature tokenization:

  • Continuous features use a per-feature learned affine projection (NumericalFeatureTokenizer): token_i = x_i * W_i + b_i.
  • Categorical features (wind_direction, month, air_mass) use independent nn.Embedding layers (CategoricalFeatureTokenizer).
  • A learnable CLS token aggregates all feature tokens through the transformer stack; its final representation is passed to the classification head.

Inputs

Feature Type Notes
temperature float (Β°C) Continuous
humidity float (%) Continuous
pressure float (hPa) Continuous
pressure_trend float (hPa / 3h) Continuous
wind_speed float (km/h) Continuous
altitude float (m) Continuous
wind_direction str Categorical β€” one of: n/ne/e/se/s/sw/w/nw (or full names)
month int (1–12) Categorical
air_mass str Categorical β€” one of: polar, arctic, continental, maritime, tropical, equatorial

Continuous features are normalized using means/standard deviations computed on the training split (stored alongside the model as normalisation_stats.json).


Output Classes

The model predicts one of 11 classes:

0 clear

1 cloudy

2 light_rain

3 heavy_rain

4 thunderstorm

5 snow

6 freezing_rain

7 soft_hail

8 fog

9 cold_front

10 windstorm


Dataset

The dataset is fully synthetic, generated programmatically (no real observational data is used).

  • Size: 120,000 samples (generate_dataset(n=120_000)).
  • Split: 80% train / 10% validation / 10% test.
  • Feature generation: Continuous features are sampled from parametric distributions (e.g. exponential altitude, cosine-seasonal temperature with lapse-rate adjustment, beta-distributed humidity, gamma-distributed wind speed) with fixed random seed (seed=42).
  • Label assignment: Each class is assigned a hand-crafted "logit score" as a function of the input features (e.g. hard exclusions such as snow being disallowed above +2 Β°C, freezing rain restricted to a βˆ’3 Β°C to +1 Β°C window, thunderstorms requiring temperature > 8 Β°C and humidity > 70%). Scores are scaled by a temperature factor (T = 0.15), softmax-normalized, and sampled; if the resulting maximum probability falls below 0.40, the label falls back to a deterministic argmax over the raw scores.
  • Class weighting (sklearn.utils.class_weight.compute_class_weight, balanced) is computed on the training split to address residual class imbalance.

No external or third-party dataset is used at any stage.


Training

Extracted from the notebook's TrainingArguments and custom Trainer subclass:

Setting Value
Epochs 25
Train batch size 512
Eval batch size 1024
Learning rate 1e-3
Weight decay 1e-4
Max grad norm 1.0
Warmup ratio 0.06
LR scheduler Cosine
Eval / save strategy Per epoch
Best model selection Highest F1 (metric_for_best_model="f1")
Early stopping Patience = 5 epochs
Mixed precision (fp16) Enabled if CUDA is available
Optimizer Hugging Face Trainer default (AdamW) β€” no custom optimizer is configured in the notebook
Loss function Custom Focal Loss (gamma=2.0), with per-class alpha weights set from the balanced class weights computed on the training set
Class balancing Focal loss alpha weighting only; no oversampling/undersampling or data augmentation is used

Hardware

2xT4 from free Kaggle Quota


Training Time

5 Minutes(yes, not joking)


Evaluation

The notebook computes accuracy, precision (weighted), recall (weighted), F1 (weighted), a full classification_report, and a confusion matrix on the held-out test split via trainer.predict(). However, no executed run log with the resulting numeric values is available for this README, so quantitative test-set metrics are not reported here rather than estimated.


Stress Tests

In addition to the formal test split, the notebook runs a fixed set of six hand-designed "stress test" scenarios intended to probe known edge cases. Reported results from an executed run:

Scenario Expected Predicted Result
Deep winter polar conditions snow snow PASS
Saturated calm air fog fog PASS
Extreme wind + low pressure windstorm cold_front FAIL
Hot dry high pressure summer clear clear PASS
Freezing rain window (βˆ’1 Β°C, high humidity) freezing_rain freezing_rain PASS
Classic thunderstorm thunderstorm thunderstorm PASS

Summary: 5 / 6 stress tests passed.

The single failure (extreme wind + low pressure misclassified as cold_front instead of windstorm) reflects an overlap between the synthetic scoring rules for these two classes: both are driven by falling pressure and elevated wind speed, and the model's confidence was split between them (β‰ˆ0.86 windstorm-favoring features vs. a meaningful cold_front probability mass in the reported softmax distribution). This indicates a labeling-rule ambiguity in the synthetic data generator rather than a clear model failure.


Limitations

  • The training data is entirely synthetic and generated from simplified, hand-written heuristics, not real meteorological observations or physical simulation.
  • This model is not suitable for real-world weather forecasting or any operational meteorological use.
  • The "physical" rules encoding label assignment (e.g. temperature thresholds for snow, humidity/wind gating for fog) are simplifications and do not capture real atmospheric dynamics.
  • Class boundaries between physically similar phenomena (e.g. windstorm vs. cold front) can overlap in the synthetic label-generation logic, as shown in the stress test results above.
  • Intended use is restricted to experimentation, architecture benchmarking, and research/educational purposes within the SupraLabs pipeline.

Get started πŸš€

import json
import math
from pathlib import Path
from typing import Optional, Dict, List

import numpy as np
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
from safetensors.torch import load_file


# ──────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────
class WeatherFTConfig(PretrainedConfig):
    model_type = "weather_ft_transformer"

    def __init__(
        self,
        num_continuous: int = 6,
        num_wind_dirs: int = 8,
        num_months: int = 12,
        num_air_masses: int = 6,
        d_token: int = 192,
        n_blocks: int = 3,
        n_heads: int = 8,
        d_ffn_factor: float = 4.0 / 3.0,
        attention_dropout: float = 0.2,
        ffn_dropout: float = 0.1,
        residual_dropout: float = 0.0,
        num_labels: int = 11,
        means: Optional[Dict] = None,
        stds: Optional[Dict] = None,
        label2id: Optional[Dict] = None,
        id2label: Optional[Dict] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.num_continuous = num_continuous
        self.num_wind_dirs = num_wind_dirs
        self.num_months = num_months
        self.num_air_masses = num_air_masses
        self.d_token = d_token
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_ffn_factor = d_ffn_factor
        self.attention_dropout = attention_dropout
        self.ffn_dropout = ffn_dropout
        self.residual_dropout = residual_dropout
        self.num_labels = num_labels
        self.means = means or {}
        self.stds = stds or {}
        self.label2id = label2id or {}
        self.id2label = (
            {int(k): v for k, v in id2label.items()} if id2label else {}
        )


# ──────────────────────────────────────────────────────────────────────────
# Model
# ──────────────────────────────────────────────────────────────────────────
class NumericalFeatureTokenizer(nn.Module):
    """Projects each continuous feature independently: token_i = x_i * W_i + b_i."""

    def __init__(self, n_features: int, d_token: int):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(n_features, d_token))
        self.bias = nn.Parameter(torch.empty(n_features, d_token))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.unsqueeze(-1) * self.weight.unsqueeze(0) + self.bias.unsqueeze(0)


class CategoricalFeatureTokenizer(nn.Module):
    """Embeds each categorical feature independently."""

    def __init__(self, cardinalities: List[int], d_token: int):
        super().__init__()
        self.embeddings = nn.ModuleList(
            [nn.Embedding(c + 1, d_token) for c in cardinalities]
        )

    def forward(self, cats: List[torch.Tensor]) -> torch.Tensor:
        return torch.stack(
            [emb(cat) for emb, cat in zip(self.embeddings, cats)], dim=1
        )


class FTTransformerBlock(nn.Module):
    def __init__(self, config: WeatherFTConfig):
        super().__init__()
        D = config.d_token
        D_ffn = max(int(D * config.d_ffn_factor), 1)
        self.ln1 = nn.LayerNorm(D)
        self.attn = nn.MultiheadAttention(
            D, config.n_heads, dropout=config.attention_dropout, batch_first=True
        )
        self.attn_drop = nn.Dropout(config.residual_dropout)
        self.ln2 = nn.LayerNorm(D)
        self.ffn = nn.Sequential(
            nn.Linear(D, D_ffn),
            nn.GELU(),
            nn.Dropout(config.ffn_dropout),
            nn.Linear(D_ffn, D),
        )
        self.ffn_drop = nn.Dropout(config.residual_dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.ln1(x)
        h, _ = self.attn(h, h, h)
        x = x + self.attn_drop(h)
        h = self.ffn(self.ln2(x))
        x = x + self.ffn_drop(h)
        return x


class WeatherFTTransformer(PreTrainedModel):
    config_class = WeatherFTConfig

    def __init__(self, config: WeatherFTConfig):
        super().__init__(config)
        D = config.d_token

        self.num_tokenizer = NumericalFeatureTokenizer(config.num_continuous, D)
        self.cat_tokenizer = CategoricalFeatureTokenizer(
            cardinalities=[
                config.num_wind_dirs,
                config.num_months,
                config.num_air_masses,
            ],
            d_token=D,
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, D))
        self.blocks = nn.ModuleList(
            [FTTransformerBlock(config) for _ in range(config.n_blocks)]
        )
        self.ln_final = nn.LayerNorm(D)
        self.head = nn.Sequential(
            nn.Linear(D, D),
            nn.ReLU(),
            nn.Dropout(config.ffn_dropout),
            nn.Linear(D, config.num_labels),
        )
        self.post_init()

    def forward(
        self,
        continuous: torch.Tensor,
        wind_direction: torch.Tensor,
        month: torch.Tensor,
        air_mass: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Dict:
        B = continuous.size(0)
        num_tokens = self.num_tokenizer(continuous)
        cat_tokens = self.cat_tokenizer([wind_direction, month, air_mass])
        cls = self.cls_token.expand(B, -1, -1)
        tokens = torch.cat([cls, num_tokens, cat_tokens], dim=1)

        for block in self.blocks:
            tokens = block(tokens)

        cls_repr = self.ln_final(tokens[:, 0])
        logits = self.head(cls_repr)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}


AutoConfig.register("weather_ft_transformer", WeatherFTConfig)
AutoModel.register(WeatherFTConfig, WeatherFTTransformer)


# ──────────────────────────────────────────────────────────────────────────
# Loading helpers
# ──────────────────────────────────────────────────────────────────────────
def load_model(model_dir: str):
    """Loads config, weights, and normalisation stats from a saved model directory."""
    model_dir = Path(model_dir)

    config = WeatherFTConfig.from_pretrained(str(model_dir))
    model = WeatherFTTransformer(config)
    state_dict = load_file(str(model_dir / "model.safetensors"))
    model.load_state_dict(state_dict)
    model.eval()

    with open(model_dir / "normalisation_stats.json") as f:
        norm_stats = json.load(f)

    return model, config, norm_stats


# ──────────────────────────────────────────────────────────────────────────
# Inference helpers
# ──────────────────────────────────────────────────────────────────────────
def preprocess_single(
    norm_stats: Dict,
    temperature: float,
    humidity: float,
    pressure: float,
    pressure_trend: float,
    wind_speed: float,
    wind_direction: str,
    altitude: float,
    month: int,
    air_mass: str,
) -> Dict[str, torch.Tensor]:
    means = norm_stats["means"]
    stds = norm_stats["stds"]
    continuous_cols = norm_stats["continuous_cols"]
    wind_dir_map = norm_stats["wind_dir_map"]
    air_mass_map = norm_stats["air_mass_map"]

    cont_raw = np.array(
        [temperature, humidity, pressure, pressure_trend, wind_speed, altitude],
        dtype=np.float32,
    )
    m = np.array([means[c] for c in continuous_cols], dtype=np.float32)
    s = np.array([stds[c] for c in continuous_cols], dtype=np.float32)
    cont_norm = (cont_raw - m) / (s + 1e-8)

    return {
        "continuous": torch.tensor(cont_norm).unsqueeze(0),
        "wind_direction": torch.tensor([wind_dir_map.get(wind_direction.lower(), 0)]),
        "month": torch.tensor([month]),
        "air_mass": torch.tensor([air_mass_map.get(air_mass.lower(), 0)]),
    }


@torch.no_grad()
def predict_single(model, config, norm_stats, **kwargs):
    inputs = preprocess_single(norm_stats, **kwargs)
    out = model(**inputs)
    probs = torch.softmax(out["logits"], dim=-1).squeeze(0).numpy()
    pred_id = int(np.argmax(probs))
    pred_label = config.id2label[pred_id]
    return pred_label, pred_id, probs


# ──────────────────────────────────────────────────────────────────────────
# Example
# ──────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    MODEL_DIR = "SupraLabs/SupraWeather-Nano-Preview"  # local path or HF repo id

    model, config, norm_stats = load_model(MODEL_DIR)

    pred_label, pred_id, probs = predict_single(
        model,
        config,
        norm_stats,
        temperature=30,
        humidity=30,
        pressure=1025,
        pressure_trend=1,
        wind_speed=5,
        wind_direction="east",
        altitude=100,
        month=7,
        air_mass="tropical",
    )

    print(f"Predicted class : {pred_label} (id={pred_id})\n")
    print("Class probabilities:")
    for i in np.argsort(probs)[::-1]:
        bar = "β–ˆ" * int(probs[i] * 40)
        print(f"  {config.id2label[i]:<14} {probs[i]:.4f}  {bar}")

License

Apache 2.0

Downloads last month
64
Safetensors
Model size
791k params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Spaces using SupraLabs/SupraWeather-Nano-Preview 2