from torch import nn class BinaryClassificationWithLogits(nn.Module): def __init__(self,in_features,out_features, hidden_features): super().__init__() self.linear_block1=nn.Linear(in_features=in_features,out_features=hidden_features) self.linear_block2=nn.Linear(in_features=hidden_features,out_features=out_features) def forward(self,X): x=self.linear_block1(X) x=self.linear_block2(x) # return LOGITS return x