File size: 457 Bytes
ab63513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from torch import nn
class BinaryClassificationWithLogits(nn.Module):
def __init__(self,in_features,out_features, hidden_features):
super().__init__()
self.linear_block1=nn.Linear(in_features=in_features,out_features=hidden_features)
self.linear_block2=nn.Linear(in_features=hidden_features,out_features=out_features)
def forward(self,X):
x=self.linear_block1(X)
x=self.linear_block2(x)
# return LOGITS
return x |