LeonardoBerti's picture
Upload 51 files
69524d0 verified
from torch import nn
import torch
from models.bin import BiN
class MLPLOB(nn.Module):
def __init__(self,
hidden_dim: int,
num_layers: int,
seq_size: int,
num_features: int,
dataset_type: str
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dataset_type = dataset_type
self.layers = nn.ModuleList()
self.order_type_embedder = nn.Embedding(3, 1)
self.first_layer = nn.Linear(num_features, hidden_dim)
self.norm_layer = BiN(num_features, seq_size)
self.layers.append(self.first_layer)
self.layers.append(nn.GELU())
for i in range(num_layers):
if i != num_layers-1:
self.layers.append(MLP(hidden_dim, hidden_dim*4, hidden_dim))
self.layers.append(MLP(seq_size, seq_size*4, seq_size))
else:
self.layers.append(MLP(hidden_dim, hidden_dim*2, hidden_dim//4))
self.layers.append(MLP(seq_size, seq_size*2, seq_size//4))
total_dim = (hidden_dim//4)*(seq_size//4)
self.final_layers = nn.ModuleList()
while total_dim > 128:
self.final_layers.append(nn.Linear(total_dim, total_dim//4))
self.final_layers.append(nn.GELU())
total_dim = total_dim//4
self.final_layers.append(nn.Linear(total_dim, 3))
def forward(self, input):
if self.dataset_type == "LOBSTER":
continuous_features = torch.cat([input[:, :, :41], input[:, :, 42:]], dim=2)
order_type = input[:, :, 41].long()
order_type_emb = self.order_type_embedder(order_type).detach()
x = torch.cat([continuous_features, order_type_emb], dim=2)
else:
x = input
x = x.permute(0, 2, 1)
x = self.norm_layer(x)
x = x.permute(0, 2, 1)
for layer in self.layers:
x = layer(x)
x = x.permute(0, 2, 1)
x = x.reshape(x.shape[0], -1)
for layer in self.final_layers:
x = layer(x)
return x
class MLP(nn.Module):
def __init__(self,
start_dim: int,
hidden_dim: int,
final_dim: int
) -> None:
super().__init__()
self.layer_norm = nn.LayerNorm(final_dim)
self.fc = nn.Linear(start_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, final_dim)
self.gelu = nn.GELU()
def forward(self, x):
residual = x
x = self.fc(x)
x = self.gelu(x)
x = self.fc2(x)
if x.shape[2] == residual.shape[2]:
x = x + residual
x = self.layer_norm(x)
x = self.gelu(x)
return x