Blealtan commited on
Commit
7132b8a
1 Parent(s): 49e397a

Streamlit app added.

Browse files
Files changed (2) hide show
  1. app.py +323 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_url, cached_download
2
+ import streamlit as st
3
+ import io
4
+ import gc
5
+
6
+ ########################################################################################################
7
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
8
+ ########################################################################################################
9
+
10
+ MODEL_REPO = 'BlinkDL/clip-guided-binary-autoencoder'
11
+
12
+ import torch, types
13
+ import numpy as np
14
+ from PIL import Image
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ import torchvision as vision
18
+ import torchvision.transforms as transforms
19
+ from torchvision.transforms import functional as VF
20
+
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+
23
+
24
+ class ToBinary(torch.autograd.Function):
25
+
26
+ @staticmethod
27
+ def forward(ctx, x):
28
+ return torch.floor(
29
+ x + 0.5) # no need for noise when we have plenty of data
30
+
31
+ @staticmethod
32
+ def backward(ctx, grad_output):
33
+ return grad_output.clone() # pass-through
34
+
35
+
36
+ class ResBlock(nn.Module):
37
+
38
+ def __init__(self, c_x, c_hidden):
39
+ super().__init__()
40
+ self.B0 = nn.BatchNorm2d(c_x)
41
+ self.C0 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1)
42
+ self.C1 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1)
43
+ self.C2 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1)
44
+ self.C3 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1)
45
+
46
+ def forward(self, x):
47
+ ACT = F.mish
48
+ x = x + self.C1(ACT(self.C0(ACT(self.B0(x)))))
49
+ x = x + self.C3(ACT(self.C2(x)))
50
+ return x
51
+
52
+
53
+ class REncoderSmall(nn.Module):
54
+
55
+ def __init__(self, args):
56
+ super().__init__()
57
+ self.args = args
58
+ dd = 8
59
+ self.Bxx = nn.BatchNorm2d(dd * 64)
60
+
61
+ self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
62
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
63
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
64
+
65
+ self.B00 = nn.BatchNorm2d(dd * 4)
66
+ self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
67
+ self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)
68
+ self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
69
+ self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)
70
+
71
+ self.B10 = nn.BatchNorm2d(dd * 16)
72
+ self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
73
+ self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)
74
+ self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
75
+ self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)
76
+
77
+ self.B20 = nn.BatchNorm2d(dd * 64)
78
+ self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
79
+ self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
80
+ self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
81
+ self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
82
+
83
+ self.COUT = nn.Conv2d(dd * 64,
84
+ args.my_img_bit,
85
+ kernel_size=3,
86
+ padding=1)
87
+
88
+ def forward(self, img):
89
+ ACT = F.mish
90
+
91
+ x = self.CIN(img)
92
+ xx = self.Bxx(F.pixel_unshuffle(x, 8))
93
+ x = x + self.Cx1(ACT(self.Cx0(x)))
94
+
95
+ x = F.pixel_unshuffle(x, 2)
96
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
97
+ x = x + self.C03(ACT(self.C02(x)))
98
+
99
+ x = F.pixel_unshuffle(x, 2)
100
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
101
+ x = x + self.C13(ACT(self.C12(x)))
102
+
103
+ x = F.pixel_unshuffle(x, 2)
104
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
105
+ x = x + self.C23(ACT(self.C22(x)))
106
+
107
+ x = self.COUT(x + xx)
108
+ return torch.sigmoid(x)
109
+
110
+
111
+ class RDecoderSmall(nn.Module):
112
+
113
+ def __init__(self, args):
114
+ super().__init__()
115
+ self.args = args
116
+ dd = 8
117
+ self.CIN = nn.Conv2d(args.my_img_bit,
118
+ dd * 64,
119
+ kernel_size=3,
120
+ padding=1)
121
+
122
+ self.B00 = nn.BatchNorm2d(dd * 64)
123
+ self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
124
+ self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
125
+ self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
126
+ self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
127
+
128
+ self.B10 = nn.BatchNorm2d(dd * 16)
129
+ self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
130
+ self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)
131
+ self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
132
+ self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)
133
+
134
+ self.B20 = nn.BatchNorm2d(dd * 4)
135
+ self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
136
+ self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)
137
+ self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
138
+ self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)
139
+
140
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
141
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
142
+ self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
143
+
144
+ def forward(self, code):
145
+ ACT = F.mish
146
+ x = self.CIN(code)
147
+
148
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
149
+ x = x + self.C03(ACT(self.C02(x)))
150
+ x = F.pixel_shuffle(x, 2)
151
+
152
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
153
+ x = x + self.C13(ACT(self.C12(x)))
154
+ x = F.pixel_shuffle(x, 2)
155
+
156
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
157
+ x = x + self.C23(ACT(self.C22(x)))
158
+ x = F.pixel_shuffle(x, 2)
159
+
160
+ x = x + self.Cx1(ACT(self.Cx0(x)))
161
+ x = self.COUT(x)
162
+
163
+ return torch.sigmoid(x)
164
+
165
+
166
+ class REncoderLarge(nn.Module):
167
+
168
+ def __init__(self, args, dd, ee, ff):
169
+ super().__init__()
170
+ self.args = args
171
+ self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1)
172
+ self.BXX = nn.BatchNorm2d(dd)
173
+ self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
174
+ self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1)
175
+ self.R0 = ResBlock(dd * 4, ff)
176
+ self.R1 = ResBlock(dd * 16, ff)
177
+ self.R2 = ResBlock(dd * 64, ff)
178
+ self.CZZ = nn.Conv2d(dd * 64,
179
+ args.my_img_bit,
180
+ kernel_size=3,
181
+ padding=1)
182
+
183
+ def forward(self, x):
184
+ ACT = F.mish
185
+ x = self.BXX(self.CXX(x))
186
+
187
+ x = x + self.CX1(ACT(self.CX0(x)))
188
+ x = F.pixel_unshuffle(x, 2)
189
+ x = self.R0(x)
190
+ x = F.pixel_unshuffle(x, 2)
191
+ x = self.R1(x)
192
+ x = F.pixel_unshuffle(x, 2)
193
+ x = self.R2(x)
194
+
195
+ x = self.CZZ(x)
196
+ return torch.sigmoid(x)
197
+
198
+
199
+ class RDecoderLarge(nn.Module):
200
+
201
+ def __init__(self, args):
202
+ super().__init__()
203
+ self.args = args
204
+ if 'd16_512' in model_prefix:
205
+ dd, ee, ff = 16, 64, 512
206
+ elif 'd32_1024' in model_prefix:
207
+ dd, ee, ff = 32, 128, 1024
208
+ self.CZZ = nn.Conv2d(args.my_img_bit,
209
+ dd * 64,
210
+ kernel_size=3,
211
+ padding=1)
212
+ self.BZZ = nn.BatchNorm2d(dd * 64)
213
+ self.R0 = ResBlock(dd * 64, ff)
214
+ self.R1 = ResBlock(dd * 16, ff)
215
+ self.R2 = ResBlock(dd * 4, ff)
216
+ self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
217
+ self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1)
218
+ self.CXX = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
219
+
220
+ def forward(self, x):
221
+ ACT = F.mish
222
+ x = self.BZZ(self.CZZ(x))
223
+
224
+ x = self.R0(x)
225
+ x = F.pixel_shuffle(x, 2)
226
+ x = self.R1(x)
227
+ x = F.pixel_shuffle(x, 2)
228
+ x = self.R2(x)
229
+ x = F.pixel_shuffle(x, 2)
230
+ x = x + self.CX1(ACT(self.CX0(x)))
231
+
232
+ x = self.CXX(x)
233
+ return torch.sigmoid(x)
234
+
235
+
236
+ @st.cache_resource(max_entries=1)
237
+ def prepare_model(model_prefix):
238
+ gc.collect()
239
+
240
+ if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745':
241
+ R_ENCODER, R_DECODER = REncoderSmall, RDecoderSmall
242
+ else:
243
+ if 'd16_512' in model_prefix:
244
+ dd, ee, ff = 16, 64, 512
245
+ elif 'd32_1024' in model_prefix:
246
+ dd, ee, ff = 32, 128, 1024
247
+ R_ENCODER, R_DECODER = ((lambda args: REncoderLarge(args, dd, ee, ff)),
248
+ (lambda args: RDecoderLarge(args, dd, ee, ff)))
249
+
250
+ args = types.SimpleNamespace()
251
+ args.my_img_bit = 13
252
+ encoder = R_ENCODER(args).eval().to(device)
253
+ decoder = R_DECODER(args).eval().to(device)
254
+
255
+ zpow = torch.tensor([2**i for i in range(0, 13)]).reshape(13, 1, 1)
256
+ zpow = zpow.to(device).long()
257
+
258
+ encoder.load_state_dict(
259
+ torch.load(
260
+ cached_download(hf_hub_url(MODEL_REPO, f'{model_prefix}-E.pth'))))
261
+ decoder.load_state_dict(
262
+ torch.load(
263
+ cached_download(hf_hub_url(MODEL_REPO, f'{model_prefix}-D.pth'))))
264
+
265
+ encoder.eval()
266
+ decoder.eval()
267
+
268
+ return encoder, decoder
269
+
270
+
271
+ def encode(model_prefix, img):
272
+ encoder, _ = prepare_model(model_prefix)
273
+ img_transform = transforms.Compose([
274
+ transforms.PILToTensor(),
275
+ transforms.ConvertImageDtype(torch.float),
276
+ transforms.Resize((224, 224))
277
+ ])
278
+
279
+ with torch.no_grad():
280
+ img = img_transform(img.convert("RGB")).unsqueeze(0).to(device)
281
+ z = encoder(img)
282
+ z = ToBinary.apply(z)
283
+
284
+ return z.cpu().numpy()
285
+
286
+
287
+ def decode(model_prefix, z):
288
+ _, decoder = prepare_model(model_prefix)
289
+ decoded = decoder(torch.Tensor(z).to(device))
290
+ return VF.to_pil_image(decoded[0])
291
+
292
+
293
+ st.title("clip-guided-binary-autoencoder")
294
+ model_prefix = st.selectbox('The model to use',
295
+ ('out-v7c_d8_256-224-13bit-OB32x0.5-745',
296
+ 'out-v7d_d16_512-224-13bit-OB32x0.5-2487',
297
+ 'out-v7d_d32_1024-224-13bit-OB32x0.5-5560'))
298
+
299
+ encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"])
300
+
301
+ with encoder_tab:
302
+ col_in, col_out = st.columns(2)
303
+ uploaded_file = col_in.file_uploader('Choose an Image')
304
+ if uploaded_file is not None:
305
+ image = Image.open(uploaded_file)
306
+ col_in.image(image, 'Input Image')
307
+ z = encode(model_prefix, image)
308
+ with io.BytesIO() as buffer:
309
+ np.save(buffer, z)
310
+ col_out.download_button(
311
+ label="Download Encoded Data",
312
+ data=buffer,
313
+ file_name=uploaded_file.name + '.npy',
314
+ )
315
+ col_out.image(decode(model_prefix, z), 'Output Image preview')
316
+
317
+ with decoder_tab:
318
+ col_in, col_out = st.columns(2)
319
+ uploaded_file = col_in.file_uploader('Choose an Encoded Data')
320
+ if uploaded_file is not None:
321
+ z = np.load(uploaded_file)
322
+ image = decode(model_prefix, z)
323
+ col_out.image(image, 'Output Image')
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.21.5
2
+ Pillow==9.4.0
3
+ torch==1.13.1+cu117
4
+ torchvision==0.14.1+cu117