OptVQ / app.py
BorelTHU's picture
initiate
223d932
raw
history blame
7.17 kB
import gradio as gr
from io import BytesIO
import os
import sys
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
import torch
from torchvision import transforms as T
from optvq.models.quantizer import sinkhorn
from optvq.utils.init import seed_everything
seed_everything(42)
from optvq.models.vqgan_hf import VQModelHF
matplotlib.rcParams['font.family'] = 'Times New Roman'
#################
N_data = 50
N_code = 20
dim = 2
handler = None
device = torch.device("cpu")
#################
def nearest(src, trg):
dis_mat = torch.cdist(src, trg)
min_idx = torch.argmin(dis_mat, dim=-1)
return min_idx
def normalize(A, dim, mode="all"):
if mode == "all":
A = (A - A.mean()) / (A.std() + 1e-6)
A = A - A.min()
elif mode == "dim":
A = A / dim
elif mode == "null":
pass
return A
def draw_NN(data, code):
# nearest neighbor method
indices = nearest(data, code)
data = data.numpy()
code = code.numpy()
plt.figure(figsize=(3, 2.5), dpi=400)
# draw arrows in blue color, alpha=0.5
for i in range(data.shape[0]):
idx = indices[i].item()
start = data[i]
end = code[idx]
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6,
ls="-", lw=0.5)
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
plt.legend(loc="lower right")
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
plt.title("Nearest neighbor")
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image = Image.open(buf)
return image
def draw_optvq(data, code):
cost = torch.cdist(data, code, p=2.0)
cost = normalize(cost, dim, mode="all")
Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False)
indices = torch.argmax(Q, dim=-1)
data = data.numpy()
code = code.numpy()
plt.figure(figsize=(3, 2.5), dpi=400)
# draw arrows in blue color, alpha=0.5
for i in range(data.shape[0]):
idx = indices[i].item()
start = data[i]
end = code[idx]
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6,
ls="-", lw=0.5)
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
plt.legend(loc="lower right")
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
plt.title("Optimal Transport (OptVQ)")
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
image = Image.open(buf)
return image
def draw_process(x, y, std):
data = torch.randn(N_data, dim)
code = torch.randn(N_code, dim) * std
code[:, 0] += x
code[:, 1] += y
image_NN = draw_NN(data, code)
image_optvq = draw_optvq(data, code)
return image_NN, image_optvq
class Handler:
def __init__(self, device):
self.transform = T.Compose([
T.Resize(256),
T.CenterCrop(256),
T.ToTensor()
])
self.device = device
self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4")
self.basevq.to(self.device)
self.basevq.eval()
self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16")
self.vqgan.to(self.device)
self.vqgan.eval()
self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
self.optvq.to(self.device)
self.optvq.eval()
def tensor_to_image(self, tensor):
img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
img = (img + 1) / 2 * 255
img = img.astype("uint8")
return img
def process_image(self, img: np.ndarray):
img = Image.fromarray(img.astype("uint8"))
img = self.transform(img)
img = img.unsqueeze(0).to(self.device)
with torch.no_grad():
img = 2 * img - 1
# basevq
quant, *_ = self.basevq.encode(img)
basevq_rec = self.basevq.decode(quant)
# vqgan
quant, *_ = self.vqgan.encode(img)
vqgan_rec = self.vqgan.decode(quant)
# optvq
quant, *_ = self.optvq.encode(img)
optvq_rec = self.optvq.decode(quant)
# tensor to PIL image
img = self.tensor_to_image(img)
basevq_rec = self.tensor_to_image(basevq_rec)
vqgan_rec = self.tensor_to_image(vqgan_rec)
optvq_rec = self.tensor_to_image(optvq_rec)
return img, basevq_rec, vqgan_rec, optvq_rec
if __name__ == "__main__":
# create the model handler
handler = Handler(device=device)
# create the interface
with gr.Blocks() as demo:
gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
btn_demo1 = gr.Button(value="Run reconstruction")
image_basevq = gr.Image(label="BaseVQ rec.")
image_vqgan = gr.Image(label="VQGAN rec.")
image_optvq = gr.Image(label="OptVQ rec.")
btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq])
gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
with gr.Row():
with gr.Column():
input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
btn_demo2 = gr.Button(value="Run 2D example")
output_nn = gr.Image(label="NN")
output_optvq = gr.Image(label="OptVQ")
# set the function
input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
demo.launch()