import torch.nn as nn | |
class MultiplicationNet(nn.Module): | |
def __init__(self): | |
super(MultiplicationNet, self).__init__() | |
layer_sizes = [2, 512, 1024, 2048, 1] | |
layers = [] | |
for i in range(len(layer_sizes) - 1): | |
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) | |
layers.append(nn.ReLU()) | |
layers.pop() | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |