NeverlandPeter commited on
Commit
502d2e6
β€’
1 Parent(s): 76efecd
img_demoAE.py CHANGED
@@ -14,7 +14,8 @@ print(f'loading...')
14
 
15
  ########################################################################################################
16
 
17
- model_prefix = 'out-v7c_d8_256-224-13bit-OB32x0.5-745'
 
18
  input_imgs = ['lena.png', 'genshin.png', 'kodim14-modified.png', 'kodim19-modified.png', 'kodim24-modified.png']
19
  device = 'cpu' # cpu cuda
20
 
@@ -29,108 +30,187 @@ class ToBinary(torch.autograd.Function):
29
  def backward(ctx, grad_output):
30
  return grad_output.clone() # pass-through
31
 
32
- class R_ENCODER(nn.Module):
33
- def __init__(self, args):
34
  super().__init__()
35
- self.args = args
36
- dd = 8
37
- self.Bxx = nn.BatchNorm2d(dd*64)
38
-
39
- self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
40
- self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
41
- self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
42
-
43
- self.B00 = nn.BatchNorm2d(dd*4)
44
- self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
45
- self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
46
- self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
47
- self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
48
-
49
- self.B10 = nn.BatchNorm2d(dd*16)
50
- self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
51
- self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
52
- self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
53
- self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
54
-
55
- self.B20 = nn.BatchNorm2d(dd*64)
56
- self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
57
- self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
58
- self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
59
- self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
60
-
61
- self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
62
-
63
- def forward(self, img):
64
  ACT = F.mish
65
-
66
- x = self.CIN(img)
67
- xx = self.Bxx(F.pixel_unshuffle(x, 8))
68
- x = x + self.Cx1(ACT(self.Cx0(x)))
69
-
70
- x = F.pixel_unshuffle(x, 2)
71
- x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
72
- x = x + self.C03(ACT(self.C02(x)))
73
-
74
- x = F.pixel_unshuffle(x, 2)
75
- x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
76
- x = x + self.C13(ACT(self.C12(x)))
77
-
78
- x = F.pixel_unshuffle(x, 2)
79
- x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
80
- x = x + self.C23(ACT(self.C22(x)))
81
-
82
- x = self.COUT(x + xx)
83
- return torch.sigmoid(x)
84
-
85
- class R_DECODER(nn.Module):
86
- def __init__(self, args):
87
- super().__init__()
88
- self.args = args
89
- dd = 8
90
- self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
91
-
92
- self.B00 = nn.BatchNorm2d(dd*64)
93
- self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
94
- self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
95
- self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
96
- self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
97
-
98
- self.B10 = nn.BatchNorm2d(dd*16)
99
- self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
100
- self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
101
- self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
102
- self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
103
-
104
- self.B20 = nn.BatchNorm2d(dd*4)
105
- self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
106
- self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
107
- self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
108
- self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
109
-
110
- self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
111
- self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
112
- self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
113
-
114
- def forward(self, code):
115
- ACT = F.mish
116
- x = self.CIN(code)
117
-
118
- x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
119
- x = x + self.C03(ACT(self.C02(x)))
120
- x = F.pixel_shuffle(x, 2)
121
-
122
- x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
123
- x = x + self.C13(ACT(self.C12(x)))
124
- x = F.pixel_shuffle(x, 2)
125
-
126
- x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
127
- x = x + self.C23(ACT(self.C22(x)))
128
- x = F.pixel_shuffle(x, 2)
129
-
130
- x = x + self.Cx1(ACT(self.Cx0(x)))
131
- x = self.COUT(x)
132
-
133
- return torch.sigmoid(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  ########################################################################################################
136
 
@@ -165,4 +245,4 @@ for input_img in input_imgs:
165
  print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
166
 
167
  out = decoder(z)
168
- vision.utils.save_image(out, f"img_test/{input_img.split('.')[0]}-out-13bit.png")
 
14
 
15
  ########################################################################################################
16
 
17
+ # model_prefix = 'out-v7c_d8_256-224-13bit-OB32x0.5-745'
18
+ model_prefix = 'out-v7d_d16_512-224-13bit-OB32x0.5-2487'
19
  input_imgs = ['lena.png', 'genshin.png', 'kodim14-modified.png', 'kodim19-modified.png', 'kodim24-modified.png']
20
  device = 'cpu' # cpu cuda
21
 
 
30
  def backward(ctx, grad_output):
31
  return grad_output.clone() # pass-through
32
 
33
+ class ResBlock(nn.Module):
34
+ def __init__(self, c_x, c_hidden):
35
  super().__init__()
36
+ self.B0 = nn.BatchNorm2d(c_x)
37
+ self.C0 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1)
38
+ self.C1 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1)
39
+ self.C2 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1)
40
+ self.C3 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1)
41
+ def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ACT = F.mish
43
+ x = x + self.C1(ACT(self.C0(ACT(self.B0(x)))))
44
+ x = x + self.C3(ACT(self.C2(x)))
45
+ return x
46
+
47
+ if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745':
48
+ class R_ENCODER(nn.Module):
49
+ def __init__(self, args):
50
+ super().__init__()
51
+ self.args = args
52
+ dd = 8
53
+ self.Bxx = nn.BatchNorm2d(dd*64)
54
+
55
+ self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
56
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
57
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
58
+
59
+ self.B00 = nn.BatchNorm2d(dd*4)
60
+ self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
61
+ self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
62
+ self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
63
+ self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
64
+
65
+ self.B10 = nn.BatchNorm2d(dd*16)
66
+ self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
67
+ self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
68
+ self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
69
+ self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
70
+
71
+ self.B20 = nn.BatchNorm2d(dd*64)
72
+ self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
73
+ self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
74
+ self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
75
+ self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
76
+
77
+ self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
78
+
79
+ def forward(self, img):
80
+ ACT = F.mish
81
+
82
+ x = self.CIN(img)
83
+ xx = self.Bxx(F.pixel_unshuffle(x, 8))
84
+ x = x + self.Cx1(ACT(self.Cx0(x)))
85
+
86
+ x = F.pixel_unshuffle(x, 2)
87
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
88
+ x = x + self.C03(ACT(self.C02(x)))
89
+
90
+ x = F.pixel_unshuffle(x, 2)
91
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
92
+ x = x + self.C13(ACT(self.C12(x)))
93
+
94
+ x = F.pixel_unshuffle(x, 2)
95
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
96
+ x = x + self.C23(ACT(self.C22(x)))
97
+
98
+ x = self.COUT(x + xx)
99
+ return torch.sigmoid(x)
100
+
101
+ class R_DECODER(nn.Module):
102
+ def __init__(self, args):
103
+ super().__init__()
104
+ self.args = args
105
+ dd = 8
106
+ self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
107
+
108
+ self.B00 = nn.BatchNorm2d(dd*64)
109
+ self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
110
+ self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
111
+ self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
112
+ self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
113
+
114
+ self.B10 = nn.BatchNorm2d(dd*16)
115
+ self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
116
+ self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
117
+ self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
118
+ self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
119
+
120
+ self.B20 = nn.BatchNorm2d(dd*4)
121
+ self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
122
+ self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
123
+ self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
124
+ self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
125
+
126
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
127
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
128
+ self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
129
+
130
+ def forward(self, code):
131
+ ACT = F.mish
132
+ x = self.CIN(code)
133
+
134
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
135
+ x = x + self.C03(ACT(self.C02(x)))
136
+ x = F.pixel_shuffle(x, 2)
137
+
138
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
139
+ x = x + self.C13(ACT(self.C12(x)))
140
+ x = F.pixel_shuffle(x, 2)
141
+
142
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
143
+ x = x + self.C23(ACT(self.C22(x)))
144
+ x = F.pixel_shuffle(x, 2)
145
+
146
+ x = x + self.Cx1(ACT(self.Cx0(x)))
147
+ x = self.COUT(x)
148
+
149
+ return torch.sigmoid(x)
150
+ else:
151
+ class R_ENCODER(nn.Module):
152
+ def __init__(self, args):
153
+ super().__init__()
154
+ self.args = args
155
+ if 'd16_512' in model_prefix:
156
+ dd, ee, ff = 16, 64, 512
157
+ else:
158
+ dd, ee, ff = 32, 128, 1024
159
+ self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1)
160
+ self.BXX = nn.BatchNorm2d(dd)
161
+ self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
162
+ self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1)
163
+ self.R0 = ResBlock(dd*4, ff)
164
+ self.R1 = ResBlock(dd*16, ff)
165
+ self.R2 = ResBlock(dd*64, ff)
166
+ self.CZZ = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
167
+
168
+ def forward(self, x):
169
+ ACT = F.mish
170
+ x = self.BXX(self.CXX(x))
171
+
172
+ x = x + self.CX1(ACT(self.CX0(x)))
173
+ x = F.pixel_unshuffle(x, 2)
174
+ x = self.R0(x)
175
+ x = F.pixel_unshuffle(x, 2)
176
+ x = self.R1(x)
177
+ x = F.pixel_unshuffle(x, 2)
178
+ x = self.R2(x)
179
+
180
+ x = self.CZZ(x)
181
+ return torch.sigmoid(x)
182
+
183
+ class R_DECODER(nn.Module):
184
+ def __init__(self, args):
185
+ super().__init__()
186
+ self.args = args
187
+ if 'd16_512' in model_prefix:
188
+ dd, ee, ff = 16, 64, 512
189
+ else:
190
+ dd, ee, ff = 32, 128, 1024
191
+ self.CZZ = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
192
+ self.BZZ = nn.BatchNorm2d(dd*64)
193
+ self.R0 = ResBlock(dd*64, ff)
194
+ self.R1 = ResBlock(dd*16, ff)
195
+ self.R2 = ResBlock(dd*4, ff)
196
+ self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
197
+ self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1)
198
+ self.CXX = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
199
+
200
+ def forward(self, x):
201
+ ACT = F.mish
202
+ x = self.BZZ(self.CZZ(x))
203
+
204
+ x = self.R0(x)
205
+ x = F.pixel_shuffle(x, 2)
206
+ x = self.R1(x)
207
+ x = F.pixel_shuffle(x, 2)
208
+ x = self.R2(x)
209
+ x = F.pixel_shuffle(x, 2)
210
+ x = x + self.CX1(ACT(self.CX0(x)))
211
+
212
+ x = self.CXX(x)
213
+ return torch.sigmoid(x)
214
 
215
  ########################################################################################################
216
 
 
245
  print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n')
246
 
247
  out = decoder(z)
248
+ vision.utils.save_image(out, f"img_test/{input_img.split('.')[0]}-{model_prefix}.png")
img_test/{genshin-out-13bit.png β†’ genshin-out-v7c_d8_256-224-13bit-OB32x0.5-745.png} RENAMED
File without changes
img_test/genshin-out-v7d_d16_512-224-13bit-OB32x0.5-2487.png ADDED
img_test/{kodim14-modified-out-13bit.png β†’ kodim14-modified-out-v7c_d8_256-224-13bit-OB32x0.5-745.png} RENAMED
File without changes
img_test/kodim14-modified-out-v7d_d16_512-224-13bit-OB32x0.5-2487.png ADDED
img_test/{kodim19-modified-out-13bit.png β†’ kodim19-modified-out-v7c_d8_256-224-13bit-OB32x0.5-745.png} RENAMED
File without changes
img_test/kodim19-modified-out-v7d_d16_512-224-13bit-OB32x0.5-2487.png ADDED
img_test/{kodim24-modified-out-13bit.png β†’ kodim24-modified-out-v7c_d8_256-224-13bit-OB32x0.5-745.png} RENAMED
File without changes
img_test/kodim24-modified-out-v7d_d16_512-224-13bit-OB32x0.5-2487.png ADDED
img_test/{lena-out-13bit.png β†’ lena-out-v7c_d8_256-224-13bit-OB32x0.5-745.png} RENAMED
File without changes
img_test/lena-out-v7d_d16_512-224-13bit-OB32x0.5-2487.png ADDED
out-v7d_d16_512-224-13bit-OB32x0.5-2487-D.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c679523f7d74d54d125746a365f27a6cbed0503d48ddcab872f28131866924a
3
+ size 99724745
out-v7d_d16_512-224-13bit-OB32x0.5-2487-E.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bf1bdeff4ebf39e4a96044f91da4cba9e525fc29ac3effd64b349637c7caf93
3
+ size 99704585