meeww commited on
Commit
b3f5444
1 Parent(s): 146cf41

Upload discriminator.py

Browse files
Files changed (1) hide show
  1. discriminator.py +37 -0
discriminator.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+
7
+ class Discriminator(nn.Module):
8
+ def __init__(
9
+ self,
10
+ image_shape: (int, int, int),
11
+ use_cuda: bool = False,
12
+ saved_model: str or None = None
13
+ ):
14
+ super(Discriminator, self).__init__()
15
+
16
+ self.model = nn.Sequential(
17
+ nn.Linear(int(np.prod(image_shape)), 512),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ nn.Linear(512, 256),
20
+ nn.LeakyReLU(0.2, inplace=True),
21
+ nn.Linear(256, 1),
22
+ )
23
+ if saved_model is not None:
24
+ self.model.load_state_dict(
25
+ torch.load(
26
+ saved_model,
27
+ map_location=torch.device('cuda' if use_cuda else 'cpu')
28
+ )
29
+ )
30
+
31
+ def forward(self, img):
32
+ img_flat = img.view(img.shape[0], -1)
33
+ validity = self.model(img_flat)
34
+ return validity
35
+
36
+ def save(self, to):
37
+ torch.save(self.model.state_dict(), to)