File size: 457 Bytes
ab63513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

from torch import nn

class BinaryClassificationWithLogits(nn.Module):
  def __init__(self,in_features,out_features, hidden_features):
    super().__init__()
    self.linear_block1=nn.Linear(in_features=in_features,out_features=hidden_features)
    self.linear_block2=nn.Linear(in_features=hidden_features,out_features=out_features)

  def forward(self,X):
  
    x=self.linear_block1(X)
   
    x=self.linear_block2(x)
   
    # return LOGITS
    return x