soarhigh commited on
Commit
56e6323
·
1 Parent(s): fc44eab

Upload regressor.py

Browse files
Files changed (1) hide show
  1. regressor.py +19 -0
regressor.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class WRegressor(nn.Module):
4
+ def __init__(self):
5
+ super().__init__()
6
+ self.linear_relu_stack = nn.Sequential(
7
+ nn.Linear(768, 256),
8
+ nn.ReLU(),
9
+ nn.Linear(256, 64),
10
+ nn.ReLU(),
11
+ nn.Linear(64, 16),
12
+ nn.ReLU(),
13
+ nn.Linear(16, 1),
14
+ )
15
+ return
16
+
17
+ def forward(self, x):
18
+ r = self.linear_relu_stack(x)
19
+ return r