junjuice0 commited on
Commit
96a87c7
1 Parent(s): 1388c9c

Update testvae.py

Browse files
Files changed (1) hide show
  1. testvae.py +9 -5
testvae.py CHANGED
@@ -11,18 +11,22 @@ import gc
11
  class TestModel(torch.nn.Module):
12
  def __init__(self):
13
  super().__init__()
14
- self.conv1 = torch.nn.Conv2d(3, 16, 5, 1, 2, bias=False)
 
15
  self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
16
- self.conv3 = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=True)
 
17
  self.bn1 = torch.nn.BatchNorm2d(16)
18
  self.bn2 = torch.nn.BatchNorm2d(16)
19
 
20
  def forward(self, x):
21
- x = self.conv1(x)
22
  x = self.bn1(x)
23
- x = self.conv2(x)
 
 
24
  x = self.bn2(x)
25
- x = self.conv3(x)
26
  x = torch.clamp(x, -1, 1)
27
  return x
28
 
 
11
  class TestModel(torch.nn.Module):
12
  def __init__(self):
13
  super().__init__()
14
+ self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
15
+ self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
16
  self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
17
+ self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
18
+ self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False)
19
  self.bn1 = torch.nn.BatchNorm2d(16)
20
  self.bn2 = torch.nn.BatchNorm2d(16)
21
 
22
  def forward(self, x):
23
+ x = self.start(x)
24
  x = self.bn1(x)
25
+ x = self.conv1(x) + x
26
+ x = self.conv2(x) + x
27
+ x = self.conv3(x) + x
28
  x = self.bn2(x)
29
+ x = self.final(x)
30
  x = torch.clamp(x, -1, 1)
31
  return x
32