import gradio as gr import os import PIL from PIL import Image from pathlib import Path import numpy as np import numpy.random as npr from contextlib import nullcontext import torch import torchvision.transforms as tvtrans from lib.cfg_helper import model_cfg_bank from lib.model_zoo import get_model n_sample_image_default = 2 n_sample_text_default = 4 cache_examples = True hfm_repo_id = 'shi-labs/versatile-diffusion-model' hfm_filename = 'pretrained_pth/vd-four-flow-v1-0-fp16.pth' def highlight_print(info): print('') print(''.join(['#']*(len(info)+4))) print('# '+info+' #') print(''.join(['#']*(len(info)+4))) print('') class color_adjust(object): def __init__(self, ref_from, ref_to): x0, m0, std0 = self.get_data_and_stat(ref_from) x1, m1, std1 = self.get_data_and_stat(ref_to) self.ref_from_stat = (m0, std0) self.ref_to_stat = (m1, std1) self.ref_from = self.preprocess(x0).reshape(-1, 3) self.ref_to = x1.reshape(-1, 3) def get_data_and_stat(self, x): if isinstance(x, str): x = np.array(PIL.Image.open(x)) elif isinstance(x, PIL.Image.Image): x = np.array(x) elif isinstance(x, torch.Tensor): x = torch.clamp(x, min=0.0, max=1.0) x = np.array(tvtrans.ToPILImage()(x)) elif isinstance(x, np.ndarray): pass else: raise ValueError x = x.astype(float) m = np.reshape(x, (-1, 3)).mean(0) s = np.reshape(x, (-1, 3)).std(0) return x, m, s def preprocess(self, x): m0, s0 = self.ref_from_stat m1, s1 = self.ref_to_stat y = ((x-m0)/s0)*s1 + m1 return y def __call__(self, xin, keep=0, simple=False): xin, _, _ = self.get_data_and_stat(xin) x = self.preprocess(xin) if simple: y = (x*(1-keep) + xin*keep) y = np.clip(y, 0, 255).astype(np.uint8) return y h, w = x.shape[:2] x = x.reshape(-1, 3) y = [] for chi in range(3): yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi]) y.append(yi) y = np.stack(y, axis=1) y = y.reshape(h, w, 3) y = (y.astype(float)*(1-keep) + xin.astype(float)*keep) y = np.clip(y, 0, 255).astype(np.uint8) return y def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600): arr = np.concatenate((arr_fo, arr_to)) min_v = arr.min() - 1e-6 max_v = arr.max() + 1e-6 min_vto = arr_to.min() - 1e-6 max_vto = arr_to.max() + 1e-6 xs = np.array( [min_v + (max_v - min_v) * i / n for i in range(n + 1)]) hist_fo, _ = np.histogram(arr_fo, xs) hist_to, _ = np.histogram(arr_to, xs) xs = xs[:-1] # compute probability distribution cum_fo = np.cumsum(hist_fo) cum_to = np.cumsum(hist_to) d_fo = cum_fo / cum_fo[-1] d_to = cum_to / cum_to[-1] # transfer t_d = np.interp(d_fo, d_to, xs) t_d[d_fo <= d_to[ 0]] = min_vto t_d[d_fo >= d_to[-1]] = max_vto arr_out = np.interp(arr_in, xs, t_d) return arr_out class vd_inference(object): def __init__(self, pth=None, hfm_repo=None, fp16=False, device=0): cfgm_name = 'vd_noema' cfgm = model_cfg_bank()('vd_noema') net = get_model()(cfgm) if fp16: highlight_print('Running in FP16') net.clip.fp16 = True net = net.half() if pth is not None: sd = torch.load(pth, map_location='cpu') print('Load pretrained weight from {}'.format(pth)) else: from huggingface_hub import hf_hub_download temppath = hf_hub_download(hfm_repo[0], hfm_repo[1]) sd = torch.load(temppath, map_location='cpu') print('Load pretrained weight from {}/{}'.format(*hfm_repo)) net.load_state_dict(sd, strict=False) net.to(device) self.device = device self.model_name = cfgm_name self.net = net self.fp16 = fp16 from lib.model_zoo.ddim_vd import DDIMSampler_VD self.sampler = DDIMSampler_VD(net) def regularize_image(self, x): BICUBIC = PIL.Image.Resampling.BICUBIC if isinstance(x, str): x = Image.open(x).resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, PIL.Image.Image): x = x.resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, np.ndarray): x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) x = tvtrans.ToTensor()(x) elif isinstance(x, torch.Tensor): pass else: assert False, 'Unknown image type' assert (x.shape[1]==512) & (x.shape[2]==512), \ 'Wrong image size' x = x.to(self.device) if self.fp16: x = x.half() return x def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None): net = self.net if xtype == 'image': x = net.autokl_decode(z) color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None) color_adj_simple = (color_adj=='Simple') or color_adj=='simple' color_adj_keep_ratio = 0.5 if color_adj_flag and (ctype=='vision'): x_adj = [] for xi in x: color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) x_adj.append(xi_adj) x = x_adj else: x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) x = [tvtrans.ToPILImage()(xi) for xi in x] return x elif xtype == 'text': prompt_temperature = 1.0 prompt_merge_same_adj_word = True x = net.optimus_decode(z, temperature=prompt_temperature) if prompt_merge_same_adj_word: xnew = [] for xi in x: xi_split = xi.split() xinew = [] for idxi, wi in enumerate(xi_split): if idxi!=0 and wi==xi_split[idxi-1]: continue xinew.append(wi) xnew.append(' '.join(xinew)) x = xnew return x def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,): net = self.net sampler = self.sampler ddim_steps = 50 ddim_eta = 0.0 if xtype == 'image': n_samples = n_sample_image_default if n_samples is None else n_samples elif xtype == 'text': n_samples = n_sample_text_default if n_samples is None else n_samples if ctype in ['prompt', 'text']: c = net.clip_encode_text(n_samples * [cin]) u = None if scale != 1.0: u = net.clip_encode_text(n_samples * [""]) elif ctype in ['vision', 'image']: cin = self.regularize_image(cin) ctemp = cin*2 - 1 ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) c = net.clip_encode_vision(ctemp) u = None if scale != 1.0: dummy = torch.zeros_like(ctemp) u = net.clip_encode_vision(dummy) u, c = [u.half(), c.half()] if self.fp16 else [u, c] if xtype == 'image': h, w = [512, 512] shape = [n_samples, 4, h//8, w//8] z, _ = sampler.sample( steps=ddim_steps, shape=shape, conditioning=c, unconditional_guidance_scale=scale, unconditional_conditioning=u, xtype=xtype, ctype=ctype, eta=ddim_eta, verbose=False,) x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin) return x elif xtype == 'text': n = 768 shape = [n_samples, n] z, _ = sampler.sample( steps=ddim_steps, shape=shape, conditioning=c, unconditional_guidance_scale=scale, unconditional_conditioning=u, xtype=xtype, ctype=ctype, eta=ddim_eta, verbose=False,) x = self.decode(z, xtype, ctype) return x def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,): net = self.net scale = 7.5 sampler = self.sampler ddim_steps = 50 ddim_eta = 0.0 n_samples = n_sample_image_default if n_samples is None else n_samples cin = self.regularize_image(cin) ctemp = cin*2 - 1 ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) c = net.clip_encode_vision(ctemp) u = None if scale != 1.0: dummy = torch.zeros_like(ctemp) u = net.clip_encode_vision(dummy) u, c = [u.half(), c.half()] if self.fp16 else [u, c] if level == 0: pass else: c_glb = c[:, 0:1] c_loc = c[:, 1: ] u_glb = u[:, 0:1] u_loc = u[:, 1: ] if level == -1: c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1) u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1) if level == -2: c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2) u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2) if level == 1: c_loc = self.find_low_rank(c_loc, demean=True, q=10) u_loc = self.find_low_rank(u_loc, demean=True, q=10) if level == 2: c_loc = self.find_low_rank(c_loc, demean=True, q=2) u_loc = self.find_low_rank(u_loc, demean=True, q=2) c = torch.cat([c_glb, c_loc], dim=1) u = torch.cat([u_glb, u_loc], dim=1) h, w = [512, 512] shape = [n_samples, 4, h//8, w//8] z, _ = sampler.sample( steps=ddim_steps, shape=shape, conditioning=c, unconditional_guidance_scale=scale, unconditional_conditioning=u, xtype='image', ctype='vision', eta=ddim_eta, verbose=False,) x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin) return x def find_low_rank(self, x, demean=True, q=20, niter=10): if demean: x_mean = x.mean(-1, keepdim=True) x_input = x - x_mean else: x_input = x if x_input.dtype == torch.float16: fp16 = True x_input = x_input.float() else: fp16 = False u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) ss = torch.stack([torch.diag(si) for si in s]) x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) if fp16: x_lowrank = x_lowrank.half() if demean: x_lowrank += x_mean return x_lowrank def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10): if demean: x_mean = x.mean(-1, keepdim=True) x_input = x - x_mean else: x_input = x if x_input.dtype == torch.float16: fp16 = True x_input = x_input.float() else: fp16 = False u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) s[:, 0:q_remove] = 0 ss = torch.stack([torch.diag(si) for si in s]) x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) if fp16: x_lowrank = x_lowrank.half() if demean: x_lowrank += x_mean return x_lowrank def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ): net = self.net scale = 7.5 sampler = self.sampler ddim_steps = 50 ddim_eta = 0.0 n_samples = n_sample_image_default if n_samples is None else n_samples ctemp0 = self.regularize_image(cim) ctemp1 = ctemp0*2 - 1 ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) cim = net.clip_encode_vision(ctemp1) uim = None if scale != 1.0: dummy = torch.zeros_like(ctemp1) uim = net.clip_encode_vision(dummy) ctx = net.clip_encode_text(n_samples * [ctx]) utx = None if scale != 1.0: utx = net.clip_encode_text(n_samples * [""]) uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx] h, w = [512, 512] shape = [n_samples, 4, h//8, w//8] z, _ = sampler.sample_dc( steps=ddim_steps, shape=shape, first_conditioning=[uim, cim], second_conditioning=[utx, ctx], unconditional_guidance_scale=scale, xtype='image', first_ctype='vision', second_ctype='prompt', eta=ddim_eta, verbose=False, mixed_ratio=(1-mixing), ) x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) return x def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,): net = self.net scale = 7.5 sampler = self.sampler ddim_steps = 50 ddim_eta = 0.0 prompt_temperature = 1.0 n_samples = n_sample_image_default if n_samples is None else n_samples ctemp0 = self.regularize_image(cim) ctemp1 = ctemp0*2 - 1 ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) cim = net.clip_encode_vision(ctemp1) uim = None if scale != 1.0: dummy = torch.zeros_like(ctemp1) uim = net.clip_encode_vision(dummy) uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] n = 768 shape = [n_samples, n] zt, _ = sampler.sample( steps=ddim_steps, shape=shape, conditioning=cim, unconditional_guidance_scale=scale, unconditional_conditioning=uim, xtype='text', ctype='vision', eta=ddim_eta, verbose=False,) ztn = net.optimus_encode([ctx_n]) ztp = net.optimus_encode([ctx_p]) ztn_norm = ztn / ztn.norm(dim=1) zt_proj_mag = torch.matmul(zt, ztn_norm[0]) zt_perp = zt - zt_proj_mag[:, None] * ztn_norm zt_newd = zt_perp + ztp ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature) ctx_new = net.clip_encode_text(ctx_new) ctx_p = net.clip_encode_text([ctx_p]) ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1) utx_new = net.clip_encode_text(n_samples * [""]) utx_new = torch.cat([utx_new, utx_new], dim=1) cim_loc = cim[:, 1: ] cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10) cim_new = cim_loc_new uim_new = uim[:, 1:] h, w = [512, 512] shape = [n_samples, 4, h//8, w//8] z, _ = sampler.sample_dc( steps=ddim_steps, shape=shape, first_conditioning=[uim_new, cim_new], second_conditioning=[utx_new, ctx_new], unconditional_guidance_scale=scale, xtype='image', first_ctype='vision', second_ctype='prompt', eta=ddim_eta, verbose=False, mixed_ratio=0.33, ) x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) return x vd_inference = vd_inference(hfm_repo=[hfm_repo_id, hfm_filename], fp16=True, device='cuda') def main(mode, image=None, prompt=None, nprompt=None, pprompt=None, color_adj=None, disentanglement_level=None, dual_guided_mixing=None, seed=0,): if seed<0: seed = 0 np.random.seed(seed) torch.manual_seed(seed+100) if mode == 'Text-to-Image': if (prompt is None) or (prompt == ""): return None, None with torch.no_grad(): rv = vd_inference.inference( xtype = 'image', cin = prompt, ctype = 'prompt', ) return rv, None elif mode == 'Image-Variation': if image is None: return None, None with torch.no_grad(): rv = vd_inference.inference( xtype = 'image', cin = image, ctype = 'vision', color_adj = color_adj,) return rv, None elif mode == 'Image-to-Text': if image is None: return None, None with torch.no_grad(): rv = vd_inference.inference( xtype = 'text', cin = image, ctype = 'vision',) return None, '\n'.join(rv) elif mode == 'Text-Variation': if prompt is None: return None, None with torch.no_grad(): rv = vd_inference.inference( xtype = 'text', cin = prompt, ctype = 'prompt',) return None, '\n'.join(rv) elif mode == 'Disentanglement': if image is None: return None, None with torch.no_grad(): rv = vd_inference.application_disensemble( cin = image, level = disentanglement_level, color_adj = color_adj,) return rv, None elif mode == 'Dual-Guided': if (image is None) or (prompt is None) or (prompt==""): return None, None with torch.no_grad(): rv = vd_inference.application_dualguided( cim = image, ctx = prompt, mixing = dual_guided_mixing, color_adj = color_adj,) return rv, None elif mode == 'Latent-I2T2I': if (image is None) or (nprompt is None) or (nprompt=="") \ or (pprompt is None) or (pprompt==""): return None, None with torch.no_grad(): rv = vd_inference.application_i2t2i( cim = image, ctx_n = nprompt, ctx_p = pprompt, color_adj = color_adj,) return rv, None else: assert False, "No such mode!" def get_instruction(mode): t2i_instruction = ["Generate image from text prompt."] i2i_instruction = [ "Generate image conditioned on reference image.", "Color Calibration provide an opinion to adjust image color according to reference image.", ] i2t_instruction = ["Generate text from reference image."] t2t_instruction = ["Generate text from reference text prompt. (Model insufficiently trained, thus results are still experimental)"] dis_instruction = [ "Generate a variation of reference image that disentangled for semantic or style.", "Color Calibration provide an opinion to adjust image color according to reference image.", "Disentanglement level controls the level of focus towards semantic (-2, -1) or style (1 2). Level 0 serves as Image-Variation.", ] dug_instruction = [ "Generate image from dual guidance of reference image and text prompt.", "Color Calibration provide an opinion to adjust image color according to reference image.", "Guidance Mixing provides linear balances between image and text context. (0 towards image, 1 towards text)", ] iti_instruction = [ "Generate image variations via image-to-text, text-latent-editing, and then text-to-image. (Still under exploration)", "Color Calibration provide an opinion to adjust image color according to reference image.", "Input prompt that will be substract from text/text latent code.", "Input prompt that will be added to text/text latent code.", ] if mode == "Text-to-Image": return '\n'.join(t2i_instruction) elif mode == "Image-Variation": return '\n'.join(i2i_instruction) elif mode == "Image-to-Text": return '\n'.join(i2t_instruction) elif mode == "Text-Variation": return '\n'.join(t2t_instruction) elif mode == "Disentanglement": return '\n'.join(dis_instruction) elif mode == "Dual-Guided": return '\n'.join(dug_instruction) elif mode == "Latent-I2T2I": return '\n'.join(iti_instruction) ############# # Interface # ############# if True: img_output = gr.Gallery(label="Image Result").style(grid=n_sample_image_default) txt_output = gr.Textbox(lines=4, label='Text Result', visible=False) with gr.Blocks() as demo: gr.HTML( """