soilformer / modelling /decode_numeric.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# decode_numeric.py
# -*- coding: utf-8 -*-
"""
Numeric decoder module for tabular transformer.
Symmetric to embed_numeric.py (bucketed by n_in):
- For each bucket (same n_in), we decode tokens without a Python for-loop over columns.
- Uses a batched per-variable MLP with per-column parameters (NOT shared across V).
Input:
x_tokens: [B, total_numeric_tokens, H]
token order must match numeric_vocab.json:
groups by n_in ascending, within group by feature name,
and within each feature: n_in tokens.
Output:
values_by_nin: Dict[int, Tensor]
n_in -> x_hat [B, V, n_in]
middle_size:
- None: 1-layer per-variable Linear
- int : 2-layer per-variable MLP (Linear -> GELU -> Linear)
"""
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from utils import GroupedMLP, load_json
class NumericDecoder(nn.Module):
"""
Decode numeric tokens back to numeric values, bucketed by n_in.
Input:
x_tokens: [B, total_numeric_tokens, H]
Output:
values_by_nin:
n_in -> y_hat [B, V, n_in]
s_by_nin:
n_in -> s [B, V]
where s = log(sigma^2), shared across the n_in dimensions
of each variable, intended for heteroscedastic loss computation.
"""
def __init__(
self,
hidden_size: int,
numeric_vocab_json: str,
middle_size: Optional[int] = None,
homoscedastic: bool = True,
):
super().__init__()
self.hidden_size = int(hidden_size)
self.middle_size = None if middle_size is None else int(middle_size)
self.homoscedastic = bool(homoscedastic)
spec = load_json(numeric_vocab_json)
self.groups: List[Dict] = list(spec["groups"])
self.total_numeric_tokens = int(spec["total_numeric_tokens"])
self.group_token_offsets: Dict[str, int] = dict(spec.get("group_token_offsets", {}))
self.group_v_decoders = nn.ModuleList()
self.group_s_decoders = nn.ModuleList()
self.group_nins: List[int] = []
self.group_Vs: List[int] = []
for g in self.groups:
n_in = int(g["n_in"])
names = list(g["feature_names"])
V = len(names)
self.group_nins.append(n_in) # noqa
self.group_Vs.append(V)
# value decoder: [B,V,n_in*H] -> [B,V,n_in]
self.group_v_decoders.append(
GroupedMLP(
n_var=V,
n_in=n_in * self.hidden_size,
n_out=n_in,
middle_size=self.middle_size,
)
)
# uncertainty decoder: [B,V,H] -> [B,V,1] -> [B,V]
if not self.homoscedastic:
self.group_s_decoders.append(
GroupedMLP(
n_var=V,
n_in=self.hidden_size,
n_out=1,
middle_size=self.middle_size,
)
)
if self.homoscedastic:
self.group_s_params = nn.ParameterList(
[nn.Parameter(torch.zeros(V)) for V in self.group_Vs]
)
else:
self.group_s_params = None
# spec integrity check
running = 0
for g in self.groups:
n_in = int(g["n_in"])
V = len(g["feature_names"])
key = str(n_in)
if key not in self.group_token_offsets:
raise ValueError(f"Missing group_token_offsets entry for n_in={n_in}")
if int(self.group_token_offsets[key]) != running:
raise ValueError(
f"group_token_offsets[{key}]={self.group_token_offsets[key]} does not match expected {running}"
)
running += V * n_in
if running != self.total_numeric_tokens:
raise ValueError(
f"total_numeric_tokens={self.total_numeric_tokens} does not match expected {running}"
)
def init_weights(self, std: float = 0.02):
for dec in self.group_v_decoders:
dec.init_weights(std=std)
if self.homoscedastic:
for p in self.group_s_params:
nn.init.zeros_(p)
else:
for dec in self.group_s_decoders:
dec.init_weights(std=0.0)
def forward(self, x_tokens: torch.Tensor):
if x_tokens.dim() != 3:
raise ValueError(f"x_tokens must be [B,T,H], got {tuple(x_tokens.shape)}")
B, T, H = x_tokens.shape
if H != self.hidden_size:
raise ValueError(f"hidden_size mismatch: got H={H}, expected {self.hidden_size}")
if T != self.total_numeric_tokens:
raise ValueError(f"token length mismatch: got T={T}, expected {self.total_numeric_tokens}")
value_out: Dict[int, torch.Tensor] = {}
s_out: Dict[int, torch.Tensor] = {}
for gi, n_in in enumerate(self.group_nins):
key = str(n_in)
start = int(self.group_token_offsets[key])
V = self.group_Vs[gi]
length = V * n_in
xg_tok = x_tokens[:, start:start + length, :] # [B, V*n_in, H]
xg_tok4 = xg_tok.reshape(B, V, n_in, H) # [B, V, n_in, H]
xg_flat = xg_tok4.reshape(B, V, n_in * H) # [B, V, n_in*H]
# values: [B, V, n_in]
y = self.group_v_decoders[gi](xg_flat)
# s = log sigma^2: [B, V]
if self.homoscedastic:
s = self.group_s_params[gi].unsqueeze(0).expand(B, -1)
else:
x_var = xg_tok4.mean(dim=2) # [B, V, H]
s = self.group_s_decoders[gi](x_var).squeeze(-1) # [B, V]
value_out[n_in] = y
s_out[n_in] = s
return value_out, s_out
# ============================================================
# DEMO
# ============================================================
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--middle_size", type=int, default=-1,
help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
args = parser.parse_args()
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
dtype = dtype_map[args.dtype]
# Directly load existing numeric vocab spec
spec = load_json(args.numeric_vocab_json)
print(f"Loaded numeric vocab spec from: {args.numeric_vocab_json}")
print(f"Groups (n_in -> V):", {int(g['n_in']): len(g['feature_names']) for g in spec["groups"]})
print("total_numeric_tokens:", spec["total_numeric_tokens"])
print("group_token_offsets:", spec["group_token_offsets"])
middle_size = None if args.middle_size < 0 else int(args.middle_size)
model = NumericDecoder(
hidden_size=args.hidden_size,
numeric_vocab_json=args.numeric_vocab_json,
middle_size=middle_size,
).to(device=device, dtype=dtype)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters (NumericDecoder): {total_params:,} (trainable: {trainable_params:,})")
B = args.batch_size
T = int(spec["total_numeric_tokens"])
H = args.hidden_size
x_tokens = torch.randn(B, T, H, device=device, dtype=dtype)
with torch.no_grad():
values_by_nin, s_by_nin = model(x_tokens)
print("Input tokens:", tuple(x_tokens.shape), x_tokens.dtype, x_tokens.device)
print("Decoded values:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
print("Decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
# values_by_nin[n_in]: [B, V, n_in]
# s_by_nin[n_in]: [B, V]
if __name__ == "__main__":
_demo_main()