################################################################################ # Copyright (C) 2023 Xingqian Xu - All Rights Reserved # # # # Please visit Versatile Diffusion's arXiv paper for more details, link at # # arxiv.org/abs/2211.08332 # # # # Besides, this work is also inspired by many established techniques including:# # Denoising Diffusion Probablistic Model; Denoising Diffusion Implicit Model; # # Latent Diffusion Model; Stable Diffusion; Stable Diffusion - Img2Img; Stable # # Diffusion - Variation; ImageMixer; DreamBooth; Stable Diffusion - Lora; More # # Control for Free; Prompt-to-Prompt; # # # ################################################################################ import gradio as gr import os import PIL from PIL import Image from pathlib import Path import numpy as np import numpy.random as npr from contextlib import nullcontext import types import torch import torchvision.transforms as tvtrans from lib.cfg_helper import model_cfg_bank from lib.model_zoo import get_model from cusomized_gradio_blocks import create_myexamples, customized_as_example, customized_postprocess n_sample_image = 2 n_sample_text = 4 cache_examples = True from lib.model_zoo.ddim import DDIMSampler ########## # helper # ########## def highlight_print(info): print('') print(''.join(['#']*(len(info)+4))) print('# '+info+' #') print(''.join(['#']*(len(info)+4))) print('') def decompose(x, q=20, niter=100): x_mean = x.mean(-1, keepdim=True) x_input = x - x_mean u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) ss = torch.stack([torch.diag(si) for si in s]) x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) x_remain = x_input - x_lowrank return u, s, v, x_mean, x_remain class adjust_rank(object): def __init__(self, max_drop_rank=[1, 5], q=20): self.max_semantic_drop_rank = max_drop_rank[0] self.max_style_drop_rank = max_drop_rank[1] self.q = q def t2y0_semf_wrapper(t0, y00, t1, y01): return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00 t0, y00 = np.exp((0 -0.5)*2), -self.max_semantic_drop_rank t1, y01 = np.exp((0.5-0.5)*2), 1 self.t2y0_semf = t2y0_semf_wrapper(t0, y00, t1, y01) def x2y_semf_wrapper(x0, x1, y1): return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0 x0 = 0 x1, y1 = self.max_semantic_drop_rank+1, 1 self.x2y_semf = x2y_semf_wrapper(x0, x1, y1) def t2y0_styf_wrapper(t0, y00, t1, y01): return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00 t0, y00 = np.exp((1 -0.5)*2), -(q-self.max_style_drop_rank) t1, y01 = np.exp((0.5-0.5)*2), 1 self.t2y0_styf = t2y0_styf_wrapper(t0, y00, t1, y01) def x2y_styf_wrapper(x0, x1, y1): return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0 x0 = q-1 x1, y1 = self.max_style_drop_rank-1, 1 self.x2y_styf = x2y_styf_wrapper(x0, x1, y1) def __call__(self, x, lvl): if lvl == 0.5: return x if x.dtype == torch.float16: fp16 = True x = x.float() else: fp16 = False std_save = x.std(axis=[-2, -1]) u, s, v, x_mean, x_remain = decompose(x, q=self.q) if lvl < 0.5: assert lvl>=0 for xi in range(0, self.max_semantic_drop_rank+1): y0 = self.t2y0_semf(lvl) yi = self.x2y_semf(xi, y0) yi = 0 if yi<0 else yi s[:, xi] *= yi elif lvl > 0.5: assert lvl <= 1 for xi in range(self.max_style_drop_rank, self.q): y0 = self.t2y0_styf(lvl) yi = self.x2y_styf(xi, y0) yi = 0 if yi<0 else yi s[:, xi] *= yi x_remain = 0 ss = torch.stack([torch.diag(si) for si in s]) x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) x_new = x_lowrank + x_mean + x_remain std_new = x_new.std(axis=[-2, -1]) x_new = x_new / std_new * std_save if fp16: x_new = x_new.half() return x_new def remove_duplicate_word(tx): def combine_words(input, length): combined_inputs = [] if len(splitted_input)>1: for i in range(len(input)-1): combined_inputs.append(input[i]+" "+last_word_of(splitted_input[i+1],length)) #add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence return combined_inputs, length+1 def remove_duplicates(input, length): bool_broke=False #this means we didn't find any duplicates here for i in range(len(input) - length): if input[i]==input[i + length]: #found a duplicate piece of sentence! for j in range(0, length): #remove the overlapping sequences in reverse order del input[i + length - j] bool_broke = True break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements if bool_broke: return remove_duplicates(input, length) #if we found a duplicate, look for another duplicate of the same length return input def last_word_of(input, length): splitted = input.split(" ") if len(splitted)==0: return input else: return splitted[length-1] def split_and_puncsplit(text): tx = text.split(" ") txnew = [] for txi in tx: txqueue=[] while True: if txi[0] in '([{': txqueue.extend([txi[:1], '']) txi = txi[1:] if len(txi) == 0: break else: break txnew += txqueue txstack=[] if len(txi) == 0: continue while True: if txi[-1] in '?!.,:;}])': txstack = ['', txi[-1:]] + txstack txi = txi[:-1] if len(txi) == 0: break else: break if len(txi) != 0: txnew += [txi] txnew += txstack return txnew if tx == '': return tx splitted_input = split_and_puncsplit(tx) word_length = 1 intermediate_output = False while len(splitted_input)>1: splitted_input = remove_duplicates(splitted_input, word_length) if len(splitted_input)>1: splitted_input, word_length = combine_words(splitted_input, word_length) if intermediate_output: print(splitted_input) print(word_length) output = splitted_input[0] output = output.replace(' ', '') return output def get_instruction(mode): t2i_instruction = ["Generate image from text prompt."] i2i_instruction = ["Generate image conditioned on reference image.",] i2t_instruction = ["Generate text from reference image. "] t2t_instruction = ["Generate text from reference text prompt. "] dcg_instruction = ["Generate image conditioned on both text and image."] tcg_instruction = ["Generate image conditioned on text and up to two images."] mcg_instruction = ["Generate image from multiple contexts."] if mode == "Text-to-Image": return '\n'.join(t2i_instruction) elif mode == "Image-Variation": return '\n'.join(i2i_instruction) elif mode == "Image-to-Text": return '\n'.join(i2t_instruction) elif mode == "Text-Variation": return '\n'.join(t2t_instruction) elif mode == "Dual-Context": return '\n'.join(dcg_instruction) elif mode == "Triple-Context": return '\n'.join(tcg_instruction) elif mode == "Multi-Context": return '\n'.join(mcg_instruction) else: assert False ######## # main # ######## class vd_dummy(object): def __init__(self, *args, **kwarg): self.which = 'Vdummy' def inference_t2i(self, *args, **kwarg): pass def inference_i2i(self, *args, **kwarg): pass def inference_i2t(self, *args, **kwarg): pass def inference_t2t(self, *args, **kwarg): pass def inference_dcg(self, *args, **kwarg): pass def inference_tcg(self, *args, **kwarg): pass def inference_mcg(self, *args, **kwarg): return None, None class vd_inference(object): def __init__(self, fp16=False, which='v2.0'): highlight_print(which) self.which = which if self.which == 'v1.0': cfgm = model_cfg_bank()('vd_four_flow_v1-0') else: assert False, 'Model type not supported' net = get_model()(cfgm) if fp16: highlight_print('Running in FP16') if self.which == 'v1.0': net.ctx['text'].fp16 = True net.ctx['image'].fp16 = True net = net.half() self.dtype = torch.float16 else: self.dtype = torch.float32 if self.which == 'v1.0': # if fp16: # sd = torch.load('pretrained/vd-four-flow-v1-0-fp16.pth', map_location='cpu') # else: # sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu') from huggingface_hub import hf_hub_download if fp16: temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth') else: temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0.pth') sd = torch.load(temppath, map_location='cpu') net.load_state_dict(sd, strict=False) self.use_cuda = torch.cuda.is_available() if self.use_cuda: net.to('cuda') self.net = net self.sampler = DDIMSampler(net) self.output_dim = [512, 512] self.n_sample_image = n_sample_image self.n_sample_text = n_sample_text self.ddim_steps = 50 self.ddim_eta = 0.0 self.scale_textto = 7.5 self.image_latent_dim = 4 self.text_latent_dim = 768 self.text_temperature = 1 if which == 'v1.0': self.adjust_rank_f = adjust_rank(max_drop_rank=[1, 5], q=20) self.scale_imgto = 7.5 self.disentanglement_noglobal = True def inference_t2i(self, text, seed): n_samples = self.n_sample_image scale = self.scale_textto sampler = self.sampler h, w = self.output_dim u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1) c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1) shape = [n_samples, self.image_latent_dim, h//8, w//8] np.random.seed(seed) torch.manual_seed(seed + 100) x, _ = sampler.sample( steps=self.ddim_steps, x_info={'type':'image'}, c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u, 'unconditional_guidance_scale':scale}, shape=shape, verbose=False, eta=self.ddim_eta) im = self.net.vae_decode(x, which='image') im = [tvtrans.ToPILImage()(i) for i in im] return im def inference_i2i(self, im, fid_lvl, fcs_lvl, clr_adj, seed): n_samples = self.n_sample_image scale = self.scale_imgto sampler = self.sampler h, w = self.output_dim device = self.net.device BICUBIC = PIL.Image.Resampling.BICUBIC im = im.resize([w, h], resample=BICUBIC) if fid_lvl == 1: return [im]*n_samples cx = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype) c = self.net.ctx_encode(cx, which='image') if self.disentanglement_noglobal: c_glb = c[:, 0:1] c_loc = c[:, 1: ] c_loc = self.adjust_rank_f(c_loc, fcs_lvl) c = torch.cat([c_glb, c_loc], dim=1).repeat(n_samples, 1, 1) else: c = self.adjust_rank_f(c, fcs_lvl).repeat(n_samples, 1, 1) u = torch.zeros_like(c) shape = [n_samples, self.image_latent_dim, h//8, w//8] np.random.seed(seed) torch.manual_seed(seed + 100) if fid_lvl!=0: x0 = self.net.vae_encode(cx, which='image').repeat(n_samples, 1, 1, 1) step = int(self.ddim_steps * (1-fid_lvl)) x, _ = sampler.sample( steps=self.ddim_steps, x_info={'type':'image', 'x0':x0, 'x0_forward_timesteps':step}, c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u, 'unconditional_guidance_scale':scale}, shape=shape, verbose=False, eta=self.ddim_eta) else: x, _ = sampler.sample( steps=self.ddim_steps, x_info={'type':'image',}, c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u, 'unconditional_guidance_scale':scale}, shape=shape, verbose=False, eta=self.ddim_eta) imout = self.net.vae_decode(x, which='image') if clr_adj == 'Simple': cx_mean = cx.view(3, -1).mean(-1)[:, None, None] cx_std = cx.view(3, -1).std(-1)[:, None, None] imout_mean = [imouti.view(3, -1).mean(-1)[:, None, None] for imouti in imout] imout_std = [imouti.view(3, -1).std(-1)[:, None, None] for imouti in imout] imout = [(ii-mi)/si*cx_std+cx_mean for ii, mi, si in zip(imout, imout_mean, imout_std)] imout = [torch.clamp(ii, 0, 1) for ii in imout] imout = [tvtrans.ToPILImage()(i) for i in imout] return imout def inference_i2t(self, im, seed): n_samples = self.n_sample_text scale = self.scale_imgto sampler = self.sampler h, w = self.output_dim device = self.net.device BICUBIC = PIL.Image.Resampling.BICUBIC im = im.resize([w, h], resample=BICUBIC) cx = tvtrans.ToTensor()(im)[None].to(device) c = self.net.ctx_encode(cx, which='image').repeat(n_samples, 1, 1) u = self.net.ctx_encode(torch.zeros_like(cx), which='image').repeat(n_samples, 1, 1) shape = [n_samples, self.text_latent_dim] np.random.seed(seed) torch.manual_seed(seed + 100) x, _ = sampler.sample( steps=self.ddim_steps, x_info={'type':'text',}, c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u, 'unconditional_guidance_scale':scale}, shape=shape, verbose=False, eta=self.ddim_eta) tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature) tx = [remove_duplicate_word(txi) for txi in tx] tx_combined = '\n'.join(tx) return tx_combined def inference_t2t(self, text, seed): n_samples = self.n_sample_text scale = self.scale_textto sampler = self.sampler u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1) c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1) shape = [n_samples, self.text_latent_dim] np.random.seed(seed) torch.manual_seed(seed + 100) x, _ = sampler.sample( steps=self.ddim_steps, x_info={'type':'text',}, c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u, 'unconditional_guidance_scale':scale}, shape=shape, verbose=False, eta=self.ddim_eta) tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature) tx = [remove_duplicate_word(txi) for txi in tx] tx_combined = '\n'.join(tx) return tx_combined def inference_dcg(self, imctx, fcs_lvl, textctx, textstrength, seed): n_samples = self.n_sample_image sampler = self.sampler h, w = self.output_dim device = self.net.device c_info_list = [] if (textctx is not None) and (textctx != "") and (textstrength != 0): ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1) ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1) scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength c_info_list.append({ 'type':'text', 'conditioning':ct, 'unconditional_conditioning':ut, 'unconditional_guidance_scale':scale, 'ratio': textstrength, }) else: scale = self.scale_imgto textstrength = 0 BICUBIC = PIL.Image.Resampling.BICUBIC cx = imctx.resize([w, h], resample=BICUBIC) cx = tvtrans.ToTensor()(cx)[None].to(device).to(self.dtype) ci = self.net.ctx_encode(cx, which='image') if self.disentanglement_noglobal: ci_glb = ci[:, 0:1] ci_loc = ci[:, 1: ] ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl) ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1) else: ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1) c_info_list.append({ 'type':'image', 'conditioning':ci, 'unconditional_conditioning':torch.zeros_like(ci), 'unconditional_guidance_scale':scale, 'ratio': (1-textstrength), }) shape = [n_samples, self.image_latent_dim, h//8, w//8] np.random.seed(seed) torch.manual_seed(seed + 100) x, _ = sampler.sample_multicontext( steps=self.ddim_steps, x_info={'type':'image',}, c_info_list=c_info_list, shape=shape, verbose=False, eta=self.ddim_eta) imout = self.net.vae_decode(x, which='image') imout = [tvtrans.ToPILImage()(i) for i in imout] return imout def inference_tcg(self, *args): args_imag = list(args[0:10]) + [None, None, None, None, None]*2 args_rest = args[10:] imin, imout = self.inference_mcg(*args_imag, *args_rest) return imin, imout def inference_mcg(self, *args): imctx = [args[0:5], args[5:10], args[10:15], args[15:20]] textctx, textstrength, seed = args[20:] n_samples = self.n_sample_image sampler = self.sampler h, w = self.output_dim device = self.net.device c_info_list = [] if (textctx is not None) and (textctx != "") and (textstrength != 0): ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1) ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1) scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength c_info_list.append({ 'type':'text', 'conditioning':ct, 'unconditional_conditioning':ut, 'unconditional_guidance_scale':scale, 'ratio': textstrength, }) else: scale = self.scale_imgto textstrength = 0 input_save = [] imc = [] for im, imm, strength, fcs_lvl, use_mask in imctx: if (im is None) and (imm is None): continue BILINEAR = PIL.Image.Resampling.BILINEAR BICUBIC = PIL.Image.Resampling.BICUBIC if use_mask: cx = imm['image'].resize([w, h], resample=BICUBIC) cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device) m = imm['mask'].resize([w, h], resample=BILINEAR) m = tvtrans.ToTensor()(m)[None, 0:1].to(self.dtype).to(device) m = (1-m) cx_show = cx*m ci = self.net.ctx_encode(cx, which='image', masks=m) else: cx = im.resize([w, h], resample=BICUBIC) cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device) ci = self.net.ctx_encode(cx, which='image') cx_show = cx input_save.append(tvtrans.ToPILImage()(cx_show[0])) if self.disentanglement_noglobal: ci_glb = ci[:, 0:1] ci_loc = ci[:, 1: ] ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl) ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1) else: ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1) imc.append(ci * strength) cis = torch.cat(imc, dim=1) c_info_list.append({ 'type':'image', 'conditioning':cis, 'unconditional_conditioning':torch.zeros_like(cis), 'unconditional_guidance_scale':scale, 'ratio': (1-textstrength), }) shape = [n_samples, self.image_latent_dim, h//8, w//8] np.random.seed(seed) torch.manual_seed(seed + 100) x, _ = sampler.sample_multicontext( steps=self.ddim_steps, x_info={'type':'image',}, c_info_list=c_info_list, shape=shape, verbose=False, eta=self.ddim_eta) imout = self.net.vae_decode(x, which='image') imout = [tvtrans.ToPILImage()(i) for i in imout] return input_save, imout # vd_inference = vd_dummy() vd_inference = vd_inference(which='v1.0', fp16=True) ################# # sub interface # ################# def t2i_interface(with_example=False): gr.HTML('

  Description: ' + get_instruction("Text-to-Image") + '

') with gr.Row(): with gr.Column(): text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input') seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(): img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image) button.click( vd_inference.inference_t2i, inputs=[text, seed], outputs=[img_output]) if with_example: gr.Examples( label='Examples', examples=get_example('Text-to-Image'), fn=vd_inference.inference_t2i, inputs=[text, seed], outputs=[img_output], cache_examples=cache_examples), def i2i_interface(with_example=False): gr.HTML('

  Description: ' + get_instruction("Image-Variation") + '

') with gr.Row(): with gr.Column(): img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') sim_flag = gr.Checkbox(label='Show Detail Controls') with gr.Row(): fid_lvl = gr.Slider(label="Fidelity (Dislike -- Same)", minimum=0, maximum=1, value=0, step=0.02, visible=False) fcs_lvl = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02, visible=False) clr_adj = gr.Radio(label="Color Adjustment", choices=["None", "Simple"], value='Simple', visible=False) explain = gr.HTML('

  Fidelity: How likely the output image looks like the referece image (0-dislike (default), 1-same).

'+ '

  Focus: What the output image should focused on (0-semantic, 0.5-balanced (default), 1-style).

', visible=False) seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(): img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image) sim_flag.change( fn=lambda x: { explain : gr.update(visible=x), fid_lvl : gr.update(visible=x), fcs_lvl : gr.update(visible=x), clr_adj : gr.update(visible=x), }, inputs=sim_flag, outputs=[explain, fid_lvl, fcs_lvl, clr_adj, seed],) button.click( vd_inference.inference_i2i, inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed], outputs=[img_output]) if with_example: gr.Examples( label='Examples', examples=get_example('Image-Variation'), fn=vd_inference.inference_i2i, inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed], outputs=[img_output], cache_examples=cache_examples), def i2t_interface(with_example=False): gr.HTML('

  Description: ' + get_instruction("Image-to-Text") + '

') with gr.Row(): with gr.Column(): img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(): txt_output = gr.Textbox(lines=4, label='Text Result') button.click( vd_inference.inference_i2t, inputs=[img_input, seed], outputs=[txt_output]) if with_example: gr.Examples( label='Examples', examples=get_example('Image-to-Text'), fn=vd_inference.inference_i2t, inputs=[img_input, seed], outputs=[txt_output], cache_examples=cache_examples), def t2t_interface(with_example=False): gr.HTML('

  Description: ' + get_instruction("Text-Variation") + '

') with gr.Row(): with gr.Column(): text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input') seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(): txt_output = gr.Textbox(lines=4, label='Text Result') button.click( vd_inference.inference_t2t, inputs=[text, seed], outputs=[txt_output]) if with_example: gr.Examples( label='Examples', examples=get_example('Text-Variation'), fn=vd_inference.inference_t2t, inputs=[text, seed], outputs=[txt_output], cache_examples=cache_examples, ) class image_mimage_swap(object): def __init__(self, block0, block1): self.block0 = block0 self.block1 = block1 self.which_update = 'both' def __call__(self, x0, x1, flag): if self.which_update == 'both': return self.update_both(x0, x1, flag) elif self.which_update == 'visible': return self.update_visible(x0, x1, flag) elif self.which_update == 'visible_oneoff': return self.update_visible_oneoff(x0, x1, flag) else: assert False def update_both(self, x0, x1, flag): if flag: ug0 = gr.update(visible=False) if x0 is None: ug1 = gr.update(value=None, visible=True) else: if (x1 is not None) and ('mask' in x1): value1 = {'image':x0, 'mask':x1['mask']} else: value1 = {'image':x0, 'mask':None} ug1 = gr.update(value=value1, visible=True) else: if (x1 is not None) and ('image' in x1): value0 = x1['image'] else: value0 = None ug0 = gr.update(value=value0, visible=True) ug1 = gr.update(visible=False) return { self.block0 : ug0, self.block1 : ug1,} def update_visible(self, x0, x1, flag): return { self.block0 : gr.update(visible=not flag), self.block1 : gr.update(visible=flag), } def update_visible_oneoff(self, x0, x1, flag): self.which_update = 'both' return { self.block0 : gr.update(visible=not flag), self.block1 : gr.update(visible=flag), } class example_visible_only_hack(object): def __init__(self, checkbox_list, functor_list): self.checkbox_list = checkbox_list self.functor_list = functor_list def __call__(self, *args): for bi, fi, vi in zip(self.checkbox_list, self.functor_list, args): if bi.value != vi: fi.which_update = 'visible_oneoff' def dcg_interface(with_example=False): gr.HTML('

  Description: ' + get_instruction("Dual-Context") + '

') with gr.Row(): input_session = [] with gr.Column(): img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02) gr.HTML('

  Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).

') text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input') tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02) seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(): output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image) input_list = [] for i in input_session: input_list += i button.click( vd_inference.inference_dcg, inputs=[img, fcs, text, tstrength, seed], outputs=[output_gallary]) if with_example: gr.Examples( label='Examples', examples=get_example('Dual-Context'), fn=vd_inference.inference_dcg, inputs=[img, fcs, text, tstrength, seed], outputs=[output_gallary], cache_examples=cache_examples) def tcg_interface(with_example=False): gr.HTML('

  Description: ' + get_instruction("Triple-Context") + '

') with gr.Row(): input_session = [] with gr.Column(min_width=940): with gr.Row(): with gr.Column(): img0 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') img0.as_example = types.MethodType(customized_as_example, img0) imgm0 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False) imgm0.postprocess = types.MethodType(customized_postprocess, imgm0) imgm0.as_example = types.MethodType(customized_as_example, imgm0) istrength0 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02) fcs0 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02) msk0 = gr.Checkbox(label='Use mask?') swapf0 = image_mimage_swap(img0, imgm0) msk0.change( fn=swapf0, inputs=[img0, imgm0, msk0], outputs=[img0, imgm0],) input_session.append([img0, imgm0, istrength0, fcs0, msk0]) with gr.Column(): img1 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') img1.as_example = types.MethodType(customized_as_example, img1) imgm1 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False) imgm1.postprocess = types.MethodType(customized_postprocess, imgm1) imgm1.as_example = types.MethodType(customized_as_example, imgm1) istrength1 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02) fcs1 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02) msk1 = gr.Checkbox(label='Use mask?') swapf1 = image_mimage_swap(img1, imgm1) msk1.change( fn=swapf1, inputs=[img1, imgm1, msk1], outputs=[img1, imgm1],) input_session.append([img1, imgm1, istrength1, fcs1, msk1]) gr.HTML('

  Weight: The strength of the reference image. This weight is subject to Text Domination).

'+ '

  Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).

'+ '

  Mask: Remove regions on reference image so they will not influence the output.

',) text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input') tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02) seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(min_width=470): input_gallary = gr.Gallery(label="Input Display", elem_id="customized_imbox").style(grid=2) output_gallary = gr.Gallery(label="Image Result", elem_id="customized_imbox").style(grid=n_sample_image) input_list = [] for i in input_session: input_list += i input_list += [text, tstrength, seed] button.click( vd_inference.inference_tcg, inputs=input_list, outputs=[input_gallary, output_gallary]) if with_example: create_myexamples( label='Examples', examples=get_example('Triple-Context'), fn=vd_inference.inference_tcg, inputs=input_list, outputs=[input_gallary, output_gallary, ], cache_examples=cache_examples, ) gr.HTML('

  How to add mask: Please see the following instructions.


'+ '
'+ ''+ ''+ ''+ '
') def mcg_interface(with_example=False): num_img_input = 4 gr.HTML('

  Description: ' + get_instruction("Multi-Context") + '

') with gr.Row(): input_session = [] with gr.Column(): for idx in range(num_img_input): with gr.Tab('Image{}'.format(idx+1)): img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') img.as_example = types.MethodType(customized_as_example, img) imgm = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False) imgm.postprocess = types.MethodType(customized_postprocess, imgm) imgm.as_example = types.MethodType(customized_as_example, imgm) with gr.Row(): istrength = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02) fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02) msk = gr.Checkbox(label='Use mask?') gr.HTML('

  Weight: The strength of the reference image. This weight is subject to Text Domination).

'+ '

  Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).

'+ '

  Mask: Remove regions on reference image so they will not influence the output.

',) msk.change( fn=image_mimage_swap(img, imgm), inputs=[img, imgm, msk], outputs=[img, imgm],) input_session.append([img, imgm, istrength, fcs, msk]) text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input') tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02) seed = gr.Number(20, label="Seed", precision=0) button = gr.Button("Run") with gr.Column(): input_gallary = gr.Gallery(label="Input Display", elem_id='customized_imbox').style(grid=4) output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image) input_list = [] for i in input_session: input_list += i input_list += [text, tstrength, seed] button.click( vd_inference.inference_mcg, inputs=input_list, outputs=[input_gallary, output_gallary], ) if with_example: create_myexamples( label='Examples', examples=get_example('Multi-Context'), fn=vd_inference.inference_mcg, inputs=input_list, outputs=[input_gallary, output_gallary], cache_examples=cache_examples, ) gr.HTML('

  How to add mask: Please see the following instructions.


'+ '
'+ ''+ ''+ ''+ '
') ########### # Example # ########### def get_example(mode): if mode == 'Text-to-Image': case = [ ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23], ['a beautiful landscape with mountains and rivers', 20], ] elif mode == "Image-Variation": case = [ ['assets/demo/reg_example/ghibli.jpg', 0, 0.5, 'None', 20], ['assets/demo/reg_example/ghibli.jpg', 0.5, 0.5, 'None', 20], ['assets/demo/reg_example/matisse.jpg', 0, 0, 'None', 20], ['assets/demo/reg_example/matisse.jpg', 0, 1, 'Simple', 20], ['assets/demo/reg_example/vermeer.jpg', 0.2, 0.3, 'None', 30], ] elif mode == "Image-to-Text": case = [ ['assets/demo/reg_example/house_by_lake.jpg', 20], ] elif mode == "Text-Variation": case = [ ['heavy arms gundam penguin mech', 20], ] elif mode == "Dual-Context": case = [ ['assets/demo/reg_example/benz.jpg', 0.5, 'cyberpunk 2077', 0.7, 22], ['assets/demo/reg_example/ghibli.jpg', 1, 'Red maple on a hill in golden Autumn.', 0.5, 21], ] elif mode == "Triple-Context": case = [ [ 'assets/demo/reg_example/night_light.jpg', None, 1 , 0.5, False, 'assets/demo/reg_example/paris.jpg' , None, 0.94, 0.5, False, "snow on the street", 0.4, 28], [ 'assets/demo/tcg_example/e1i0.jpg', None, 1 , 0.5, False, 'assets/demo/tcg_example/e1i1.jpg', None, 0.94, 0.5, False, "a painting of an elegant woman in front of the moon", 0.2, 217], [ 'assets/demo/tcg_example/e2i0.jpg', None, 1, 0.5, False, 'assets/demo/reg_example/paris.jpg', None, 1, 0.5, False, "", 0, 29], [ 'assets/demo/tcg_example/e0i0.jpg', None, 1 , 0.5, False, 'assets/demo/tcg_example/e0i1.jpg', None, 0.9, 0.5, False, "rose blooms on the tree", 0.2, 20], [ 'assets/demo/reg_example/ghibli.jpg', None, 1 , 1 , False, 'assets/demo/reg_example/space.jpg' , None, 0.84, 0.5, False, "", 0, 20], [ 'assets/demo/reg_example/train.jpg' , None, 0.8, 0.5, False, 'assets/demo/reg_example/matisse.jpg', None, 1 , 1 , False, "", 0, 20], ] elif mode == "Multi-Context": case = [ [ 'assets/demo/mcg_example/e0i0.jpg', None, 1, 0.5, False, 'assets/demo/mcg_example/e0i1.jpg', None, 1, 0.5, False, 'assets/demo/mcg_example/e0i2.jpg', None, 0.86, 0.5, False, None, None, 1, 0.5, False, "", 0, 20], ] else: raise ValueError return case ############# # Interface # ############# css = """ #customized_imbox { min-height: 450px; } #customized_imbox>div[data-testid="image"] { min-height: 450px; } #customized_imbox>div[data-testid="image"]>div { min-height: 450px; } #customized_imbox>div[data-testid="image"]>iframe { min-height: 450px; } #customized_imbox>div.unpadded_box { min-height: 450px; } #myinst { font-size: 0.8rem; margin: 0rem; color: #6B7280; } #maskinst { text-align: justify; min-width: 1200px; } #maskinst>img { min-width:399px; max-width:450px; vertical-align: top; display: inline-block; } #maskinst:after { content: ""; width: 100%; display: inline-block; } """ if True: with gr.Blocks(css=css) as demo: gr.HTML( """

Versatile Diffusion

We built Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework, as a step towards Universal Generative AI. VD can natively support image-to-text, image-variation, text-to-image, and text-variation, and can be further extended to other applications such as semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more. Future versions will support more modalities such as speech, music, video and 3D.

Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang, and Humphrey Shi [arXiv] [GitHub]

""") with gr.Tab('Text-to-Image'): t2i_interface(with_example=True) with gr.Tab('Image-Variation'): i2i_interface(with_example=True) with gr.Tab('Image-to-Text'): i2t_interface(with_example=True) with gr.Tab('Text-Variation'): t2t_interface(with_example=True) with gr.Tab('Dual-Context Image-Generation'): dcg_interface(with_example=True) with gr.Tab('Triple-Context Image-Blender'): tcg_interface(with_example=True) with gr.Tab('Multi-Context Image-Blender'): mcg_interface(with_example=True) gr.HTML( """

Version: {}

Caution: We would like the raise the awareness of users of this demo of its potential issues and concerns. Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope. In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data. So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future. We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.

Biases and content acknowledgement: Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content. VD in this demo is meant only for research purposes.

""".format(' '+vd_inference.which)) demo.launch(share=True) # demo.launch(debug=True)