yourusername commited on
Commit
1832ae9
1 Parent(s): dfc71b6

:sparkles: add components

Browse files
Files changed (2) hide show
  1. model.py +60 -1
  2. requirements.txt +0 -1
model.py CHANGED
@@ -1,6 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class MessageModel:
2
  def __init__(self, msg='hello, world'):
3
  self.msg = msg
4
-
5
  def __call__(self):
6
  print(self.msg)
 
1
+ from typing import Tuple
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Dense(nn.Module):
8
+
9
+ def __init__(self, input_dim, output_dim, bias=True, activation=nn.LeakyReLU, **kwargs):
10
+ super().__init__()
11
+ self.fc = nn.Linear(input_dim, output_dim, bias=bias)
12
+ nn.init.xavier_uniform_(self.fc.weight)
13
+ nn.init.constant_(self.fc.bias, 0.0)
14
+ self.activation = activation(**kwargs) if activation is not None else None
15
+
16
+ def forward(self, x):
17
+ if self.activation is None:
18
+ return self.fc(x)
19
+ return self.activation(self.fc(x))
20
+
21
+
22
+ class Encoder(nn.Module):
23
+ def __init__(self, input_dim, *dims):
24
+ super().__init__()
25
+ dims = (input_dim,) + dims
26
+ self.layers = nn.Sequential(
27
+ *[Dense(dims[i], dims[i+1], negative_slope=0.4, inplace=True) for i in range(len(dims) - 1)]
28
+ )
29
+ def forward(self, x):
30
+ return self.layers(x)
31
+
32
+
33
+ class Decoder(nn.Module):
34
+ def __init__(self, output_dim, *dims):
35
+ super().__init__()
36
+ self.layers = nn.Sequential(
37
+ *[Dense(dims[i], dims[i + 1], negative_slope=0.4, inplace=True) for i in range(len(dims) - 1)]
38
+ + [Dense(dims[-1], output_dim, activation=nn.Sigmoid)]
39
+ )
40
+ def forward(self, x):
41
+ return self.layers(x)
42
+
43
+
44
+ class Autoencoder(nn.Module):
45
+
46
+ def __init__(self, input_dim: int = 784, hidden_dims: Tuple[int] = (256, 64, 16, 4, 2)):
47
+ super().__init__()
48
+ self.encoder = Encoder(input_dim, *hidden_dims)
49
+ self.decoder = Decoder(input_dim, *reversed(hidden_dims))
50
+ self.input_dim = input_dim
51
+ self.hidden_dims = hidden_dims
52
+
53
+ def forward(self, x):
54
+ x = x.flatten(1)
55
+ latent = self.encoder(x)
56
+ recon = self.decoder(latent)
57
+ loss = F.mse_loss(recon, x)
58
+ return recon, latent, loss
59
+
60
+
61
  class MessageModel:
62
  def __init__(self, msg='hello, world'):
63
  self.msg = msg
 
64
  def __call__(self):
65
  print(self.msg)
requirements.txt CHANGED
@@ -1,2 +1 @@
1
  torch
2
- torchvision
 
1
  torch