File size: 308 Bytes
4875545
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = torch.nn.Linear(10, 5)
        self.threshold = 0.

    def forward(self, x):
       ## generates a random float the same size as x
       return torch.randn(x.shape[0]).to(x.device)