import torch class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.fc1 = torch.nn.Linear(10, 5) self.threshold = 0. def forward(self, x): ## generates a random float the same size as x return torch.randn(x.shape[0]).to(x.device)