import os
import pathlib
import torch
import torch.hub
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
from torchvision.io.image import encode_png
from PIL import Image
import PIL
from mcquic import Config
from mcquic.modules.compressor import BaseCompressor, Compressor
from mcquic.datasets.transforms import AlignedCrop
from mcquic.utils.specification import File
from mcquic.utils.vision import DeTransform
try:
import streamlit as st
except:
raise ImportError("To run `mcquic service`, please install Streamlit by `pip install streamlit` firstly.")
MODELS_URL = "https://github.com/xiaosu-zhu/McQuic/releases/download/generic/qp_2_msssim_8e954998.mcquic"
HF_SPACE = "HF_SPACE" in os.environ
@st.experimental_singleton
def loadModel(device):
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
config = Config.deserialize(ckpt["config"])
model = Compressor(**config.Model.Params).to(device)
model.QuantizationParameter = "qp_2_msssim"
model.load_state_dict(ckpt["model"])
return model
@st.cache
def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
image = convert_image_dtype(image)
if crop:
image = AlignedCrop()(image)
# [c, h, w]
image = (image - 0.5) * 2
codes, binaries, headers = model.compress(image[None, ...])
return File(headers[0], binaries[0])
@st.cache
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
binaries = sourceFile.Content
# [1, c, h, w]
restored = model.decompress([binaries], [sourceFile.FileHeader])
# [c, h, w]
return DeTransform()(restored[0])
def main():
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
device = torch.device("cuda")
model = loadModel(device).eval()
st.sidebar.markdown("""
a.k.a. Multi-codebook Quantizers for neural image compression
Compressing images on-the-fly.
> Due to resources limitation, I only provide compression service with model `qp = 2` targeted `ms-ssim`.
""", unsafe_allow_html=True)
if HF_SPACE:
st.markdown("""
> Due to resources limitation of HF spaces, upload image size is restricted to smaller than `3000 x 3000`. Also, this demo is CPU-only and may be slow.
> This demo is synced with main branch of `McQuic`. Some features may be unstable and changed frequently.
""", unsafe_allow_html=True)
with st.form("SubmitForm"):
uploadedFile = st.file_uploader("Try running McQuic to compress or restore images!", type=["png", "jpg", "jpeg", "mcq"], help="Upload your image or compressed `.mcq` file here.")
cropping = st.checkbox("Cropping image to align grids.", help="If checked, the image is cropped to align feature map grids. This will make compressed file smaller.")
submitted = st.form_submit_button("Submit", help="Click to start compress/restore.")
if submitted and uploadedFile is not None:
if uploadedFile.name.endswith(".mcq"):
uploadedFile.flush()
binaryFile = File.deserialize(uploadedFile.read())
st.text(str(binaryFile))
result = decompressImage(binaryFile, model)
st.image(result.cpu().permute(1, 2, 0).numpy())
downloadButton = st.empty()
done = downloadButton.download_button("Click to download restored image", data=bytes(encode_png(result.cpu()).tolist()), file_name=".".join(uploadedFile.name.split(".")[:-1] + ["png"]), mime="image/png")
if done:
downloadButton.empty()
elif uploadedFile.name.lower().endswith((".png", ".jpg", ".jpeg")):
try:
image = Image.open(uploadedFile)
except PIL.UnidentifiedImageError:
st.markdown("""
> Image open failed. Please try other images.
""", unsafe_allow_html=True)
return
w, h = image.size
if HF_SPACE and (h > 3000 or w > 3000):
st.markdown("""
> Image is too large. Please try other images.
""", unsafe_allow_html=True)
return
image = pil_to_tensor(image.convert("RGB")).to(device)
# st.image(image.cpu().permute(1, 2, 0).numpy())
result = compressImage(image, model, cropping)
st.text(str(result))
downloadButton = st.empty()
done = st.download_button("Click to download compressed file", data=result.serialize(), file_name=".".join(uploadedFile.name.split(".")[:-1] + ["mcq"]), mime="image/mcq")
if done:
downloadButton.empty()
else:
st.markdown("""
> Not supported image formate. Please try other images.
""", unsafe_allow_html=True)
return
st.markdown("""
CVF Open Access | arXiv | BibTex | Demo
""", unsafe_allow_html=True)
if __name__ == "__main__":
with torch.inference_mode():
main()