import gradio as gr from io import BytesIO import os import sys import matplotlib.pyplot as plt import matplotlib import numpy as np from PIL import Image from omegaconf import OmegaConf import torch from torchvision import transforms as T from optvq.models.quantizer import sinkhorn from optvq.utils.init import seed_everything seed_everything(42) from optvq.models.vqgan_hf import VQModelHF matplotlib.rcParams['font.family'] = 'Times New Roman' ################# N_data = 50 N_code = 20 dim = 2 handler = None device = torch.device("cpu") ################# def nearest(src, trg): dis_mat = torch.cdist(src, trg) min_idx = torch.argmin(dis_mat, dim=-1) return min_idx def normalize(A, dim, mode="all"): if mode == "all": A = (A - A.mean()) / (A.std() + 1e-6) A = A - A.min() elif mode == "dim": A = A / dim elif mode == "null": pass return A def draw_NN(data, code): # nearest neighbor method indices = nearest(data, code) data = data.numpy() code = code.numpy() plt.figure(figsize=(3, 2.5), dpi=400) # draw arrows in blue color, alpha=0.5 for i in range(data.shape[0]): idx = indices[i].item() start = data[i] end = code[idx] plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6, ls="-", lw=0.5) plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") plt.legend(loc="lower right") plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) plt.title("Nearest neighbor") buf = BytesIO() plt.savefig(buf, format="png") buf.seek(0) image = Image.open(buf) return image def draw_optvq(data, code): cost = torch.cdist(data, code, p=2.0) cost = normalize(cost, dim, mode="all") Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False) indices = torch.argmax(Q, dim=-1) data = data.numpy() code = code.numpy() plt.figure(figsize=(3, 2.5), dpi=400) # draw arrows in blue color, alpha=0.5 for i in range(data.shape[0]): idx = indices[i].item() start = data[i] end = code[idx] plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6, ls="-", lw=0.5) plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") plt.legend(loc="lower right") plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) plt.title("Optimal Transport (OptVQ)") buf = BytesIO() plt.savefig(buf, format="png") buf.seek(0) image = Image.open(buf) return image def draw_process(x, y, std): data = torch.randn(N_data, dim) code = torch.randn(N_code, dim) * std code[:, 0] += x code[:, 1] += y image_NN = draw_NN(data, code) image_optvq = draw_optvq(data, code) return image_NN, image_optvq class Handler: def __init__(self, device): self.transform = T.Compose([ T.Resize(256), T.CenterCrop(256), T.ToTensor() ]) self.device = device self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4") self.basevq.to(self.device) self.basevq.eval() self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16") self.vqgan.to(self.device) self.vqgan.eval() self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4") self.optvq.to(self.device) self.optvq.eval() def tensor_to_image(self, tensor): img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy() img = (img + 1) / 2 * 255 img = img.astype("uint8") return img def process_image(self, img: np.ndarray): img = Image.fromarray(img.astype("uint8")) img = self.transform(img) img = img.unsqueeze(0).to(self.device) with torch.no_grad(): img = 2 * img - 1 # basevq quant, *_ = self.basevq.encode(img) basevq_rec = self.basevq.decode(quant) # vqgan quant, *_ = self.vqgan.encode(img) vqgan_rec = self.vqgan.decode(quant) # optvq quant, *_ = self.optvq.encode(img) optvq_rec = self.optvq.decode(quant) # tensor to PIL image img = self.tensor_to_image(img) basevq_rec = self.tensor_to_image(basevq_rec) vqgan_rec = self.tensor_to_image(vqgan_rec) optvq_rec = self.tensor_to_image(optvq_rec) return img, basevq_rec, vqgan_rec, optvq_rec if __name__ == "__main__": # create the model handler handler = Handler(device=device) # create the interface with gr.Blocks() as demo: gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy") btn_demo1 = gr.Button(value="Run reconstruction") image_basevq = gr.Image(label="BaseVQ rec.") image_vqgan = gr.Image(label="VQGAN rec.") image_optvq = gr.Image(label="OptVQ rec.") btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq]) gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results") with gr.Row(): with gr.Column(): input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1) input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1) input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1) btn_demo2 = gr.Button(value="Run 2D example") output_nn = gr.Image(label="NN") output_optvq = gr.Image(label="OptVQ") # set the function input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) demo.launch()