import gradio as gr |
from io import BytesIO |
import os |
import sys |
import matplotlib.pyplot as plt |
import matplotlib |
import numpy as np |
from PIL import Image |
from omegaconf import OmegaConf |
import torch |
from torchvision import transforms as T |
from optvq.models.quantizer import sinkhorn |
from optvq.utils.init import seed_everything |
seed_everything(42) |
from optvq.models.vqgan_hf import VQModelHF |
matplotlib.rcParams['font.family'] = 'Times New Roman' |
N_data = 50 |
N_code = 20 |
dim = 2 |
handler = None |
device = torch.device("cpu") |
def nearest(src, trg): |
dis_mat = torch.cdist(src, trg) |
min_idx = torch.argmin(dis_mat, dim=-1) |
return min_idx |
def normalize(A, dim, mode="all"): |
if mode == "all": |
A = (A - A.mean()) / (A.std() + 1e-6) |
A = A - A.min() |
elif mode == "dim": |
A = A / dim |
elif mode == "null": |
pass |
return A |
def draw_NN(data, code): |
indices = nearest(data, code) |
data = data.numpy() |
code = code.numpy() |
plt.figure(figsize=(3, 2.5), dpi=400) |
for i in range(data.shape[0]): |
idx = indices[i].item() |
start = data[i] |
end = code[idx] |
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], |
head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6, |
ls="-", lw=0.5) |
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") |
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") |
plt.legend(loc="lower right") |
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) |
plt.title("Nearest neighbor") |
buf = BytesIO() |
plt.savefig(buf, format="png") |
buf.seek(0) |
image = Image.open(buf) |
return image |
def draw_optvq(data, code): |
cost = torch.cdist(data, code, p=2.0) |
cost = normalize(cost, dim, mode="all") |
Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False) |
indices = torch.argmax(Q, dim=-1) |
data = data.numpy() |
code = code.numpy() |
plt.figure(figsize=(3, 2.5), dpi=400) |
for i in range(data.shape[0]): |
idx = indices[i].item() |
start = data[i] |
end = code[idx] |
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], |
head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6, |
ls="-", lw=0.5) |
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") |
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") |
plt.legend(loc="lower right") |
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) |
plt.title("Optimal Transport (OptVQ)") |
buf = BytesIO() |
plt.savefig(buf, format="png") |
buf.seek(0) |
image = Image.open(buf) |
return image |
def draw_process(x, y, std): |
data = torch.randn(N_data, dim) |
code = torch.randn(N_code, dim) * std |
code[:, 0] += x |
code[:, 1] += y |
image_NN = draw_NN(data, code) |
image_optvq = draw_optvq(data, code) |
return image_NN, image_optvq |
class Handler: |
def __init__(self, device): |
self.transform = T.Compose([ |
T.Resize(256), |
T.CenterCrop(256), |
T.ToTensor() |
]) |
self.device = device |
self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4") |
self.basevq.to(self.device) |
self.basevq.eval() |
self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16") |
self.vqgan.to(self.device) |
self.vqgan.eval() |
self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4") |
self.optvq.to(self.device) |
self.optvq.eval() |
def tensor_to_image(self, tensor): |
img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy() |
img = (img + 1) / 2 * 255 |
img = img.astype("uint8") |
return img |
def process_image(self, img: np.ndarray): |
img = Image.fromarray(img.astype("uint8")) |
img = self.transform(img) |
img = img.unsqueeze(0).to(self.device) |
with torch.no_grad(): |
img = 2 * img - 1 |
quant, *_ = self.basevq.encode(img) |
basevq_rec = self.basevq.decode(quant) |
quant, *_ = self.vqgan.encode(img) |
vqgan_rec = self.vqgan.decode(quant) |
quant, *_ = self.optvq.encode(img) |
optvq_rec = self.optvq.decode(quant) |
img = self.tensor_to_image(img) |
basevq_rec = self.tensor_to_image(basevq_rec) |
vqgan_rec = self.tensor_to_image(vqgan_rec) |
optvq_rec = self.tensor_to_image(optvq_rec) |
return img, basevq_rec, vqgan_rec, optvq_rec |
if __name__ == "__main__": |
handler = Handler(device=device) |
with gr.Blocks() as demo: |
gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results") |
with gr.Row(): |
with gr.Column(): |
image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy") |
btn_demo1 = gr.Button(value="Run reconstruction") |
image_basevq = gr.Image(label="BaseVQ rec.") |
image_vqgan = gr.Image(label="VQGAN rec.") |
image_optvq = gr.Image(label="OptVQ rec.") |
btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq]) |
gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results") |
with gr.Row(): |
with gr.Column(): |
input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1) |
input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1) |
input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1) |
btn_demo2 = gr.Button(value="Run 2D example") |
output_nn = gr.Image(label="NN") |
output_optvq = gr.Image(label="OptVQ") |
input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
demo.launch() |