import torch.nn as nn class WRegressor(nn.Module): def __init__(self): super().__init__() self.linear_relu_stack = nn.Sequential( nn.Linear(768, 256), nn.ReLU(), nn.Dropout(), nn.Linear(256, 64), nn.ReLU(), nn.Dropout(), nn.Linear(64, 16), nn.ReLU(), nn.Dropout(), nn.Linear(16, 1), ) return def forward(self, x): r = self.linear_relu_stack(x) return r