SupraWeather-Nano-Preview
SupraWeather-Nano-Preview is an FT-Transformer (Feature Tokenizer Transformer) trained to classify a weather phenomenon from a small set of tabular meteorological features (temperature, humidity, pressure, wind, altitude, month, and air mass type).
This model is trained exclusively on a synthetically generated dataset and is NOT intended for real-world weather forecasting. The synthetic data was produced by a rule-based generator that assigns labels using simplified, hand-crafted physical heuristics (e.g. hard temperature/humidity gating per class) rather than real observational or reanalysis weather data. The model should be treated as an architecture/training-pipeline demonstration, not as a meteorological tool.
Architecture
Type: FT-Transformer (Feature Tokenizer + Transformer Encoder), following the general design described by Gorishniy et al. (2021), "Revisiting Deep Learning Models for Tabular Data."
Extracted directly from the model configuration in the notebook:
| Component | Value |
|---|---|
d_token (hidden size) |
192 |
n_blocks (transformer layers) |
3 |
n_heads (attention heads) |
8 |
d_ffn_factor |
4/3 |
attention_dropout |
0.2 |
ffn_dropout |
0.1 |
residual_dropout |
0.0 |
num_labels |
11 |
| Normalization | Pre-LayerNorm transformer blocks |
| Classification head | Linear(192β192) β ReLU β Dropout β Linear(192β11) |
| Tokens per sample | 1 CLS + 6 numerical tokens + 3 categorical tokens = 10 |
Parameter count: 500k-700k
Feature tokenization:
- Continuous features use a per-feature learned affine projection (
NumericalFeatureTokenizer):token_i = x_i * W_i + b_i. - Categorical features (
wind_direction,month,air_mass) use independentnn.Embeddinglayers (CategoricalFeatureTokenizer). - A learnable CLS token aggregates all feature tokens through the transformer stack; its final representation is passed to the classification head.
Inputs
| Feature | Type | Notes |
|---|---|---|
temperature |
float (Β°C) | Continuous |
humidity |
float (%) | Continuous |
pressure |
float (hPa) | Continuous |
pressure_trend |
float (hPa / 3h) | Continuous |
wind_speed |
float (km/h) | Continuous |
altitude |
float (m) | Continuous |
wind_direction |
str | Categorical β one of: n/ne/e/se/s/sw/w/nw (or full names) |
month |
int (1β12) | Categorical |
air_mass |
str | Categorical β one of: polar, arctic, continental, maritime, tropical, equatorial |
Continuous features are normalized using means/standard deviations computed on the training split (stored alongside the model as normalisation_stats.json).
Output Classes
The model predicts one of 11 classes:
0 clear
1 cloudy
2 light_rain
3 heavy_rain
4 thunderstorm
5 snow
6 freezing_rain
7 soft_hail
8 fog
9 cold_front
10 windstorm
Dataset
The dataset is fully synthetic, generated programmatically (no real observational data is used).
- Size: 120,000 samples (
generate_dataset(n=120_000)). - Split: 80% train / 10% validation / 10% test.
- Feature generation: Continuous features are sampled from parametric distributions (e.g. exponential altitude, cosine-seasonal temperature with lapse-rate adjustment, beta-distributed humidity, gamma-distributed wind speed) with fixed random seed (
seed=42). - Label assignment: Each class is assigned a hand-crafted "logit score" as a function of the input features (e.g. hard exclusions such as snow being disallowed above +2 Β°C, freezing rain restricted to a β3 Β°C to +1 Β°C window, thunderstorms requiring temperature > 8 Β°C and humidity > 70%). Scores are scaled by a temperature factor (T = 0.15), softmax-normalized, and sampled; if the resulting maximum probability falls below 0.40, the label falls back to a deterministic argmax over the raw scores.
- Class weighting (
sklearn.utils.class_weight.compute_class_weight,balanced) is computed on the training split to address residual class imbalance.
No external or third-party dataset is used at any stage.
Training
Extracted from the notebook's TrainingArguments and custom Trainer subclass:
| Setting | Value |
|---|---|
| Epochs | 25 |
| Train batch size | 512 |
| Eval batch size | 1024 |
| Learning rate | 1e-3 |
| Weight decay | 1e-4 |
| Max grad norm | 1.0 |
| Warmup ratio | 0.06 |
| LR scheduler | Cosine |
| Eval / save strategy | Per epoch |
| Best model selection | Highest F1 (metric_for_best_model="f1") |
| Early stopping | Patience = 5 epochs |
| Mixed precision (fp16) | Enabled if CUDA is available |
| Optimizer | Hugging Face Trainer default (AdamW) β no custom optimizer is configured in the notebook |
| Loss function | Custom Focal Loss (gamma=2.0), with per-class alpha weights set from the balanced class weights computed on the training set |
| Class balancing | Focal loss alpha weighting only; no oversampling/undersampling or data augmentation is used |
Hardware
2xT4 from free Kaggle Quota
Training Time
5 Minutes(yes, not joking)
Evaluation
The notebook computes accuracy, precision (weighted), recall (weighted), F1 (weighted), a full classification_report, and a confusion matrix on the held-out test split via trainer.predict(). However, no executed run log with the resulting numeric values is available for this README, so quantitative test-set metrics are not reported here rather than estimated.
Stress Tests
In addition to the formal test split, the notebook runs a fixed set of six hand-designed "stress test" scenarios intended to probe known edge cases. Reported results from an executed run:
| Scenario | Expected | Predicted | Result |
|---|---|---|---|
| Deep winter polar conditions | snow | snow | PASS |
| Saturated calm air | fog | fog | PASS |
| Extreme wind + low pressure | windstorm | cold_front | FAIL |
| Hot dry high pressure summer | clear | clear | PASS |
| Freezing rain window (β1 Β°C, high humidity) | freezing_rain | freezing_rain | PASS |
| Classic thunderstorm | thunderstorm | thunderstorm | PASS |
Summary: 5 / 6 stress tests passed.
The single failure (extreme wind + low pressure misclassified as cold_front instead of windstorm) reflects an overlap between the synthetic scoring rules for these two classes: both are driven by falling pressure and elevated wind speed, and the model's confidence was split between them (β0.86 windstorm-favoring features vs. a meaningful cold_front probability mass in the reported softmax distribution). This indicates a labeling-rule ambiguity in the synthetic data generator rather than a clear model failure.
Limitations
- The training data is entirely synthetic and generated from simplified, hand-written heuristics, not real meteorological observations or physical simulation.
- This model is not suitable for real-world weather forecasting or any operational meteorological use.
- The "physical" rules encoding label assignment (e.g. temperature thresholds for snow, humidity/wind gating for fog) are simplifications and do not capture real atmospheric dynamics.
- Class boundaries between physically similar phenomena (e.g. windstorm vs. cold front) can overlap in the synthetic label-generation logic, as shown in the stress test results above.
- Intended use is restricted to experimentation, architecture benchmarking, and research/educational purposes within the SupraLabs pipeline.
Get started π
import json
import math
from pathlib import Path
from typing import Optional, Dict, List
import numpy as np
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
from safetensors.torch import load_file
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Config
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class WeatherFTConfig(PretrainedConfig):
model_type = "weather_ft_transformer"
def __init__(
self,
num_continuous: int = 6,
num_wind_dirs: int = 8,
num_months: int = 12,
num_air_masses: int = 6,
d_token: int = 192,
n_blocks: int = 3,
n_heads: int = 8,
d_ffn_factor: float = 4.0 / 3.0,
attention_dropout: float = 0.2,
ffn_dropout: float = 0.1,
residual_dropout: float = 0.0,
num_labels: int = 11,
means: Optional[Dict] = None,
stds: Optional[Dict] = None,
label2id: Optional[Dict] = None,
id2label: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
self.num_continuous = num_continuous
self.num_wind_dirs = num_wind_dirs
self.num_months = num_months
self.num_air_masses = num_air_masses
self.d_token = d_token
self.n_blocks = n_blocks
self.n_heads = n_heads
self.d_ffn_factor = d_ffn_factor
self.attention_dropout = attention_dropout
self.ffn_dropout = ffn_dropout
self.residual_dropout = residual_dropout
self.num_labels = num_labels
self.means = means or {}
self.stds = stds or {}
self.label2id = label2id or {}
self.id2label = (
{int(k): v for k, v in id2label.items()} if id2label else {}
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Model
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class NumericalFeatureTokenizer(nn.Module):
"""Projects each continuous feature independently: token_i = x_i * W_i + b_i."""
def __init__(self, n_features: int, d_token: int):
super().__init__()
self.weight = nn.Parameter(torch.empty(n_features, d_token))
self.bias = nn.Parameter(torch.empty(n_features, d_token))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.zeros_(self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.unsqueeze(-1) * self.weight.unsqueeze(0) + self.bias.unsqueeze(0)
class CategoricalFeatureTokenizer(nn.Module):
"""Embeds each categorical feature independently."""
def __init__(self, cardinalities: List[int], d_token: int):
super().__init__()
self.embeddings = nn.ModuleList(
[nn.Embedding(c + 1, d_token) for c in cardinalities]
)
def forward(self, cats: List[torch.Tensor]) -> torch.Tensor:
return torch.stack(
[emb(cat) for emb, cat in zip(self.embeddings, cats)], dim=1
)
class FTTransformerBlock(nn.Module):
def __init__(self, config: WeatherFTConfig):
super().__init__()
D = config.d_token
D_ffn = max(int(D * config.d_ffn_factor), 1)
self.ln1 = nn.LayerNorm(D)
self.attn = nn.MultiheadAttention(
D, config.n_heads, dropout=config.attention_dropout, batch_first=True
)
self.attn_drop = nn.Dropout(config.residual_dropout)
self.ln2 = nn.LayerNorm(D)
self.ffn = nn.Sequential(
nn.Linear(D, D_ffn),
nn.GELU(),
nn.Dropout(config.ffn_dropout),
nn.Linear(D_ffn, D),
)
self.ffn_drop = nn.Dropout(config.residual_dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.ln1(x)
h, _ = self.attn(h, h, h)
x = x + self.attn_drop(h)
h = self.ffn(self.ln2(x))
x = x + self.ffn_drop(h)
return x
class WeatherFTTransformer(PreTrainedModel):
config_class = WeatherFTConfig
def __init__(self, config: WeatherFTConfig):
super().__init__(config)
D = config.d_token
self.num_tokenizer = NumericalFeatureTokenizer(config.num_continuous, D)
self.cat_tokenizer = CategoricalFeatureTokenizer(
cardinalities=[
config.num_wind_dirs,
config.num_months,
config.num_air_masses,
],
d_token=D,
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, D))
self.blocks = nn.ModuleList(
[FTTransformerBlock(config) for _ in range(config.n_blocks)]
)
self.ln_final = nn.LayerNorm(D)
self.head = nn.Sequential(
nn.Linear(D, D),
nn.ReLU(),
nn.Dropout(config.ffn_dropout),
nn.Linear(D, config.num_labels),
)
self.post_init()
def forward(
self,
continuous: torch.Tensor,
wind_direction: torch.Tensor,
month: torch.Tensor,
air_mass: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict:
B = continuous.size(0)
num_tokens = self.num_tokenizer(continuous)
cat_tokens = self.cat_tokenizer([wind_direction, month, air_mass])
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, num_tokens, cat_tokens], dim=1)
for block in self.blocks:
tokens = block(tokens)
cls_repr = self.ln_final(tokens[:, 0])
logits = self.head(cls_repr)
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(logits, labels)
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
AutoConfig.register("weather_ft_transformer", WeatherFTConfig)
AutoModel.register(WeatherFTConfig, WeatherFTTransformer)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Loading helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_model(model_dir: str):
"""Loads config, weights, and normalisation stats from a saved model directory."""
model_dir = Path(model_dir)
config = WeatherFTConfig.from_pretrained(str(model_dir))
model = WeatherFTTransformer(config)
state_dict = load_file(str(model_dir / "model.safetensors"))
model.load_state_dict(state_dict)
model.eval()
with open(model_dir / "normalisation_stats.json") as f:
norm_stats = json.load(f)
return model, config, norm_stats
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Inference helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def preprocess_single(
norm_stats: Dict,
temperature: float,
humidity: float,
pressure: float,
pressure_trend: float,
wind_speed: float,
wind_direction: str,
altitude: float,
month: int,
air_mass: str,
) -> Dict[str, torch.Tensor]:
means = norm_stats["means"]
stds = norm_stats["stds"]
continuous_cols = norm_stats["continuous_cols"]
wind_dir_map = norm_stats["wind_dir_map"]
air_mass_map = norm_stats["air_mass_map"]
cont_raw = np.array(
[temperature, humidity, pressure, pressure_trend, wind_speed, altitude],
dtype=np.float32,
)
m = np.array([means[c] for c in continuous_cols], dtype=np.float32)
s = np.array([stds[c] for c in continuous_cols], dtype=np.float32)
cont_norm = (cont_raw - m) / (s + 1e-8)
return {
"continuous": torch.tensor(cont_norm).unsqueeze(0),
"wind_direction": torch.tensor([wind_dir_map.get(wind_direction.lower(), 0)]),
"month": torch.tensor([month]),
"air_mass": torch.tensor([air_mass_map.get(air_mass.lower(), 0)]),
}
@torch.no_grad()
def predict_single(model, config, norm_stats, **kwargs):
inputs = preprocess_single(norm_stats, **kwargs)
out = model(**inputs)
probs = torch.softmax(out["logits"], dim=-1).squeeze(0).numpy()
pred_id = int(np.argmax(probs))
pred_label = config.id2label[pred_id]
return pred_label, pred_id, probs
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Example
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
MODEL_DIR = "SupraLabs/SupraWeather-Nano-Preview" # local path or HF repo id
model, config, norm_stats = load_model(MODEL_DIR)
pred_label, pred_id, probs = predict_single(
model,
config,
norm_stats,
temperature=30,
humidity=30,
pressure=1025,
pressure_trend=1,
wind_speed=5,
wind_direction="east",
altitude=100,
month=7,
air_mass="tropical",
)
print(f"Predicted class : {pred_label} (id={pred_id})\n")
print("Class probabilities:")
for i in np.argsort(probs)[::-1]:
bar = "β" * int(probs[i] * 40)
print(f" {config.id2label[i]:<14} {probs[i]:.4f} {bar}")
License
Apache 2.0
- Downloads last month
- 64
