Blealtan commited on
Commit
2034096
1 Parent(s): 9bd4927

Refine code and use text instead of file

Browse files
Files changed (1) hide show
  1. app.py +41 -57
app.py CHANGED
@@ -1,7 +1,9 @@
1
- from huggingface_hub import hf_hub_url, cached_download
 
2
  import streamlit as st
3
  import io
4
  import gc
 
5
 
6
  ########################################################################################################
7
  # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
@@ -20,6 +22,8 @@ from torchvision.transforms import functional as VF
20
 
21
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
 
 
 
23
 
24
  class ToBinary(torch.autograd.Function):
25
 
@@ -52,9 +56,8 @@ class ResBlock(nn.Module):
52
 
53
  class REncoderSmall(nn.Module):
54
 
55
- def __init__(self, args):
56
  super().__init__()
57
- self.args = args
58
  dd = 8
59
  self.Bxx = nn.BatchNorm2d(dd * 64)
60
 
@@ -80,10 +83,7 @@ class REncoderSmall(nn.Module):
80
  self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
81
  self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
82
 
83
- self.COUT = nn.Conv2d(dd * 64,
84
- args.my_img_bit,
85
- kernel_size=3,
86
- padding=1)
87
 
88
  def forward(self, img):
89
  ACT = F.mish
@@ -110,14 +110,10 @@ class REncoderSmall(nn.Module):
110
 
111
  class RDecoderSmall(nn.Module):
112
 
113
- def __init__(self, args):
114
  super().__init__()
115
- self.args = args
116
  dd = 8
117
- self.CIN = nn.Conv2d(args.my_img_bit,
118
- dd * 64,
119
- kernel_size=3,
120
- padding=1)
121
 
122
  self.B00 = nn.BatchNorm2d(dd * 64)
123
  self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
@@ -165,9 +161,8 @@ class RDecoderSmall(nn.Module):
165
 
166
  class REncoderLarge(nn.Module):
167
 
168
- def __init__(self, args, dd, ee, ff):
169
  super().__init__()
170
- self.args = args
171
  self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1)
172
  self.BXX = nn.BatchNorm2d(dd)
173
  self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
@@ -175,10 +170,7 @@ class REncoderLarge(nn.Module):
175
  self.R0 = ResBlock(dd * 4, ff)
176
  self.R1 = ResBlock(dd * 16, ff)
177
  self.R2 = ResBlock(dd * 64, ff)
178
- self.CZZ = nn.Conv2d(dd * 64,
179
- args.my_img_bit,
180
- kernel_size=3,
181
- padding=1)
182
 
183
  def forward(self, x):
184
  ACT = F.mish
@@ -198,13 +190,9 @@ class REncoderLarge(nn.Module):
198
 
199
  class RDecoderLarge(nn.Module):
200
 
201
- def __init__(self, args, dd, ee, ff):
202
  super().__init__()
203
- self.args = args
204
- self.CZZ = nn.Conv2d(args.my_img_bit,
205
- dd * 64,
206
- kernel_size=3,
207
- padding=1)
208
  self.BZZ = nn.BatchNorm2d(dd * 64)
209
  self.R0 = ResBlock(dd * 64, ff)
210
  self.R1 = ResBlock(dd * 16, ff)
@@ -234,32 +222,22 @@ def prepare_model(model_prefix):
234
  gc.collect()
235
 
236
  if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745':
237
- R_ENCODER, R_DECODER = REncoderSmall, RDecoderSmall
238
  else:
239
  if 'd16_512' in model_prefix:
240
  dd, ee, ff = 16, 64, 512
241
  elif 'd32_1024' in model_prefix:
242
  dd, ee, ff = 32, 128, 1024
243
- R_ENCODER, R_DECODER = ((lambda args: REncoderLarge(args, dd, ee, ff)),
244
- (lambda args: RDecoderLarge(args, dd, ee, ff)))
245
-
246
- args = types.SimpleNamespace()
247
- args.my_img_bit = 13
248
- encoder = R_ENCODER(args).eval().to(device)
249
- decoder = R_DECODER(args).eval().to(device)
250
 
251
- zpow = torch.tensor([2**i for i in range(0, 13)]).reshape(13, 1, 1)
252
- zpow = zpow.to(device).long()
253
 
254
  encoder.load_state_dict(
255
- torch.load(
256
- cached_download(hf_hub_url(MODEL_REPO, f'{model_prefix}-E.pth'))))
257
  decoder.load_state_dict(
258
- torch.load(
259
- cached_download(hf_hub_url(MODEL_REPO, f'{model_prefix}-D.pth'))))
260
-
261
- encoder.eval()
262
- decoder.eval()
263
 
264
  return encoder, decoder
265
 
@@ -277,11 +255,23 @@ def encode(model_prefix, img):
277
  z = encoder(img)
278
  z = ToBinary.apply(z)
279
 
280
- return z.cpu().numpy()
 
 
 
 
281
 
282
 
283
- def decode(model_prefix, z):
284
  _, decoder = prepare_model(model_prefix)
 
 
 
 
 
 
 
 
285
  decoded = decoder(torch.Tensor(z).to(device))
286
  return VF.to_pil_image(decoded[0])
287
 
@@ -300,20 +290,14 @@ with encoder_tab:
300
  if uploaded_file is not None:
301
  image = Image.open(uploaded_file)
302
  col_in.image(image, 'Input Image')
303
- z = encode(model_prefix, image)
304
- with io.BytesIO() as buffer:
305
- np.save(buffer, z)
306
- col_out.download_button(
307
- label="Download Encoded Data",
308
- data=buffer,
309
- file_name=uploaded_file.name + '.npy',
310
- )
311
- col_out.image(decode(model_prefix, z), 'Output Image preview')
312
 
313
  with decoder_tab:
314
  col_in, col_out = st.columns(2)
315
- uploaded_file = col_in.file_uploader('Choose an Encoded Data')
316
- if uploaded_file is not None:
317
- z = np.load(uploaded_file)
318
- image = decode(model_prefix, z)
319
  col_out.image(image, 'Output Image')
 
1
+ import base64
2
+ from huggingface_hub import hf_hub_download
3
  import streamlit as st
4
  import io
5
  import gc
6
+ import json
7
 
8
  ########################################################################################################
9
  # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
 
22
 
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
+ IMG_BITS = 13
26
+
27
 
28
  class ToBinary(torch.autograd.Function):
29
 
 
56
 
57
  class REncoderSmall(nn.Module):
58
 
59
+ def __init__(self):
60
  super().__init__()
 
61
  dd = 8
62
  self.Bxx = nn.BatchNorm2d(dd * 64)
63
 
 
83
  self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
84
  self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
85
 
86
+ self.COUT = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1)
 
 
 
87
 
88
  def forward(self, img):
89
  ACT = F.mish
 
110
 
111
  class RDecoderSmall(nn.Module):
112
 
113
+ def __init__(self):
114
  super().__init__()
 
115
  dd = 8
116
+ self.CIN = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1)
 
 
 
117
 
118
  self.B00 = nn.BatchNorm2d(dd * 64)
119
  self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
 
161
 
162
  class REncoderLarge(nn.Module):
163
 
164
+ def __init__(self, dd, ee, ff):
165
  super().__init__()
 
166
  self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1)
167
  self.BXX = nn.BatchNorm2d(dd)
168
  self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
 
170
  self.R0 = ResBlock(dd * 4, ff)
171
  self.R1 = ResBlock(dd * 16, ff)
172
  self.R2 = ResBlock(dd * 64, ff)
173
+ self.CZZ = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1)
 
 
 
174
 
175
  def forward(self, x):
176
  ACT = F.mish
 
190
 
191
  class RDecoderLarge(nn.Module):
192
 
193
+ def __init__(self, dd, ee, ff):
194
  super().__init__()
195
+ self.CZZ = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1)
 
 
 
 
196
  self.BZZ = nn.BatchNorm2d(dd * 64)
197
  self.R0 = ResBlock(dd * 64, ff)
198
  self.R1 = ResBlock(dd * 16, ff)
 
222
  gc.collect()
223
 
224
  if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745':
225
+ R_ENCODER, R_DECODER = REncoderSmall(), RDecoderSmall()
226
  else:
227
  if 'd16_512' in model_prefix:
228
  dd, ee, ff = 16, 64, 512
229
  elif 'd32_1024' in model_prefix:
230
  dd, ee, ff = 32, 128, 1024
231
+ R_ENCODER = REncoderLarge(dd, ee, ff)
232
+ R_DECODER = RDecoderLarge(dd, ee, ff)
 
 
 
 
 
233
 
234
+ encoder = R_ENCODER.eval().to(device)
235
+ decoder = R_DECODER.eval().to(device)
236
 
237
  encoder.load_state_dict(
238
+ torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-E.pth')))
 
239
  decoder.load_state_dict(
240
+ torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-D.pth')))
 
 
 
 
241
 
242
  return encoder, decoder
243
 
 
255
  z = encoder(img)
256
  z = ToBinary.apply(z)
257
 
258
+ with io.BytesIO() as buffer:
259
+ np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
260
+ z_b64 = base64.b64encode(buffer.getvalue()).decode()
261
+
262
+ return json.dumps({"shape": list(z.shape), "data": z_b64})
263
 
264
 
265
+ def decode(model_prefix, z_str):
266
  _, decoder = prepare_model(model_prefix)
267
+
268
+ z_json = json.loads(z_str)
269
+ with io.BytesIO() as buffer:
270
+ buffer.write(base64.b64decode(z_json["data"]))
271
+ buffer.seek(0)
272
+ z = np.load(buffer)
273
+ z = np.unpackbits(z).astype('float').reshape(z_json["shape"])
274
+
275
  decoded = decoder(torch.Tensor(z).to(device))
276
  return VF.to_pil_image(decoded[0])
277
 
 
290
  if uploaded_file is not None:
291
  image = Image.open(uploaded_file)
292
  col_in.image(image, 'Input Image')
293
+ z_str = encode(model_prefix, image)
294
+ col_out.write("Encoded to:")
295
+ col_out.code(z_str,language=None)
296
+ col_out.image(decode(model_prefix, z_str), 'Output Image preview')
 
 
 
 
 
297
 
298
  with decoder_tab:
299
  col_in, col_out = st.columns(2)
300
+ z_str = col_in.text_area('Paste encoded string here:')
301
+ if len(z_str) > 0:
302
+ image = decode(model_prefix, z_str)
 
303
  col_out.image(image, 'Output Image')