import base64 from huggingface_hub import hf_hub_download import streamlit as st import io import gc import json ######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## MODEL_REPO = 'BlinkDL/clip-guided-binary-autoencoder' import torch, types import numpy as np from PIL import Image import torch.nn as nn from torch.nn import functional as F import torchvision as vision import torchvision.transforms as transforms from torchvision.transforms import functional as VF device = 'cuda' if torch.cuda.is_available() else 'cpu' IMG_BITS = 13 class ToBinary(torch.autograd.Function): @staticmethod def forward(ctx, x): return torch.floor( x + 0.5) # no need for noise when we have plenty of data @staticmethod def backward(ctx, grad_output): return grad_output.clone() # pass-through class ResBlock(nn.Module): def __init__(self, c_x, c_hidden): super().__init__() self.B0 = nn.BatchNorm2d(c_x) self.C0 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1) self.C1 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1) self.C2 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1) self.C3 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1) def forward(self, x): ACT = F.mish x = x + self.C1(ACT(self.C0(ACT(self.B0(x))))) x = x + self.C3(ACT(self.C2(x))) return x class REncoderSmall(nn.Module): def __init__(self): super().__init__() dd = 8 self.Bxx = nn.BatchNorm2d(dd * 64) self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) self.B00 = nn.BatchNorm2d(dd * 4) self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.B10 = nn.BatchNorm2d(dd * 16) self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) self.B20 = nn.BatchNorm2d(dd * 64) self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) self.COUT = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1) def forward(self, img): ACT = F.mish x = self.CIN(img) xx = self.Bxx(F.pixel_unshuffle(x, 8)) x = x + self.Cx1(ACT(self.Cx0(x))) x = F.pixel_unshuffle(x, 2) x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) x = x + self.C03(ACT(self.C02(x))) x = F.pixel_unshuffle(x, 2) x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) x = x + self.C13(ACT(self.C12(x))) x = F.pixel_unshuffle(x, 2) x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) x = x + self.C23(ACT(self.C22(x))) x = self.COUT(x + xx) return torch.sigmoid(x) class RDecoderSmall(nn.Module): def __init__(self): super().__init__() dd = 8 self.CIN = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1) self.B00 = nn.BatchNorm2d(dd * 64) self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) self.B10 = nn.BatchNorm2d(dd * 16) self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) self.B20 = nn.BatchNorm2d(dd * 4) self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) def forward(self, code): ACT = F.mish x = self.CIN(code) x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) x = x + self.C03(ACT(self.C02(x))) x = F.pixel_shuffle(x, 2) x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) x = x + self.C13(ACT(self.C12(x))) x = F.pixel_shuffle(x, 2) x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) x = x + self.C23(ACT(self.C22(x))) x = F.pixel_shuffle(x, 2) x = x + self.Cx1(ACT(self.Cx0(x))) x = self.COUT(x) return torch.sigmoid(x) class REncoderLarge(nn.Module): def __init__(self, dd, ee, ff): super().__init__() self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1) self.BXX = nn.BatchNorm2d(dd) self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1) self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1) self.R0 = ResBlock(dd * 4, ff) self.R1 = ResBlock(dd * 16, ff) self.R2 = ResBlock(dd * 64, ff) self.CZZ = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1) def forward(self, x): ACT = F.mish x = self.BXX(self.CXX(x)) x = x + self.CX1(ACT(self.CX0(x))) x = F.pixel_unshuffle(x, 2) x = self.R0(x) x = F.pixel_unshuffle(x, 2) x = self.R1(x) x = F.pixel_unshuffle(x, 2) x = self.R2(x) x = self.CZZ(x) return torch.sigmoid(x) class RDecoderLarge(nn.Module): def __init__(self, dd, ee, ff): super().__init__() self.CZZ = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1) self.BZZ = nn.BatchNorm2d(dd * 64) self.R0 = ResBlock(dd * 64, ff) self.R1 = ResBlock(dd * 16, ff) self.R2 = ResBlock(dd * 4, ff) self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1) self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1) self.CXX = nn.Conv2d(dd, 3, kernel_size=3, padding=1) def forward(self, x): ACT = F.mish x = self.BZZ(self.CZZ(x)) x = self.R0(x) x = F.pixel_shuffle(x, 2) x = self.R1(x) x = F.pixel_shuffle(x, 2) x = self.R2(x) x = F.pixel_shuffle(x, 2) x = x + self.CX1(ACT(self.CX0(x))) x = self.CXX(x) return torch.sigmoid(x) @st.cache def prepare_model(model_prefix): gc.collect() if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745': R_ENCODER, R_DECODER = REncoderSmall(), RDecoderSmall() else: if 'd16_512' in model_prefix: dd, ee, ff = 16, 64, 512 elif 'd32_1024' in model_prefix: dd, ee, ff = 32, 128, 1024 R_ENCODER = REncoderLarge(dd, ee, ff) R_DECODER = RDecoderLarge(dd, ee, ff) encoder = R_ENCODER.eval().to(device) decoder = R_DECODER.eval().to(device) encoder.load_state_dict( torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-E.pth'))) decoder.load_state_dict( torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-D.pth'))) return encoder, decoder def encode(model_prefix, img): encoder, _ = prepare_model(model_prefix) img_transform = transforms.Compose([ transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float) ]) with torch.no_grad(): img = img_transform(img.convert("RGB")).unsqueeze(0).to(device) z = encoder(img) z = ToBinary.apply(z) with io.BytesIO() as buffer: np.save(buffer, np.packbits(z.cpu().numpy().astype('bool'))) z_b64 = base64.b64encode(buffer.getvalue()).decode() return json.dumps({"shape": list(z.shape), "data": z_b64}) def decode(model_prefix, z_str): _, decoder = prepare_model(model_prefix) z_json = json.loads(z_str) with io.BytesIO() as buffer: buffer.write(base64.b64decode(z_json["data"])) buffer.seek(0) z = np.load(buffer) z = np.unpackbits(z).astype('float').reshape(z_json["shape"]) decoded = decoder(torch.Tensor(z).to(device)) return VF.to_pil_image(decoded[0]) st.title("Clip Guided Binary Autoencoder") st.write("Model is from [@BlinkDL](https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder)") model_prefix = st.selectbox('The model to use', ('out-v7c_d8_256-224-13bit-OB32x0.5-745', 'out-v7d_d16_512-224-13bit-OB32x0.5-2487', 'out-v7d_d32_1024-224-13bit-OB32x0.5-5560')) encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"]) with encoder_tab: col_in, col_out = st.columns(2) uploaded_file = col_in.file_uploader('Choose an Image') if uploaded_file is not None: image = Image.open(uploaded_file) col_in.image(image, 'Input Image') z_str = encode(model_prefix, image) col_out.write("Encoded to:") col_out.code(z_str,language=None) col_out.image(decode(model_prefix, z_str), 'Output Image preview') with decoder_tab: col_in, col_out = st.columns(2) z_str = col_in.text_area('Paste encoded string here:') if len(z_str) > 0: image = decode(model_prefix, z_str) col_out.image(image, 'Output Image')