|
from torch import nn
|
|
import torch
|
|
from models.bin import BiN
|
|
|
|
class MLPLOB(nn.Module):
|
|
def __init__(self,
|
|
hidden_dim: int,
|
|
num_layers: int,
|
|
seq_size: int,
|
|
num_features: int,
|
|
dataset_type: str
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.num_layers = num_layers
|
|
self.dataset_type = dataset_type
|
|
self.layers = nn.ModuleList()
|
|
self.order_type_embedder = nn.Embedding(3, 1)
|
|
self.first_layer = nn.Linear(num_features, hidden_dim)
|
|
self.norm_layer = BiN(num_features, seq_size)
|
|
self.layers.append(self.first_layer)
|
|
self.layers.append(nn.GELU())
|
|
for i in range(num_layers):
|
|
if i != num_layers-1:
|
|
self.layers.append(MLP(hidden_dim, hidden_dim*4, hidden_dim))
|
|
self.layers.append(MLP(seq_size, seq_size*4, seq_size))
|
|
else:
|
|
self.layers.append(MLP(hidden_dim, hidden_dim*2, hidden_dim//4))
|
|
self.layers.append(MLP(seq_size, seq_size*2, seq_size//4))
|
|
|
|
total_dim = (hidden_dim//4)*(seq_size//4)
|
|
self.final_layers = nn.ModuleList()
|
|
while total_dim > 128:
|
|
self.final_layers.append(nn.Linear(total_dim, total_dim//4))
|
|
self.final_layers.append(nn.GELU())
|
|
total_dim = total_dim//4
|
|
self.final_layers.append(nn.Linear(total_dim, 3))
|
|
|
|
def forward(self, input):
|
|
if self.dataset_type == "LOBSTER":
|
|
continuous_features = torch.cat([input[:, :, :41], input[:, :, 42:]], dim=2)
|
|
order_type = input[:, :, 41].long()
|
|
order_type_emb = self.order_type_embedder(order_type).detach()
|
|
x = torch.cat([continuous_features, order_type_emb], dim=2)
|
|
else:
|
|
x = input
|
|
x = x.permute(0, 2, 1)
|
|
x = self.norm_layer(x)
|
|
x = x.permute(0, 2, 1)
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
x = x.permute(0, 2, 1)
|
|
x = x.reshape(x.shape[0], -1)
|
|
for layer in self.final_layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self,
|
|
start_dim: int,
|
|
hidden_dim: int,
|
|
final_dim: int
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.layer_norm = nn.LayerNorm(final_dim)
|
|
self.fc = nn.Linear(start_dim, hidden_dim)
|
|
self.fc2 = nn.Linear(hidden_dim, final_dim)
|
|
self.gelu = nn.GELU()
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
x = self.fc(x)
|
|
x = self.gelu(x)
|
|
x = self.fc2(x)
|
|
if x.shape[2] == residual.shape[2]:
|
|
x = x + residual
|
|
x = self.layer_norm(x)
|
|
x = self.gelu(x)
|
|
return x
|
|
|
|
|