Spaces:
Runtime error
Runtime error
Better support for keep_shape
Browse files
app.py
CHANGED
|
@@ -25,18 +25,6 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
| 25 |
IMG_BITS = 13
|
| 26 |
|
| 27 |
|
| 28 |
-
class ToBinary(torch.autograd.Function):
|
| 29 |
-
|
| 30 |
-
@staticmethod
|
| 31 |
-
def forward(ctx, x):
|
| 32 |
-
return torch.floor(
|
| 33 |
-
x + 0.5) # no need for noise when we have plenty of data
|
| 34 |
-
|
| 35 |
-
@staticmethod
|
| 36 |
-
def backward(ctx, grad_output):
|
| 37 |
-
return grad_output.clone() # pass-through
|
| 38 |
-
|
| 39 |
-
|
| 40 |
class ResBlock(nn.Module):
|
| 41 |
|
| 42 |
def __init__(self, c_x, c_hidden):
|
|
@@ -242,26 +230,46 @@ def prepare_model(model_prefix):
|
|
| 242 |
return encoder, decoder
|
| 243 |
|
| 244 |
|
| 245 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
encoder, _ = prepare_model(model_prefix)
|
| 247 |
-
img_transform = transforms.Compose(
|
| 248 |
-
[transforms.PILToTensor(),
|
| 249 |
-
transforms.ConvertImageDtype(torch.float)] +
|
| 250 |
-
([transforms.Resize((224, 224))] if not keep_dims else []))
|
| 251 |
|
| 252 |
with torch.no_grad():
|
| 253 |
-
img =
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
with io.BytesIO() as buffer:
|
| 258 |
np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
|
| 259 |
z_b64 = base64.b64encode(buffer.getvalue()).decode()
|
| 260 |
|
| 261 |
-
return json.dumps({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
|
| 264 |
def decode(model_prefix, z_str):
|
|
|
|
| 265 |
_, decoder = prepare_model(model_prefix)
|
| 266 |
|
| 267 |
z_json = json.loads(z_str)
|
|
@@ -269,10 +277,23 @@ def decode(model_prefix, z_str):
|
|
| 269 |
buffer.write(base64.b64decode(z_json["data"]))
|
| 270 |
buffer.seek(0)
|
| 271 |
z = np.load(buffer)
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
-
|
| 275 |
-
return VF.to_pil_image(
|
| 276 |
|
| 277 |
|
| 278 |
st.title("Clip Guided Binary Autoencoder")
|
|
@@ -288,12 +309,13 @@ encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"])
|
|
| 288 |
|
| 289 |
with encoder_tab:
|
| 290 |
col_in, col_out = st.columns(2)
|
| 291 |
-
|
|
|
|
| 292 |
uploaded_file = col_in.file_uploader('Choose an Image')
|
| 293 |
if uploaded_file is not None:
|
| 294 |
image = Image.open(uploaded_file)
|
| 295 |
col_in.image(image, 'Input Image')
|
| 296 |
-
z_str = encode(model_prefix, image,
|
| 297 |
col_out.write("Encoded to:")
|
| 298 |
col_out.code(z_str, language=None)
|
| 299 |
col_out.image(decode(model_prefix, z_str), 'Output Image preview')
|
|
|
|
| 25 |
IMG_BITS = 13
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
class ResBlock(nn.Module):
|
| 29 |
|
| 30 |
def __init__(self, c_x, c_hidden):
|
|
|
|
| 230 |
return encoder, decoder
|
| 231 |
|
| 232 |
|
| 233 |
+
def compute_padding(img_shape):
|
| 234 |
+
hsize, vsize = (img_shape[1] + 7) // 8 * 8, (img_shape[0] + 7) // 8 * 8
|
| 235 |
+
hpad, vpad = hsize - img_shape[1], vsize - img_shape[0]
|
| 236 |
+
left, top = hpad // 2, vpad // 2
|
| 237 |
+
right, bottom = hpad - left, vpad - top
|
| 238 |
+
return left, top, right, bottom
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def encode(model_prefix, img, keep_shape):
|
| 242 |
+
gc.collect()
|
| 243 |
encoder, _ = prepare_model(model_prefix)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
with torch.no_grad():
|
| 246 |
+
img = VF.pil_to_tensor(img.convert("RGB"))
|
| 247 |
+
img = VF.convert_image_dtype(img)
|
| 248 |
+
img = img.unsqueeze(0).to(device)
|
| 249 |
+
img_shape = img.shape[2:]
|
| 250 |
+
|
| 251 |
+
if keep_shape:
|
| 252 |
+
left, top, right, bottom = compute_padding(img_shape)
|
| 253 |
+
img = VF.pad(img, [left, top, right, bottom], padding_mode='edge')
|
| 254 |
+
else:
|
| 255 |
+
img = VF.resize(img, [224, 224])
|
| 256 |
+
|
| 257 |
+
z = torch.floor(encoder(img) + 0.5)
|
| 258 |
|
| 259 |
with io.BytesIO() as buffer:
|
| 260 |
np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
|
| 261 |
z_b64 = base64.b64encode(buffer.getvalue()).decode()
|
| 262 |
|
| 263 |
+
return json.dumps({
|
| 264 |
+
"img_shape": img_shape,
|
| 265 |
+
"z_shape": z.shape[2:],
|
| 266 |
+
"keep_shape": keep_shape,
|
| 267 |
+
"data": z_b64,
|
| 268 |
+
})
|
| 269 |
|
| 270 |
|
| 271 |
def decode(model_prefix, z_str):
|
| 272 |
+
gc.collect()
|
| 273 |
_, decoder = prepare_model(model_prefix)
|
| 274 |
|
| 275 |
z_json = json.loads(z_str)
|
|
|
|
| 277 |
buffer.write(base64.b64decode(z_json["data"]))
|
| 278 |
buffer.seek(0)
|
| 279 |
z = np.load(buffer)
|
| 280 |
+
img_shape = z_json["img_shape"]
|
| 281 |
+
z_shape = z_json["z_shape"]
|
| 282 |
+
keep_shape = z_json["keep_shape"]
|
| 283 |
+
|
| 284 |
+
z = np.unpackbits(z)[:IMG_BITS * z_shape[0] * z_shape[1]].astype('float')
|
| 285 |
+
z = z.reshape([1, IMG_BITS] + z_shape)
|
| 286 |
+
|
| 287 |
+
img = decoder(torch.Tensor(z).to(device))
|
| 288 |
+
|
| 289 |
+
if keep_shape:
|
| 290 |
+
left, top, right, bottom = compute_padding(img_shape)
|
| 291 |
+
img = img[0, :, top:img.shape[2] - bottom, left:img.shape[3] - right]
|
| 292 |
+
else:
|
| 293 |
+
img = img[0]
|
| 294 |
|
| 295 |
+
st.write(img.shape)
|
| 296 |
+
return VF.to_pil_image(img)
|
| 297 |
|
| 298 |
|
| 299 |
st.title("Clip Guided Binary Autoencoder")
|
|
|
|
| 309 |
|
| 310 |
with encoder_tab:
|
| 311 |
col_in, col_out = st.columns(2)
|
| 312 |
+
keep_shape = col_in.checkbox(
|
| 313 |
+
'Use original size of input image instead of rescaling (Experimental)')
|
| 314 |
uploaded_file = col_in.file_uploader('Choose an Image')
|
| 315 |
if uploaded_file is not None:
|
| 316 |
image = Image.open(uploaded_file)
|
| 317 |
col_in.image(image, 'Input Image')
|
| 318 |
+
z_str = encode(model_prefix, image, keep_shape)
|
| 319 |
col_out.write("Encoded to:")
|
| 320 |
col_out.code(z_str, language=None)
|
| 321 |
col_out.image(decode(model_prefix, z_str), 'Output Image preview')
|