multimodalart HF staff commited on
Commit
01d95b9
1 Parent(s): a2bfbe9

Update previewer/modules.py

Browse files
Files changed (1) hide show
  1. previewer/modules.py +20 -11
previewer/modules.py CHANGED
@@ -1,11 +1,12 @@
1
  from torch import nn
2
 
3
- # Effnet 16x16 to 64x64 previewer
 
4
  class Previewer(nn.Module):
5
  def __init__(self, c_in=16, c_hidden=512, c_out=3):
6
  super().__init__()
7
  self.blocks = nn.Sequential(
8
- nn.Conv2d(c_in, c_hidden, kernel_size=1), # 36 channels to 512 channels
9
  nn.GELU(),
10
  nn.BatchNorm2d(c_hidden),
11
 
@@ -13,23 +14,31 @@ class Previewer(nn.Module):
13
  nn.GELU(),
14
  nn.BatchNorm2d(c_hidden),
15
 
16
- nn.ConvTranspose2d(c_hidden, c_hidden//2, kernel_size=2, stride=2), # 16 -> 32
 
 
 
 
 
 
 
 
17
  nn.GELU(),
18
- nn.BatchNorm2d(c_hidden//2),
19
 
20
- nn.Conv2d(c_hidden//2, c_hidden//2, kernel_size=3, padding=1),
21
  nn.GELU(),
22
- nn.BatchNorm2d(c_hidden//2),
23
 
24
- nn.ConvTranspose2d(c_hidden//2, c_hidden//4, kernel_size=2, stride=2), # 32 -> 64
25
  nn.GELU(),
26
- nn.BatchNorm2d(c_hidden//4),
27
 
28
- nn.Conv2d(c_hidden//4, c_hidden//4, kernel_size=3, padding=1),
29
  nn.GELU(),
30
- nn.BatchNorm2d(c_hidden//4),
31
 
32
- nn.Conv2d(c_hidden//4, c_out, kernel_size=1),
33
  )
34
 
35
  def forward(self, x):
 
1
  from torch import nn
2
 
3
+
4
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
5
  class Previewer(nn.Module):
6
  def __init__(self, c_in=16, c_hidden=512, c_out=3):
7
  super().__init__()
8
  self.blocks = nn.Sequential(
9
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
10
  nn.GELU(),
11
  nn.BatchNorm2d(c_hidden),
12
 
 
14
  nn.GELU(),
15
  nn.BatchNorm2d(c_hidden),
16
 
17
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
18
+ nn.GELU(),
19
+ nn.BatchNorm2d(c_hidden // 2),
20
+
21
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
22
+ nn.GELU(),
23
+ nn.BatchNorm2d(c_hidden // 2),
24
+
25
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
26
  nn.GELU(),
27
+ nn.BatchNorm2d(c_hidden // 4),
28
 
29
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
30
  nn.GELU(),
31
+ nn.BatchNorm2d(c_hidden // 4),
32
 
33
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
34
  nn.GELU(),
35
+ nn.BatchNorm2d(c_hidden // 4),
36
 
37
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
38
  nn.GELU(),
39
+ nn.BatchNorm2d(c_hidden // 4),
40
 
41
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
42
  )
43
 
44
  def forward(self, x):