File size: 308 Bytes
4875545 |
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(10, 5)
self.threshold = 0.
def forward(self, x):
## generates a random float the same size as x
return torch.randn(x.shape[0]).to(x.device)
|