from torch import nn | |
class BinaryClassificationWithLogits(nn.Module): | |
def __init__(self,in_features,out_features, hidden_features): | |
super().__init__() | |
self.linear_block1=nn.Linear(in_features=in_features,out_features=hidden_features) | |
self.linear_block2=nn.Linear(in_features=hidden_features,out_features=out_features) | |
def forward(self,X): | |
x=self.linear_block1(X) | |
x=self.linear_block2(x) | |
# return LOGITS | |
return x |