import torch.nn as nn import pytorch_lightning as pl class MLP(pl.LightningModule): def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True): super().__init__() self.input_size = input_size self.xcol = xcol self.ycol = ycol self.layers = nn.Sequential( nn.Linear(self.input_size, 2048), nn.ReLU(), nn.BatchNorm1d(2048) if batch_norm else nn.Identity(), nn.Dropout(0.3), nn.Linear(2048, 512), nn.ReLU(), nn.BatchNorm1d(512) if batch_norm else nn.Identity(), nn.Dropout(0.3), nn.Linear(512, 256), nn.ReLU(), nn.BatchNorm1d(256) if batch_norm else nn.Identity(), nn.Dropout(0.2), nn.Linear(256, 128), nn.ReLU(), nn.BatchNorm1d(128) if batch_norm else nn.Identity(), nn.Dropout(0.1), nn.Linear(128, 32), nn.ReLU(), nn.Linear(32, 1) ) def forward(self, x): return self.layers(x)