Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import PIL | |
| from PIL import Image | |
| from pathlib import Path | |
| import numpy as np | |
| import numpy.random as npr | |
| from contextlib import nullcontext | |
| import torch | |
| import torchvision.transforms as tvtrans | |
| from lib.cfg_helper import model_cfg_bank | |
| from lib.model_zoo import get_model | |
| n_sample_image_default = 2 | |
| n_sample_text_default = 4 | |
| cache_examples = True | |
| hfm_repo_id = 'shi-labs/versatile-diffusion-model' | |
| hfm_filename = 'pretrained_pth/vd-four-flow-v1-0-fp16.pth' | |
| def highlight_print(info): | |
| print('') | |
| print(''.join(['#']*(len(info)+4))) | |
| print('# '+info+' #') | |
| print(''.join(['#']*(len(info)+4))) | |
| print('') | |
| class color_adjust(object): | |
| def __init__(self, ref_from, ref_to): | |
| x0, m0, std0 = self.get_data_and_stat(ref_from) | |
| x1, m1, std1 = self.get_data_and_stat(ref_to) | |
| self.ref_from_stat = (m0, std0) | |
| self.ref_to_stat = (m1, std1) | |
| self.ref_from = self.preprocess(x0).reshape(-1, 3) | |
| self.ref_to = x1.reshape(-1, 3) | |
| def get_data_and_stat(self, x): | |
| if isinstance(x, str): | |
| x = np.array(PIL.Image.open(x)) | |
| elif isinstance(x, PIL.Image.Image): | |
| x = np.array(x) | |
| elif isinstance(x, torch.Tensor): | |
| x = torch.clamp(x, min=0.0, max=1.0) | |
| x = np.array(tvtrans.ToPILImage()(x)) | |
| elif isinstance(x, np.ndarray): | |
| pass | |
| else: | |
| raise ValueError | |
| x = x.astype(float) | |
| m = np.reshape(x, (-1, 3)).mean(0) | |
| s = np.reshape(x, (-1, 3)).std(0) | |
| return x, m, s | |
| def preprocess(self, x): | |
| m0, s0 = self.ref_from_stat | |
| m1, s1 = self.ref_to_stat | |
| y = ((x-m0)/s0)*s1 + m1 | |
| return y | |
| def __call__(self, xin, keep=0, simple=False): | |
| xin, _, _ = self.get_data_and_stat(xin) | |
| x = self.preprocess(xin) | |
| if simple: | |
| y = (x*(1-keep) + xin*keep) | |
| y = np.clip(y, 0, 255).astype(np.uint8) | |
| return y | |
| h, w = x.shape[:2] | |
| x = x.reshape(-1, 3) | |
| y = [] | |
| for chi in range(3): | |
| yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi]) | |
| y.append(yi) | |
| y = np.stack(y, axis=1) | |
| y = y.reshape(h, w, 3) | |
| y = (y.astype(float)*(1-keep) + xin.astype(float)*keep) | |
| y = np.clip(y, 0, 255).astype(np.uint8) | |
| return y | |
| def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600): | |
| arr = np.concatenate((arr_fo, arr_to)) | |
| min_v = arr.min() - 1e-6 | |
| max_v = arr.max() + 1e-6 | |
| min_vto = arr_to.min() - 1e-6 | |
| max_vto = arr_to.max() + 1e-6 | |
| xs = np.array( | |
| [min_v + (max_v - min_v) * i / n for i in range(n + 1)]) | |
| hist_fo, _ = np.histogram(arr_fo, xs) | |
| hist_to, _ = np.histogram(arr_to, xs) | |
| xs = xs[:-1] | |
| # compute probability distribution | |
| cum_fo = np.cumsum(hist_fo) | |
| cum_to = np.cumsum(hist_to) | |
| d_fo = cum_fo / cum_fo[-1] | |
| d_to = cum_to / cum_to[-1] | |
| # transfer | |
| t_d = np.interp(d_fo, d_to, xs) | |
| t_d[d_fo <= d_to[ 0]] = min_vto | |
| t_d[d_fo >= d_to[-1]] = max_vto | |
| arr_out = np.interp(arr_in, xs, t_d) | |
| return arr_out | |
| class vd_inference(object): | |
| def __init__(self, pth=None, hfm_repo=None, fp16=False, device=0): | |
| cfgm_name = 'vd_noema' | |
| cfgm = model_cfg_bank()('vd_noema') | |
| net = get_model()(cfgm) | |
| if fp16: | |
| highlight_print('Running in FP16') | |
| net.clip.fp16 = True | |
| net = net.half() | |
| if pth is not None: | |
| sd = torch.load(pth, map_location='cpu') | |
| print('Load pretrained weight from {}'.format(pth)) | |
| else: | |
| from huggingface_hub import hf_hub_download | |
| temppath = hf_hub_download(hfm_repo[0], hfm_repo[1]) | |
| sd = torch.load(temppath, map_location='cpu') | |
| print('Load pretrained weight from {}/{}'.format(*hfm_repo)) | |
| net.load_state_dict(sd, strict=False) | |
| net.to(device) | |
| self.device = device | |
| self.model_name = cfgm_name | |
| self.net = net | |
| self.fp16 = fp16 | |
| from lib.model_zoo.ddim_vd import DDIMSampler_VD | |
| self.sampler = DDIMSampler_VD(net) | |
| def regularize_image(self, x): | |
| BICUBIC = PIL.Image.Resampling.BICUBIC | |
| if isinstance(x, str): | |
| x = Image.open(x).resize([512, 512], resample=BICUBIC) | |
| x = tvtrans.ToTensor()(x) | |
| elif isinstance(x, PIL.Image.Image): | |
| x = x.resize([512, 512], resample=BICUBIC) | |
| x = tvtrans.ToTensor()(x) | |
| elif isinstance(x, np.ndarray): | |
| x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) | |
| x = tvtrans.ToTensor()(x) | |
| elif isinstance(x, torch.Tensor): | |
| pass | |
| else: | |
| assert False, 'Unknown image type' | |
| assert (x.shape[1]==512) & (x.shape[2]==512), \ | |
| 'Wrong image size' | |
| x = x.to(self.device) | |
| if self.fp16: | |
| x = x.half() | |
| return x | |
| def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None): | |
| net = self.net | |
| if xtype == 'image': | |
| x = net.autokl_decode(z) | |
| color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None) | |
| color_adj_simple = (color_adj=='Simple') or color_adj=='simple' | |
| color_adj_keep_ratio = 0.5 | |
| if color_adj_flag and (ctype=='vision'): | |
| x_adj = [] | |
| for xi in x: | |
| color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) | |
| xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) | |
| x_adj.append(xi_adj) | |
| x = x_adj | |
| else: | |
| x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) | |
| x = [tvtrans.ToPILImage()(xi) for xi in x] | |
| return x | |
| elif xtype == 'text': | |
| prompt_temperature = 1.0 | |
| prompt_merge_same_adj_word = True | |
| x = net.optimus_decode(z, temperature=prompt_temperature) | |
| if prompt_merge_same_adj_word: | |
| xnew = [] | |
| for xi in x: | |
| xi_split = xi.split() | |
| xinew = [] | |
| for idxi, wi in enumerate(xi_split): | |
| if idxi!=0 and wi==xi_split[idxi-1]: | |
| continue | |
| xinew.append(wi) | |
| xnew.append(' '.join(xinew)) | |
| x = xnew | |
| return x | |
| def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,): | |
| net = self.net | |
| sampler = self.sampler | |
| ddim_steps = 50 | |
| ddim_eta = 0.0 | |
| if xtype == 'image': | |
| n_samples = n_sample_image_default if n_samples is None else n_samples | |
| elif xtype == 'text': | |
| n_samples = n_sample_text_default if n_samples is None else n_samples | |
| if ctype in ['prompt', 'text']: | |
| c = net.clip_encode_text(n_samples * [cin]) | |
| u = None | |
| if scale != 1.0: | |
| u = net.clip_encode_text(n_samples * [""]) | |
| elif ctype in ['vision', 'image']: | |
| cin = self.regularize_image(cin) | |
| ctemp = cin*2 - 1 | |
| ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) | |
| c = net.clip_encode_vision(ctemp) | |
| u = None | |
| if scale != 1.0: | |
| dummy = torch.zeros_like(ctemp) | |
| u = net.clip_encode_vision(dummy) | |
| u, c = [u.half(), c.half()] if self.fp16 else [u, c] | |
| if xtype == 'image': | |
| h, w = [512, 512] | |
| shape = [n_samples, 4, h//8, w//8] | |
| z, _ = sampler.sample( | |
| steps=ddim_steps, | |
| shape=shape, | |
| conditioning=c, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=u, | |
| xtype=xtype, ctype=ctype, | |
| eta=ddim_eta, | |
| verbose=False,) | |
| x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin) | |
| return x | |
| elif xtype == 'text': | |
| n = 768 | |
| shape = [n_samples, n] | |
| z, _ = sampler.sample( | |
| steps=ddim_steps, | |
| shape=shape, | |
| conditioning=c, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=u, | |
| xtype=xtype, ctype=ctype, | |
| eta=ddim_eta, | |
| verbose=False,) | |
| x = self.decode(z, xtype, ctype) | |
| return x | |
| def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,): | |
| net = self.net | |
| scale = 7.5 | |
| sampler = self.sampler | |
| ddim_steps = 50 | |
| ddim_eta = 0.0 | |
| n_samples = n_sample_image_default if n_samples is None else n_samples | |
| cin = self.regularize_image(cin) | |
| ctemp = cin*2 - 1 | |
| ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) | |
| c = net.clip_encode_vision(ctemp) | |
| u = None | |
| if scale != 1.0: | |
| dummy = torch.zeros_like(ctemp) | |
| u = net.clip_encode_vision(dummy) | |
| u, c = [u.half(), c.half()] if self.fp16 else [u, c] | |
| if level == 0: | |
| pass | |
| else: | |
| c_glb = c[:, 0:1] | |
| c_loc = c[:, 1: ] | |
| u_glb = u[:, 0:1] | |
| u_loc = u[:, 1: ] | |
| if level == -1: | |
| c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1) | |
| u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1) | |
| if level == -2: | |
| c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2) | |
| u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2) | |
| if level == 1: | |
| c_loc = self.find_low_rank(c_loc, demean=True, q=10) | |
| u_loc = self.find_low_rank(u_loc, demean=True, q=10) | |
| if level == 2: | |
| c_loc = self.find_low_rank(c_loc, demean=True, q=2) | |
| u_loc = self.find_low_rank(u_loc, demean=True, q=2) | |
| c = torch.cat([c_glb, c_loc], dim=1) | |
| u = torch.cat([u_glb, u_loc], dim=1) | |
| h, w = [512, 512] | |
| shape = [n_samples, 4, h//8, w//8] | |
| z, _ = sampler.sample( | |
| steps=ddim_steps, | |
| shape=shape, | |
| conditioning=c, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=u, | |
| xtype='image', ctype='vision', | |
| eta=ddim_eta, | |
| verbose=False,) | |
| x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin) | |
| return x | |
| def find_low_rank(self, x, demean=True, q=20, niter=10): | |
| if demean: | |
| x_mean = x.mean(-1, keepdim=True) | |
| x_input = x - x_mean | |
| else: | |
| x_input = x | |
| if x_input.dtype == torch.float16: | |
| fp16 = True | |
| x_input = x_input.float() | |
| else: | |
| fp16 = False | |
| u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) | |
| ss = torch.stack([torch.diag(si) for si in s]) | |
| x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) | |
| if fp16: | |
| x_lowrank = x_lowrank.half() | |
| if demean: | |
| x_lowrank += x_mean | |
| return x_lowrank | |
| def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10): | |
| if demean: | |
| x_mean = x.mean(-1, keepdim=True) | |
| x_input = x - x_mean | |
| else: | |
| x_input = x | |
| if x_input.dtype == torch.float16: | |
| fp16 = True | |
| x_input = x_input.float() | |
| else: | |
| fp16 = False | |
| u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) | |
| s[:, 0:q_remove] = 0 | |
| ss = torch.stack([torch.diag(si) for si in s]) | |
| x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) | |
| if fp16: | |
| x_lowrank = x_lowrank.half() | |
| if demean: | |
| x_lowrank += x_mean | |
| return x_lowrank | |
| def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ): | |
| net = self.net | |
| scale = 7.5 | |
| sampler = self.sampler | |
| ddim_steps = 50 | |
| ddim_eta = 0.0 | |
| n_samples = n_sample_image_default if n_samples is None else n_samples | |
| ctemp0 = self.regularize_image(cim) | |
| ctemp1 = ctemp0*2 - 1 | |
| ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) | |
| cim = net.clip_encode_vision(ctemp1) | |
| uim = None | |
| if scale != 1.0: | |
| dummy = torch.zeros_like(ctemp1) | |
| uim = net.clip_encode_vision(dummy) | |
| ctx = net.clip_encode_text(n_samples * [ctx]) | |
| utx = None | |
| if scale != 1.0: | |
| utx = net.clip_encode_text(n_samples * [""]) | |
| uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] | |
| utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx] | |
| h, w = [512, 512] | |
| shape = [n_samples, 4, h//8, w//8] | |
| z, _ = sampler.sample_dc( | |
| steps=ddim_steps, | |
| shape=shape, | |
| first_conditioning=[uim, cim], | |
| second_conditioning=[utx, ctx], | |
| unconditional_guidance_scale=scale, | |
| xtype='image', | |
| first_ctype='vision', | |
| second_ctype='prompt', | |
| eta=ddim_eta, | |
| verbose=False, | |
| mixed_ratio=(1-mixing), ) | |
| x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) | |
| return x | |
| def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,): | |
| net = self.net | |
| scale = 7.5 | |
| sampler = self.sampler | |
| ddim_steps = 50 | |
| ddim_eta = 0.0 | |
| prompt_temperature = 1.0 | |
| n_samples = n_sample_image_default if n_samples is None else n_samples | |
| ctemp0 = self.regularize_image(cim) | |
| ctemp1 = ctemp0*2 - 1 | |
| ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) | |
| cim = net.clip_encode_vision(ctemp1) | |
| uim = None | |
| if scale != 1.0: | |
| dummy = torch.zeros_like(ctemp1) | |
| uim = net.clip_encode_vision(dummy) | |
| uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] | |
| n = 768 | |
| shape = [n_samples, n] | |
| zt, _ = sampler.sample( | |
| steps=ddim_steps, | |
| shape=shape, | |
| conditioning=cim, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=uim, | |
| xtype='text', ctype='vision', | |
| eta=ddim_eta, | |
| verbose=False,) | |
| ztn = net.optimus_encode([ctx_n]) | |
| ztp = net.optimus_encode([ctx_p]) | |
| ztn_norm = ztn / ztn.norm(dim=1) | |
| zt_proj_mag = torch.matmul(zt, ztn_norm[0]) | |
| zt_perp = zt - zt_proj_mag[:, None] * ztn_norm | |
| zt_newd = zt_perp + ztp | |
| ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature) | |
| ctx_new = net.clip_encode_text(ctx_new) | |
| ctx_p = net.clip_encode_text([ctx_p]) | |
| ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1) | |
| utx_new = net.clip_encode_text(n_samples * [""]) | |
| utx_new = torch.cat([utx_new, utx_new], dim=1) | |
| cim_loc = cim[:, 1: ] | |
| cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10) | |
| cim_new = cim_loc_new | |
| uim_new = uim[:, 1:] | |
| h, w = [512, 512] | |
| shape = [n_samples, 4, h//8, w//8] | |
| z, _ = sampler.sample_dc( | |
| steps=ddim_steps, | |
| shape=shape, | |
| first_conditioning=[uim_new, cim_new], | |
| second_conditioning=[utx_new, ctx_new], | |
| unconditional_guidance_scale=scale, | |
| xtype='image', | |
| first_ctype='vision', | |
| second_ctype='prompt', | |
| eta=ddim_eta, | |
| verbose=False, | |
| mixed_ratio=0.33, ) | |
| x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) | |
| return x | |
| vd_inference = vd_inference(hfm_repo=[hfm_repo_id, hfm_filename], fp16=True, device='cuda') | |
| def main(mode, | |
| image=None, | |
| prompt=None, | |
| nprompt=None, | |
| pprompt=None, | |
| color_adj=None, | |
| disentanglement_level=None, | |
| dual_guided_mixing=None, | |
| seed=0,): | |
| if seed<0: | |
| seed = 0 | |
| np.random.seed(seed) | |
| torch.manual_seed(seed+100) | |
| if mode == 'Text-to-Image': | |
| if (prompt is None) or (prompt == ""): | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.inference( | |
| xtype = 'image', | |
| cin = prompt, | |
| ctype = 'prompt', ) | |
| return rv, None | |
| elif mode == 'Image-Variation': | |
| if image is None: | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.inference( | |
| xtype = 'image', | |
| cin = image, | |
| ctype = 'vision', | |
| color_adj = color_adj,) | |
| return rv, None | |
| elif mode == 'Image-to-Text': | |
| if image is None: | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.inference( | |
| xtype = 'text', | |
| cin = image, | |
| ctype = 'vision',) | |
| return None, '\n'.join(rv) | |
| elif mode == 'Text-Variation': | |
| if prompt is None: | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.inference( | |
| xtype = 'text', | |
| cin = prompt, | |
| ctype = 'prompt',) | |
| return None, '\n'.join(rv) | |
| elif mode == 'Disentanglement': | |
| if image is None: | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.application_disensemble( | |
| cin = image, | |
| level = disentanglement_level, | |
| color_adj = color_adj,) | |
| return rv, None | |
| elif mode == 'Dual-Guided': | |
| if (image is None) or (prompt is None) or (prompt==""): | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.application_dualguided( | |
| cim = image, | |
| ctx = prompt, | |
| mixing = dual_guided_mixing, | |
| color_adj = color_adj,) | |
| return rv, None | |
| elif mode == 'Latent-I2T2I': | |
| if (image is None) or (nprompt is None) or (nprompt=="") \ | |
| or (pprompt is None) or (pprompt==""): | |
| return None, None | |
| with torch.no_grad(): | |
| rv = vd_inference.application_i2t2i( | |
| cim = image, | |
| ctx_n = nprompt, | |
| ctx_p = pprompt, | |
| color_adj = color_adj,) | |
| return rv, None | |
| else: | |
| assert False, "No such mode!" | |
| def get_instruction(mode): | |
| t2i_instruction = ["Generate image from text prompt."] | |
| i2i_instruction = [ | |
| "Generate image conditioned on reference image.", | |
| "Color Calibration provide an opinion to adjust image color according to reference image.", ] | |
| i2t_instruction = ["Generate text from reference image."] | |
| t2t_instruction = ["Generate text from reference text prompt. (Model insufficiently trained, thus results are still experimental)"] | |
| dis_instruction = [ | |
| "Generate a variation of reference image that disentangled for semantic or style.", | |
| "Color Calibration provide an opinion to adjust image color according to reference image.", | |
| "Disentanglement level controls the level of focus towards semantic (-2, -1) or style (1 2). Level 0 serves as Image-Variation.", ] | |
| dug_instruction = [ | |
| "Generate image from dual guidance of reference image and text prompt.", | |
| "Color Calibration provide an opinion to adjust image color according to reference image.", | |
| "Guidance Mixing provides linear balances between image and text context. (0 towards image, 1 towards text)", ] | |
| iti_instruction = [ | |
| "Generate image variations via image-to-text, text-latent-editing, and then text-to-image. (Still under exploration)", | |
| "Color Calibration provide an opinion to adjust image color according to reference image.", | |
| "Input prompt that will be substract from text/text latent code.", | |
| "Input prompt that will be added to text/text latent code.", ] | |
| if mode == "Text-to-Image": | |
| return '\n'.join(t2i_instruction) | |
| elif mode == "Image-Variation": | |
| return '\n'.join(i2i_instruction) | |
| elif mode == "Image-to-Text": | |
| return '\n'.join(i2t_instruction) | |
| elif mode == "Text-Variation": | |
| return '\n'.join(t2t_instruction) | |
| elif mode == "Disentanglement": | |
| return '\n'.join(dis_instruction) | |
| elif mode == "Dual-Guided": | |
| return '\n'.join(dug_instruction) | |
| elif mode == "Latent-I2T2I": | |
| return '\n'.join(iti_instruction) | |
| ############# | |
| # Interface # | |
| ############# | |
| if True: | |
| img_output = gr.Gallery(label="Image Result").style(grid=n_sample_image_default) | |
| txt_output = gr.Textbox(lines=4, label='Text Result', visible=False) | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <div style="position: relative; float: left; text-align: center; width: 60%; min-width:600px; height: 160px; margin: 20px 0 20px 20%;"> | |
| <h1 style="font-weight: 900; font-size: 3rem;"> | |
| Versatile Diffusion | |
| </h1> | |
| <br> | |
| <h2 style="font-weight: 450; font-size: 1rem;"> | |
| We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>. | |
| VD can natively support image-to-text, image-variation, text-to-image, and text-variation, | |
| and can be further extended to other applications such as | |
| semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more. | |
| Future versions will support more modalities such as speech, music, video and 3D. | |
| </h2> | |
| <br> | |
| <h3>Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang, | |
| and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a> | |
| [<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>] | |
| [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>] | |
| </h3> | |
| </div> | |
| <div style="position: relative; float: right; width: 19.9%; min-width:200px; margin: 20px auto;"> | |
| <img src="https://huggingface.co/spaces/shi-labs/Versatile-Diffusion/resolve/main/assets/figures/share_instruction.png"> | |
| </div> | |
| """) | |
| mode_input = gr.Radio([ | |
| "Text-to-Image", "Image-Variation", "Image-to-Text", "Text-Variation", | |
| "Disentanglement", "Dual-Guided", "Latent-I2T2I"], value='Text-to-Image', label="VD Flows and Applications") | |
| instruction = gr.Textbox(get_instruction("Text-to-Image"), label='Info') | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(label='Image Input', visible=False) | |
| txt_input = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input') | |
| ntxt_input = gr.Textbox(label='Remove Prompt', visible=False) | |
| ptxt_input = gr.Textbox(label='Add Prompt', visible=False) | |
| coladj_input = gr.Radio(["None", "Simple"], value='Simple', label="Color Calibration", visible=False) | |
| dislvl_input = gr.Slider(-2, 2, value=0, step=1, label="Disentanglement level", visible=False) | |
| dguide_input = gr.Slider(0, 1, value=0.5, step=0.01, label="Guidance Mixing", visible=False) | |
| seed_input = gr.Number(100, label="Seed", precision=0) | |
| btn = gr.Button("Run") | |
| btn.click( | |
| main, | |
| inputs=[ | |
| mode_input, | |
| img_input, | |
| txt_input, | |
| ntxt_input, | |
| ptxt_input, | |
| coladj_input, | |
| dislvl_input, | |
| dguide_input, | |
| seed_input, ], | |
| outputs=[img_output, txt_output]) | |
| with gr.Column(): | |
| img_output.render() | |
| txt_output.render() | |
| example_mode = [ | |
| "Text-to-Image", | |
| "Image-Variation", | |
| "Image-to-Text", | |
| "Text-Variation", | |
| "Disentanglement", | |
| "Dual-Guided", | |
| "Latent-I2T2I"] | |
| def get_example(mode): | |
| if mode == 'Text-to-Image': | |
| case = [ | |
| ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23], | |
| ['a beautiful grand nebula in the universe', 24], | |
| ['heavy arms gundam penguin mech', 25], | |
| ] | |
| elif mode == "Image-Variation": | |
| case = [ | |
| ['assets/space.jpg', 'None', 26], | |
| ['assets/train.jpg', 'Simple', 27], | |
| ] | |
| elif mode == "Image-to-Text": | |
| case = [ | |
| ['assets/boy_and_girl.jpg' , 28], | |
| ['assets/house_by_lake.jpg', 29], | |
| ] | |
| elif mode == "Text-Variation": | |
| case = [ | |
| ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ' , 32], | |
| ['a beautiful grand nebula in the universe' , 33], | |
| ['heavy arms gundam penguin mech', 34], | |
| ] | |
| elif mode == "Disentanglement": | |
| case = [ | |
| ['assets/vermeer.jpg', 'Simple', -2, 30], | |
| ['assets/matisse.jpg', 'Simple', 2, 31], | |
| ] | |
| elif mode == "Dual-Guided": | |
| case = [ | |
| ['assets/benz.jpg', 'cyberpunk 2077', 'Simple', 0.75, 22], | |
| ['assets/vermeer.jpg', 'a girl with a diamond necklace', 'Simple', 0.66, 21], | |
| ] | |
| elif mode == "Latent-I2T2I": | |
| case = [ | |
| ['assets/ghibli.jpg', 'white house', 'tall castle', 'Simple', 20], | |
| ['assets/matisse.jpg', 'fruits and bottles on the table', 'flowers on the table', 'Simple', 21], | |
| ] | |
| else: | |
| raise ValueError | |
| case = [[mode] + casei for casei in case] | |
| return case | |
| def get_example_iof(mode): | |
| if mode == 'Text-to-Image': | |
| inps = [txt_input, seed_input] | |
| oups = [img_output] | |
| fn = lambda m, x, y: \ | |
| main(mode=m, prompt=x, seed=y)[0] | |
| elif mode == "Image-Variation": | |
| inps = [img_input, coladj_input, seed_input] | |
| oups = [img_output] | |
| fn = lambda m, x, y, z: \ | |
| main(mode=m, image=x, color_adj=y, seed=z)[0] | |
| elif mode == "Image-to-Text": | |
| inps = [img_input, seed_input] | |
| oups = [txt_output] | |
| fn = lambda m, x, y: \ | |
| main(mode=m, image=x, seed=y)[1] | |
| elif mode == "Text-Variation": | |
| inps = [txt_input, seed_input] | |
| oups = [txt_output] | |
| fn = lambda m, x, y: \ | |
| main(mode=m, prompt=x, seed=y)[1] | |
| elif mode == "Disentanglement": | |
| inps = [img_input, coladj_input, dislvl_input, seed_input] | |
| oups = [img_output] | |
| fn = lambda m, x, y, z, w: \ | |
| main(mode=m, image=x, color_adj=y, disentanglement_level=z, seed=w)[0] | |
| elif mode == "Dual-Guided": | |
| inps = [img_input, txt_input, coladj_input, dguide_input, seed_input] | |
| oups = [img_output] | |
| fn = lambda m, x, y, z, w, u: \ | |
| main(mode=m, image=x, prompt=y, color_adj=z, dual_guided_mixing=w, seed=u)[0] | |
| elif mode == "Latent-I2T2I": | |
| inps = [img_input, ntxt_input, ptxt_input, coladj_input, seed_input] | |
| oups = [img_output] | |
| fn = lambda m, x, y, z, w, u: \ | |
| main(mode=m, image=x, nprompt=y, pprompt=z, color_adj=w, seed=u)[0] | |
| else: | |
| raise ValueError | |
| return [mode_input]+inps, oups, fn | |
| with gr.Row(): | |
| for emode in example_mode[0:4]: | |
| with gr.Column(): | |
| gr.Examples( | |
| label=emode+' Examples', | |
| examples=get_example(emode), | |
| inputs=get_example_iof(emode)[0], | |
| outputs=get_example_iof(emode)[1], | |
| fn = get_example_iof(emode)[2], | |
| cache_examples=cache_examples), | |
| with gr.Row(): | |
| for emode in example_mode[4:7]: | |
| with gr.Column(): | |
| gr.Examples( | |
| label=emode+' Examples', | |
| examples=get_example(emode), | |
| inputs=get_example_iof(emode)[0], | |
| outputs=get_example_iof(emode)[1], | |
| fn = get_example_iof(emode)[2], | |
| cache_examples=cache_examples), | |
| mode_input.change( | |
| fn=lambda x: gr.update(value=get_instruction(x)), | |
| inputs=mode_input, | |
| outputs=instruction,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Text-Variation'])), | |
| inputs=mode_input, | |
| outputs=img_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x in ['Text-to-Image', 'Text-Variation', 'Dual-Guided'])), | |
| inputs=mode_input, | |
| outputs=txt_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])), | |
| inputs=mode_input, | |
| outputs=ntxt_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])), | |
| inputs=mode_input, | |
| outputs=ptxt_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Image-to-Text', 'Text-Variation'])), | |
| inputs=mode_input, | |
| outputs=coladj_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x=='Disentanglement')), | |
| inputs=mode_input, | |
| outputs=dislvl_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x=='Dual-Guided')), | |
| inputs=mode_input, | |
| outputs=dguide_input,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x not in ['Image-to-Text', 'Text-Variation'])), | |
| inputs=mode_input, | |
| outputs=img_output,) | |
| mode_input.change( | |
| fn=lambda x: gr.update(visible=(x in ['Image-to-Text', 'Text-Variation'])), | |
| inputs=mode_input, | |
| outputs=txt_output,) | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; max-width: 1200px; margin: 20px auto;"> | |
| <h3> | |
| <b>Caution</b>: | |
| We would like the raise the awareness of users of this demo of its potential issues and concerns. | |
| Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope. | |
| In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data. | |
| So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future. | |
| We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors. | |
| </h3> | |
| <br> | |
| <h3> | |
| <b>Biases and content acknowledgement</b>: | |
| Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. | |
| VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content. | |
| VD in this demo is meant only for research purposes. | |
| </h3> | |
| </div> | |
| """) | |
| # demo.launch(share=True) | |
| demo.launch(debug=True) | |