rtferraz commited on
Commit
d685c0e
·
verified ·
1 Parent(s): 0dec8e4

Add PLR embeddings (Gorishniy et al. 2022)

Browse files
src/domain_tokenizer/models/plr_embeddings.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PLR (Periodic Linear ReLU) Numerical Embeddings.
3
+
4
+ Maps scalar numerical features to high-dimensional dense vectors via
5
+ learned periodic (sin/cos) activations followed by a linear projection.
6
+
7
+ From: Gorishniy et al. 2022, "On Embeddings for Numerical Features in
8
+ Tabular Deep Learning" (arXiv:2203.05556, NeurIPS 2022).
9
+
10
+ Used by Nubank nuFormer for the tabular feature branch (291 features).
11
+ PLR is the ingredient that makes DCNv2 beat LightGBM.
12
+ """
13
+
14
+ import math
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class PeriodicLinearReLU(nn.Module):
20
+ """PLR numerical embeddings (Gorishniy et al. 2022).
21
+
22
+ Maps each scalar feature through learned periodic activations:
23
+ x -> [sin(2pi*w*x + b), cos(2pi*w*x + b)] -> Linear -> ReLU
24
+
25
+ Frequencies w and phases b are LEARNED parameters (per feature).
26
+
27
+ Args:
28
+ n_features: Number of numerical features.
29
+ n_frequencies: Number of sin/cos frequency pairs per feature.
30
+ embedding_dim: Output embedding dimension per feature.
31
+
32
+ Input: (batch, n_features) -- raw scalar feature values
33
+ Output: (batch, n_features, embedding_dim)
34
+ """
35
+
36
+ def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64):
37
+ super().__init__()
38
+ self.n_features = n_features
39
+ self.n_frequencies = n_frequencies
40
+ self.embedding_dim = embedding_dim
41
+
42
+ self.frequencies = nn.Parameter(torch.randn(n_features, n_frequencies) * 0.01)
43
+ self.phases = nn.Parameter(torch.zeros(n_features, n_frequencies))
44
+ self.linear = nn.Linear(2 * n_frequencies, embedding_dim)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ x = x.unsqueeze(-1)
48
+ angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0)
49
+ periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
50
+ return torch.relu(self.linear(periodic))
51
+
52
+ def extra_repr(self) -> str:
53
+ return (f"n_features={self.n_features}, n_frequencies={self.n_frequencies}, "
54
+ f"embedding_dim={self.embedding_dim}")