import gc import math import sys from IPython import display import torch from torchvision import utils as tv_utils from torchvision.transforms import functional as TF import gradio as gr from git.repo.base import Repo from os.path import exists as path_exists if not (path_exists(f"v-diffusion-pytorch")): Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch") if not (path_exists(f"CLIP")): Repo.clone_from("https://github.com/openai/CLIP", "CLIP") sys.path.append('v-diffusion-pytorch') from huggingface_hub import hf_hub_download from CLIP import clip from diffusion import get_model, sampling, utils cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth") model = get_model('cc12m_1_cfg')() _, side_y, side_x = model.shape model.load_state_dict(torch.load(cc12m_model, map_location='cpu')) model = model.half().cuda().eval().requires_grad_(False) clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0] def run_all(prompt, steps, n_images, weight): import random seed = int(random.randint(0, 2147483647)) target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda() def cfg_model_fn(x, t): """The CFG wrapper function.""" n = x.shape[0] x_in = x.repeat([2, 1, 1, 1]) t_in = t.repeat([2]) clip_embed_repeat = target_embed.repeat([n, 1]) clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat]) v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0) v = v_uncond + (v_cond - v_uncond) * weight return v gc.collect() torch.cuda.empty_cache() torch.manual_seed(seed) x = torch.randn([n_images, 3, side_y, side_x], device='cuda') t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1] step_list = utils.get_spliced_ddpm_cosine_schedule(t) outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback) images_out = [] for i, out in enumerate(outs): images_out.append(utils.to_pil_image(out)) return(images_out) ##################### START GRADIO HERE ############################ #image = gr.outputs.Image(type="pil", label="Your result") gallery = gr.Gallery(css={"height": "256px","width":"256px"}) iface = gr.Interface( fn=run_all, inputs=[ gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"), gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=250,minimum=1,step=1), gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1,step=1), gr.inputs.Slider(label="Weight", default=5, maximum=15, minimum=0, step=1), #gr.inputs.Checkbox(label="CLIP Guided"), #gr.inputs.Dropdown(label="Flavor",choices=["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]), #markdown, #gr.inputs.Dropdown(label="Style",choices=["Default","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"],default="Hyper Fast Results"), #gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=512), #gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=512), ], outputs=gallery, title="Generate images from text with V-Diffusion CC12M CFG", description="