lob-pattern-net / model.py
kangkangchen's picture
Upload folder using huggingface_hub
8171f7d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class BilinearNorm(nn.Module):
def __init__(self, d):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1,1,d))
self.beta = nn.Parameter(torch.zeros(1,1,d))
self.gate = nn.Parameter(torch.ones(1,1,d))
def forward(self, x):
m = x.mean(1, keepdim=True)
s = x.std(1, keepdim=True) + 1e-8
xn = (x - m) / s
g = torch.sigmoid(self.gate)
return g * (self.gamma * xn + self.beta) + (1 - g) * x
class LOBAlgoNet(nn.Module):
"""
CNN + Transformer model for algorithmic order pattern detection.
Input: (B, T=100, 40) normalized LOB snapshots
Output: (B, 5) logits for [TWAP, VWAP, ICEBERG, SUPPORT, NORMAL]
"""
def __init__(self, num_classes=5, d_model=128, nhead=4, dropout=0.25):
super().__init__()
self.norm = BilinearNorm(40)
# Spatial CNN: cross-level patterns
self.spatial = nn.Sequential(
nn.Conv2d(1, 32, (1,2), stride=(1,2)), # 40β†’20
nn.BatchNorm2d(32), nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, (1,2), stride=(1,2)), # 20β†’10
nn.BatchNorm2d(32), nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, (1,10)), # 10β†’1
nn.BatchNorm2d(32), nn.LeakyReLU(0.01),
)
# Temporal CNN: multi-scale temporal features
self.temporal = nn.Sequential(
nn.Conv1d(32, 64, 3, padding=1), nn.BatchNorm1d(64), nn.LeakyReLU(0.01), nn.Dropout(dropout),
nn.Conv1d(64, 64, 5, padding=2), nn.BatchNorm1d(64), nn.LeakyReLU(0.01), nn.Dropout(dropout),
nn.Conv1d(64, d_model, 3, padding=1), nn.BatchNorm1d(d_model), nn.LeakyReLU(0.01), nn.Dropout(dropout),
)
# Transformer attention
enc_layer = nn.TransformerEncoderLayer(d_model, nhead, d_model*2, dropout, batch_first=True, activation='gelu')
self.attention = nn.TransformerEncoder(enc_layer, num_layers=2)
# Classifier
self.head = nn.Sequential(
nn.LayerNorm(d_model), nn.Dropout(dropout),
nn.Linear(d_model, 64), nn.GELU(), nn.Dropout(dropout),
nn.Linear(64, num_classes)
)
self._init()
def _init(self):
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d)):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None: nn.init.zeros_(m.bias)
def forward(self, x):
x = self.norm(x)
x = self.spatial(x.unsqueeze(1)).squeeze(-1) # (B,32,T)
x = self.temporal(x) # (B,d_model,T)
x = self.attention(x.permute(0,2,1)) # (B,T,d_model)
x = x.mean(dim=1) # (B,d_model)
return self.head(x)
def get_embeddings(self, x):
"""ζε–η‰ΉεΎε‘ι‡οΌŒη”¨δΊŽθšη±»/ε―θ§†εŒ–εˆ†ζž"""
x = self.norm(x)
x = self.spatial(x.unsqueeze(1)).squeeze(-1)
x = self.temporal(x)
x = self.attention(x.permute(0,2,1))
return x.mean(dim=1) # (B, d_model)