File size: 1,126 Bytes
bb0a0a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn

class ResBlock(nn.Module):
	"""Linear block with residuals"""
	def __init__(self, ch):
		super().__init__()
		self.join = nn.ReLU()
		self.long = nn.Sequential(
			nn.Linear(ch, ch),
			nn.LeakyReLU(0.1),
			nn.Linear(ch, ch),
			nn.LeakyReLU(0.1),
			nn.Linear(ch, ch),
		)
	def forward(self, x):
		return self.join(self.long(x) + x)

class PredictorModel(nn.Module):
	"""Main predictor class"""
	def __init__(self, features=768, outputs=1, hidden=1024):
		super().__init__()
		self.features = features
		self.outputs = outputs
		self.hidden = hidden
		self.up = nn.Sequential(
			nn.Linear(self.features, self.hidden),
			ResBlock(ch=self.hidden),
		)
		self.down = nn.Sequential(
			nn.Linear(self.hidden, 128),
			nn.Linear(128, 64),
			nn.Dropout(0.1),
			nn.LeakyReLU(),
			nn.Linear(64, 32),
			nn.Linear(32, self.outputs),
		)
		self.out = nn.Softmax(dim=1) if self.outputs > 1 else nn.Tanh()
	def forward(self, x):
		y = self.up(x)
		z = self.down(y)
		if self.outputs > 1:
			return self.out(z)
		else:
			return (self.out(z)+1.0)/2.0