Multiplication / MultiplicationNet.py
xcx0902's picture
Upload folder using huggingface_hub
0613d9c verified
import torch.nn as nn
class MultiplicationNet(nn.Module):
def __init__(self):
super(MultiplicationNet, self).__init__()
layer_sizes = [2, 512, 1024, 2048, 1]
layers = []
for i in range(len(layer_sizes) - 1):
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
layers.append(nn.ReLU())
layers.pop()
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)