soilformer / modelling /embed_categorical.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# embed_categorical.py
# -*- coding: utf-8 -*-
"""
Categorical embedding module for tabular transformer.
Design:
- Each categorical column = 1 token
- Value embedding: ONE global lookup table using (offset + local_id)
- ID embedding: ONE categorical column-ID embedding table
- Explicit col_id stored in cat_vocab.json (no implicit ordering assumptions)
Outputs:
local_ids [B,M] -> tokens [B,M,H]
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from utils import load_json, save_json
SPECIAL_MASK = "__MASK__"
# ============================================================
# Meta → categorical column list
# ============================================================
def get_categorical_feature_names_from_meta(tabular_meta: Dict) -> List[str]:
"""
Deterministic ordering:
alphabetical by feature name.
"""
cols = []
for k, v in tabular_meta.items():
if v.get("dataclass") == "categorical" and not v.get("is_array_valued", False):
cols.append(k)
return sorted(cols)
# ============================================================
# Vocab spec
# ============================================================
@dataclass
class CatColSpec:
name: str
col_id: int
offset: int
num_classes: int
mask_local_id: int
label2id: Dict[str, int]
def build_cat_vocab_spec_from_meta(
tabular_meta: Dict,
categorical_feature_names: List[str],
label_order: str = "alpha",
) -> Dict[str, CatColSpec]:
vocab: Dict[str, CatColSpec] = {}
offset = 0
for j, col in enumerate(categorical_feature_names):
info = tabular_meta[col]
class_stats = info.get("class_stats", {}) or {}
# deterministic label order
if label_order == "alpha":
labels = sorted(class_stats.keys())
elif label_order == "freq_desc":
labels = sorted(class_stats.keys(), key=lambda k: (-class_stats[k], k))
else:
raise ValueError("label_order must be alpha or freq_desc")
label2id = {lab: i for i, lab in enumerate(labels)}
mask_local_id = len(labels)
label2id[SPECIAL_MASK] = mask_local_id
spec = CatColSpec(
name=col,
col_id=j, # EXPLICIT categorical column id
offset=offset,
num_classes=mask_local_id + 1,
mask_local_id=mask_local_id,
label2id=label2id,
)
vocab[col] = spec
offset += spec.num_classes
return vocab
def save_cat_vocab_json(vocab: Dict[str, CatColSpec], path: str) -> None:
out = {}
for col, spec in vocab.items():
out[col] = {
"col_id": spec.col_id,
"offset": spec.offset,
"num_classes": spec.num_classes,
"mask_local_id": spec.mask_local_id,
"global_id_start": spec.offset,
"global_id_end": spec.offset + spec.num_classes - 1,
"label2id": spec.label2id,
}
save_json(out, path)
# ============================================================
# Embedding modules
# ============================================================
class CategoricalValueEmbedding(nn.Module):
"""
Global value embedding using offsets.
"""
def __init__(self, hidden_size: int, cat_vocab_json: str):
super().__init__()
spec = load_json(cat_vocab_json)
# sort by col_id to ensure consistent tensor layout
items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
offsets = []
num_classes = []
col_ids = []
total_vocab = 0
for name, s in items:
offsets.append(int(s["offset"]))
num_classes.append(int(s["num_classes"]))
col_ids.append(int(s["col_id"]))
total_vocab = max(total_vocab, s["offset"] + s["num_classes"])
self.hidden_size = int(hidden_size)
self.total_vocab_size = int(total_vocab)
# Merge all classes to avoid many small nn.Embedding modules
self.emb = nn.Embedding(self.total_vocab_size, self.hidden_size)
self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long), persistent=True)
self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True)
self.register_buffer("col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True)
def init_weights(self, std=0.02):
nn.init.normal_(self.emb.weight, std=std)
def forward(self, local_ids: torch.LongTensor) -> torch.Tensor:
"""
local_ids: [B,M]
returns: [B,M,H]
"""
if local_ids.dim() != 2:
raise ValueError("local_ids must be [B,M]")
B, M = local_ids.shape
if M != self.offsets.numel():
raise ValueError("Column count mismatch")
if torch.any(local_ids < 0):
raise ValueError("Negative local_id")
nc = self.num_classes.view(1, M).expand(B, M)
if torch.any(local_ids >= nc):
raise ValueError("local_ids out of range")
gid = self.offsets.view(1, M) + local_ids
return self.emb(gid)
class CategoricalIdEmbedding(nn.Module):
"""
Explicit categorical column ID embedding.
"""
def __init__(self, hidden_size: int, cat_vocab_json: str):
super().__init__()
spec = load_json(cat_vocab_json)
items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
col_ids = [s["col_id"] for _, s in items]
max_col_id = max(col_ids)
self.emb = nn.Embedding(max_col_id + 1, hidden_size)
self.register_buffer(
"cat_col_ids",
torch.tensor(col_ids, dtype=torch.long),
persistent=True,
)
self.hidden_size = hidden_size
def init_weights(self, std=0.02):
nn.init.normal_(self.emb.weight, std=std)
def forward(self, batch_size: int) -> torch.Tensor:
"""
returns [B,M,H]
"""
id_vec = self.emb(self.cat_col_ids) # [M,H]
return id_vec.view(1, -1, self.hidden_size).expand(batch_size, -1, -1)
class CategoricalEmbedding(nn.Module):
"""
token = value_embedding + categorical_id_embedding
"""
def __init__(self, hidden_size: int, cat_vocab_json: str):
super().__init__()
self.value_emb = CategoricalValueEmbedding(hidden_size, cat_vocab_json)
self.id_emb = CategoricalIdEmbedding(hidden_size, cat_vocab_json)
def init_weights(self, std=0.02):
self.value_emb.init_weights(std=std)
self.id_emb.init_weights(std=std)
def forward(
self,
local_ids: torch.LongTensor, # [B, M]
valid_positions: Optional[torch.Tensor] = None, # Bool [B,M] (True=valid) or indices [K,2]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
tokens: [B, M, H]
token_mask: [B, M] (1=valid, 0=invalid)
"""
if local_ids.dim() != 2:
raise ValueError(f"local_ids must be [B,M], got {tuple(local_ids.shape)}")
B, M = local_ids.shape
tokens = self.value_emb(local_ids) + self.id_emb(B) # [B,M,H]
# Default: all tokens are valid
valid = torch.ones((B, M), dtype=torch.bool, device=local_ids.device)
if valid_positions is not None:
if valid_positions.dtype == torch.bool:
if valid_positions.shape != (B, M):
raise ValueError(
f"valid_positions (bool) must be [B,M]=({B}, {M}), got {tuple(valid_positions.shape)}")
valid = valid_positions.to(device=local_ids.device)
else:
# Optional: support index pairs [K,2] where each row is (b_idx, m_idx) for valid positions
if valid_positions.dim() != 2 or valid_positions.size(1) != 2:
raise ValueError("valid_positions (indices) must be [K,2] with (batch_idx, col_idx)")
valid = torch.zeros((B, M), dtype=torch.bool, device=local_ids.device)
b_idx = valid_positions[:, 0].to(device=local_ids.device, dtype=torch.long)
m_idx = valid_positions[:, 1].to(device=local_ids.device, dtype=torch.long)
valid[b_idx, m_idx] = True
# Token mask: 1=valid, 0=invalid
token_mask = valid.to(dtype=torch.long) # [B,M]
# This is WRONG: we should allow __MASK__ to attend other columns
# # Invalid tokens must not contribute
# invalid = ~valid
# if invalid.any():
# tokens = tokens.masked_fill(invalid.unsqueeze(-1), 0.0)
return tokens, token_mask
# ============================================================
# DEMO
# ============================================================
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json")
parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json")
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--batch_size", type=int, default=4)
args = parser.parse_args()
tabular_meta = load_json(args.tabular_meta)
cat_names = get_categorical_feature_names_from_meta(tabular_meta)
print(f"Found {len(cat_names)} categorical columns")
vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names)
save_cat_vocab_json(vocab, args.cat_vocab_json)
print(f"Saved vocab to {args.cat_vocab_json}")
model = CategoricalEmbedding(
hidden_size=args.hidden_size,
cat_vocab_json=args.cat_vocab_json,
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters (CategoricalEmbedding): {total_params:,} (trainable: {trainable_params:,})")
B = args.batch_size
M = len(cat_names)
local_ids = torch.zeros((B, M), dtype=torch.long)
with torch.no_grad():
out, mask = model(local_ids)
print("local_ids:", tuple(local_ids.shape))
print("output:", tuple(out.shape)) # [B,M,H]
print("mask:", tuple(mask.shape)) # [B,M]
if __name__ == "__main__":
_demo_main()