import torch.nn as nn class MultiplicationNet(nn.Module): def __init__(self): super(MultiplicationNet, self).__init__() layer_sizes = [2, 512, 1024, 2048, 1] layers = [] for i in range(len(layer_sizes) - 1): layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) layers.append(nn.ReLU()) layers.pop() self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)