Blealtan commited on
Commit
5527144
1 Parent(s): f1b6ca9

Better support for keep_shape

Browse files
Files changed (1) hide show
  1. app.py +48 -26
app.py CHANGED
@@ -25,18 +25,6 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
  IMG_BITS = 13
26
 
27
 
28
- class ToBinary(torch.autograd.Function):
29
-
30
- @staticmethod
31
- def forward(ctx, x):
32
- return torch.floor(
33
- x + 0.5) # no need for noise when we have plenty of data
34
-
35
- @staticmethod
36
- def backward(ctx, grad_output):
37
- return grad_output.clone() # pass-through
38
-
39
-
40
  class ResBlock(nn.Module):
41
 
42
  def __init__(self, c_x, c_hidden):
@@ -242,26 +230,46 @@ def prepare_model(model_prefix):
242
  return encoder, decoder
243
 
244
 
245
- def encode(model_prefix, img, keep_dims):
 
 
 
 
 
 
 
 
 
246
  encoder, _ = prepare_model(model_prefix)
247
- img_transform = transforms.Compose(
248
- [transforms.PILToTensor(),
249
- transforms.ConvertImageDtype(torch.float)] +
250
- ([transforms.Resize((224, 224))] if not keep_dims else []))
251
 
252
  with torch.no_grad():
253
- img = img_transform(img.convert("RGB")).unsqueeze(0).to(device)
254
- z = encoder(img)
255
- z = ToBinary.apply(z)
 
 
 
 
 
 
 
 
 
256
 
257
  with io.BytesIO() as buffer:
258
  np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
259
  z_b64 = base64.b64encode(buffer.getvalue()).decode()
260
 
261
- return json.dumps({"shape": list(z.shape), "data": z_b64})
 
 
 
 
 
262
 
263
 
264
  def decode(model_prefix, z_str):
 
265
  _, decoder = prepare_model(model_prefix)
266
 
267
  z_json = json.loads(z_str)
@@ -269,10 +277,23 @@ def decode(model_prefix, z_str):
269
  buffer.write(base64.b64decode(z_json["data"]))
270
  buffer.seek(0)
271
  z = np.load(buffer)
272
- z = np.unpackbits(z).astype('float').reshape(z_json["shape"])
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- decoded = decoder(torch.Tensor(z).to(device))
275
- return VF.to_pil_image(decoded[0])
276
 
277
 
278
  st.title("Clip Guided Binary Autoencoder")
@@ -288,12 +309,13 @@ encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"])
288
 
289
  with encoder_tab:
290
  col_in, col_out = st.columns(2)
291
- keep_dims = col_in.checkbox('Keep the size of original input image', True)
 
292
  uploaded_file = col_in.file_uploader('Choose an Image')
293
  if uploaded_file is not None:
294
  image = Image.open(uploaded_file)
295
  col_in.image(image, 'Input Image')
296
- z_str = encode(model_prefix, image, keep_dims)
297
  col_out.write("Encoded to:")
298
  col_out.code(z_str, language=None)
299
  col_out.image(decode(model_prefix, z_str), 'Output Image preview')
 
25
  IMG_BITS = 13
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class ResBlock(nn.Module):
29
 
30
  def __init__(self, c_x, c_hidden):
 
230
  return encoder, decoder
231
 
232
 
233
+ def compute_padding(img_shape):
234
+ hsize, vsize = (img_shape[1] + 7) // 8 * 8, (img_shape[0] + 7) // 8 * 8
235
+ hpad, vpad = hsize - img_shape[1], vsize - img_shape[0]
236
+ left, top = hpad // 2, vpad // 2
237
+ right, bottom = hpad - left, vpad - top
238
+ return left, top, right, bottom
239
+
240
+
241
+ def encode(model_prefix, img, keep_shape):
242
+ gc.collect()
243
  encoder, _ = prepare_model(model_prefix)
 
 
 
 
244
 
245
  with torch.no_grad():
246
+ img = VF.pil_to_tensor(img.convert("RGB"))
247
+ img = VF.convert_image_dtype(img)
248
+ img = img.unsqueeze(0).to(device)
249
+ img_shape = img.shape[2:]
250
+
251
+ if keep_shape:
252
+ left, top, right, bottom = compute_padding(img_shape)
253
+ img = VF.pad(img, [left, top, right, bottom], padding_mode='edge')
254
+ else:
255
+ img = VF.resize(img, [224, 224])
256
+
257
+ z = torch.floor(encoder(img) + 0.5)
258
 
259
  with io.BytesIO() as buffer:
260
  np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
261
  z_b64 = base64.b64encode(buffer.getvalue()).decode()
262
 
263
+ return json.dumps({
264
+ "img_shape": img_shape,
265
+ "z_shape": z.shape[2:],
266
+ "keep_shape": keep_shape,
267
+ "data": z_b64,
268
+ })
269
 
270
 
271
  def decode(model_prefix, z_str):
272
+ gc.collect()
273
  _, decoder = prepare_model(model_prefix)
274
 
275
  z_json = json.loads(z_str)
 
277
  buffer.write(base64.b64decode(z_json["data"]))
278
  buffer.seek(0)
279
  z = np.load(buffer)
280
+ img_shape = z_json["img_shape"]
281
+ z_shape = z_json["z_shape"]
282
+ keep_shape = z_json["keep_shape"]
283
+
284
+ z = np.unpackbits(z)[:IMG_BITS * z_shape[0] * z_shape[1]].astype('float')
285
+ z = z.reshape([1, IMG_BITS] + z_shape)
286
+
287
+ img = decoder(torch.Tensor(z).to(device))
288
+
289
+ if keep_shape:
290
+ left, top, right, bottom = compute_padding(img_shape)
291
+ img = img[0, :, top:img.shape[2] - bottom, left:img.shape[3] - right]
292
+ else:
293
+ img = img[0]
294
 
295
+ st.write(img.shape)
296
+ return VF.to_pil_image(img)
297
 
298
 
299
  st.title("Clip Guided Binary Autoencoder")
 
309
 
310
  with encoder_tab:
311
  col_in, col_out = st.columns(2)
312
+ keep_shape = col_in.checkbox(
313
+ 'Use original size of input image instead of rescaling (Experimental)')
314
  uploaded_file = col_in.file_uploader('Choose an Image')
315
  if uploaded_file is not None:
316
  image = Image.open(uploaded_file)
317
  col_in.image(image, 'Input Image')
318
+ z_str = encode(model_prefix, image, keep_shape)
319
  col_out.write("Encoded to:")
320
  col_out.code(z_str, language=None)
321
  col_out.image(decode(model_prefix, z_str), 'Output Image preview')