|
import gradio as gr |
|
from io import BytesIO |
|
import os |
|
import sys |
|
|
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import numpy as np |
|
from PIL import Image |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
from torchvision import transforms as T |
|
|
|
from optvq.models.quantizer import sinkhorn |
|
from optvq.utils.init import seed_everything |
|
seed_everything(42) |
|
from optvq.models.vqgan_hf import VQModelHF |
|
matplotlib.rcParams['font.family'] = 'Times New Roman' |
|
|
|
|
|
N_data = 50 |
|
N_code = 20 |
|
dim = 2 |
|
handler = None |
|
device = torch.device("cpu") |
|
|
|
|
|
def nearest(src, trg): |
|
dis_mat = torch.cdist(src, trg) |
|
min_idx = torch.argmin(dis_mat, dim=-1) |
|
return min_idx |
|
|
|
def normalize(A, dim, mode="all"): |
|
if mode == "all": |
|
A = (A - A.mean()) / (A.std() + 1e-6) |
|
A = A - A.min() |
|
elif mode == "dim": |
|
A = A / dim |
|
elif mode == "null": |
|
pass |
|
return A |
|
|
|
def draw_NN(data, code): |
|
|
|
indices = nearest(data, code) |
|
data = data.numpy() |
|
code = code.numpy() |
|
|
|
plt.figure(figsize=(3, 2.5), dpi=400) |
|
|
|
for i in range(data.shape[0]): |
|
idx = indices[i].item() |
|
start = data[i] |
|
end = code[idx] |
|
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], |
|
head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6, |
|
ls="-", lw=0.5) |
|
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") |
|
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") |
|
plt.legend(loc="lower right") |
|
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) |
|
plt.title("Nearest neighbor") |
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
return image |
|
|
|
def draw_optvq(data, code): |
|
cost = torch.cdist(data, code, p=2.0) |
|
cost = normalize(cost, dim, mode="all") |
|
Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False) |
|
indices = torch.argmax(Q, dim=-1) |
|
data = data.numpy() |
|
code = code.numpy() |
|
|
|
plt.figure(figsize=(3, 2.5), dpi=400) |
|
|
|
for i in range(data.shape[0]): |
|
idx = indices[i].item() |
|
start = data[i] |
|
end = code[idx] |
|
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], |
|
head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6, |
|
ls="-", lw=0.5) |
|
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") |
|
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") |
|
plt.legend(loc="lower right") |
|
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) |
|
plt.title("Optimal Transport (OptVQ)") |
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
image = Image.open(buf) |
|
return image |
|
|
|
def draw_process(x, y, std): |
|
data = torch.randn(N_data, dim) |
|
code = torch.randn(N_code, dim) * std |
|
code[:, 0] += x |
|
code[:, 1] += y |
|
|
|
image_NN = draw_NN(data, code) |
|
image_optvq = draw_optvq(data, code) |
|
|
|
return image_NN, image_optvq |
|
|
|
class Handler: |
|
def __init__(self, device): |
|
self.transform = T.Compose([ |
|
T.Resize(256), |
|
T.CenterCrop(256), |
|
T.ToTensor() |
|
]) |
|
self.device = device |
|
|
|
|
|
self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4") |
|
self.basevq.to(self.device) |
|
self.basevq.eval() |
|
|
|
self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16") |
|
self.vqgan.to(self.device) |
|
self.vqgan.eval() |
|
|
|
self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4") |
|
self.optvq.to(self.device) |
|
self.optvq.eval() |
|
|
|
def tensor_to_image(self, tensor): |
|
img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy() |
|
img = (img + 1) / 2 * 255 |
|
img = img.astype("uint8") |
|
return img |
|
|
|
def process_image(self, img: np.ndarray): |
|
img = Image.fromarray(img.astype("uint8")) |
|
img = self.transform(img) |
|
img = img.unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
img = 2 * img - 1 |
|
|
|
quant, *_ = self.basevq.encode(img) |
|
basevq_rec = self.basevq.decode(quant) |
|
|
|
quant, *_ = self.vqgan.encode(img) |
|
vqgan_rec = self.vqgan.decode(quant) |
|
|
|
quant, *_ = self.optvq.encode(img) |
|
optvq_rec = self.optvq.decode(quant) |
|
|
|
|
|
img = self.tensor_to_image(img) |
|
basevq_rec = self.tensor_to_image(basevq_rec) |
|
vqgan_rec = self.tensor_to_image(vqgan_rec) |
|
optvq_rec = self.tensor_to_image(optvq_rec) |
|
return img, basevq_rec, vqgan_rec, optvq_rec |
|
|
|
if __name__ == "__main__": |
|
|
|
handler = Handler(device=device) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy") |
|
btn_demo1 = gr.Button(value="Run reconstruction") |
|
image_basevq = gr.Image(label="BaseVQ rec.") |
|
image_vqgan = gr.Image(label="VQGAN rec.") |
|
image_optvq = gr.Image(label="OptVQ rec.") |
|
btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq]) |
|
|
|
gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results") |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1) |
|
input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1) |
|
input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1) |
|
btn_demo2 = gr.Button(value="Run 2D example") |
|
output_nn = gr.Image(label="NN") |
|
output_optvq = gr.Image(label="OptVQ") |
|
|
|
|
|
input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
|
input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
|
input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
|
btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) |
|
|
|
demo.launch() |