import argparse import json import os import random import gradio as gr from lib import ( CannyAnnotator, Config, async_call, disable_progress_bars, download_civit_file, download_repo_files, generate, get_valid_size, read_file, resize_image, ) # the CSS `content` attribute expects a string so we need to wrap the number in quotes refresh_seed_js = """ () => { const n = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); const button = document.getElementById("refresh"); button.style.setProperty("--seed", `"${n}"`); return n; } """ seed_js = """ (seed) => { const button = document.getElementById("refresh"); button.style.setProperty("--seed", `"${seed}"`); return seed; } """ aspect_ratio_js = """ (ar, w, h) => { if (!ar) return [w, h]; const [width, height] = ar.split(","); return [parseInt(width), parseInt(height)]; } """ def create_image_dropdown(images, locked=False): if locked: return gr.Dropdown( choices=[("🔒", -2)], interactive=False, value=-2, ) else: return gr.Dropdown( choices=[("None", -1)] + [(str(i + 1), i) for i, _ in enumerate(images or [])], interactive=True, value=-1, ) async def gallery_fn(images, image, ip_image): return ( create_image_dropdown(images, locked=image is not None), create_image_dropdown(images, locked=ip_image is not None), ) async def image_prompt_fn(images): return create_image_dropdown(images) # handle selecting an image from the gallery # -2 is the lock icon, -1 is None async def image_select_fn(images, image, i): if i == -2: return gr.Image(image) if i == -1: return gr.Image(None) return gr.Image(images[i][0]) if i > -1 else None async def random_fn(): prompts = read_file("data/prompts.json") prompts = json.loads(prompts) return gr.Textbox(value=random.choice(prompts)) # TODO: move this to another file once more annotators are added; will need @GPU decorator async def annotate_fn(image, annotator): size = get_valid_size(image) image = resize_image(image, size) if annotator == "canny": canny = CannyAnnotator() return canny(image, size) async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)): if len(args) > 0: prompt = args[0] else: prompt = None if prompt is None or prompt.strip() == "": raise gr.Error("You must enter a prompt") # always the last arguments DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-2:] gen_args = list(args[:-2]) if DISABLE_IMAGE_PROMPT: gen_args[2] = None if DISABLE_IP_IMAGE_PROMPT: gen_args[3] = None try: if Config.ZERO_GPU: progress((0, 100), desc="ZeroGPU init") images = await async_call( generate, *gen_args, Error=gr.Error, Info=gr.Info, progress=progress, ) except RuntimeError: raise gr.Error("Error: Please try again") return images with gr.Blocks( head=read_file("./partials/head.html"), css="./app.css", js="./app.js", theme=gr.themes.Default( # colors neutral_hue=gr.themes.colors.gray, primary_hue=gr.themes.colors.orange, secondary_hue=gr.themes.colors.blue, # sizing text_size=gr.themes.sizes.text_md, radius_size=gr.themes.sizes.radius_sm, spacing_size=gr.themes.sizes.spacing_md, # fonts font=[gr.themes.GoogleFont("Inter"), *Config.SANS_FONTS], font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *Config.MONO_FONTS], ).set( layout_gap="8px", block_shadow="0 0 #0000", block_shadow_dark="0 0 #0000", block_background_fill=gr.themes.colors.gray.c50, block_background_fill_dark=gr.themes.colors.gray.c900, ), ) as demo: # override image inputs without clearing them DISABLE_IMAGE_PROMPT = gr.State(False) DISABLE_IP_IMAGE_PROMPT = gr.State(False) gr.HTML(read_file("./partials/intro.html")) with gr.Tabs(): with gr.TabItem("🏠 Text"): with gr.Column(): output_images = gr.Gallery( elem_classes=["gallery"], show_share_button=False, object_fit="cover", interactive=False, show_label=False, label="Output", format="png", columns=2, ) prompt = gr.Textbox( placeholder="What do you want to see?", autoscroll=False, show_label=False, label="Prompt", max_lines=3, lines=3, ) # Buttons with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") random_btn = gr.Button( elem_classes=["icon-button", "popover"], variant="secondary", elem_id="random", min_width=0, value="🎲", ) refresh_btn = gr.Button( elem_classes=["icon-button", "popover"], variant="secondary", elem_id="refresh", min_width=0, value="🔄", ) clear_btn = gr.ClearButton( elem_classes=["icon-button", "popover"], components=[output_images], variant="secondary", elem_id="clear", min_width=0, value="🗑️", ) # img2img tab with gr.TabItem("🖼️ Image"): with gr.Row(): image_prompt = gr.Image( show_share_button=False, label="Initial Image", min_width=320, format="png", type="pil", ) ip_image_prompt = gr.Image( show_share_button=False, label="IP-Adapter Image", min_width=320, format="png", type="pil", ) with gr.Row(): image_select = gr.Dropdown( info="Use an initial image from the gallery", choices=[("None", -1)], label="Gallery Image", interactive=True, filterable=False, value=-1, ) ip_image_select = gr.Dropdown( info="Use an IP-Adapter image from the gallery", label="Gallery Image", choices=[("None", -1)], interactive=True, filterable=False, value=-1, ) with gr.Row(): denoising_strength = gr.Slider( value=Config.DENOISING_STRENGTH, label="Denoising Strength", minimum=0.0, maximum=1.0, step=0.1, ) with gr.Row(): disable_image = gr.Checkbox( elem_classes=["checkbox"], label="Disable Initial Image", value=False, ) disable_ip_image = gr.Checkbox( elem_classes=["checkbox"], label="Disable IP-Adapter Image", value=False, ) use_ip_face = gr.Checkbox( elem_classes=["checkbox"], label="Use IP-Adapter Face", value=False, ) # controlnet tab with gr.TabItem("🎮 Control"): with gr.Row(): control_image_input = gr.Image( show_share_button=False, label="Control Image", min_width=320, format="png", type="pil", ) control_image_prompt = gr.Image( interactive=False, show_share_button=False, label="Control Image Output", show_label=False, min_width=320, format="png", type="pil", ) with gr.Row(): control_annotator = gr.Dropdown( choices=[("Canny", "canny")], label="Annotator", filterable=False, value="canny", ) with gr.Row(): annotate_btn = gr.Button("Annotate", variant="primary") clear_control_btn = gr.ClearButton( elem_classes=["icon-button", "popover"], components=[control_image_prompt], variant="secondary", elem_id="clear-control", min_width=0, value="🗑️", ) with gr.TabItem("⚙️ Menu"): with gr.Group(): negative_prompt = gr.Textbox( value="nsfw+", label="Negative Prompt", lines=2, ) with gr.Row(): model = gr.Dropdown( choices=Config.MODELS, filterable=False, value=Config.MODEL, label="Model", min_width=240, ) scheduler = gr.Dropdown( choices=Config.SCHEDULERS.keys(), value=Config.SCHEDULER, elem_id="scheduler", label="Scheduler", filterable=False, ) with gr.Row(): styles = json.loads(read_file("data/styles.json")) style_ids = list(styles.keys()) style_ids = [sid for sid in style_ids if not sid.startswith("_")] style = gr.Dropdown( value=Config.STYLE, label="Style", min_width=240, choices=[("None", "none")] + [(styles[sid]["name"], sid) for sid in style_ids], ) embeddings = gr.Dropdown( elem_id="embeddings", label="Embeddings", choices=[(f"<{e}>", e) for e in Config.EMBEDDINGS], multiselect=True, value=[Config.EMBEDDING], min_width=240, ) with gr.Row(): with gr.Group(elem_classes=["gap-0"]): lora_1 = gr.Dropdown( min_width=240, label="LoRA #1", value="none", choices=[("None", "none")] + [ (lora["name"], lora_id) for lora_id, lora in Config.CIVIT_LORAS.items() ], ) lora_1_weight = gr.Slider( value=0.0, minimum=0.0, maximum=1.0, step=0.1, show_label=False, ) with gr.Group(elem_classes=["gap-0"]): lora_2 = gr.Dropdown( min_width=240, label="LoRA #2", value="none", choices=[("None", "none")] + [ (lora["name"], lora_id) for lora_id, lora in Config.CIVIT_LORAS.items() ], ) lora_2_weight = gr.Slider( value=0.0, minimum=0.0, maximum=1.0, step=0.1, show_label=False, ) with gr.Row(): guidance_scale = gr.Slider( value=Config.GUIDANCE_SCALE, label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, ) inference_steps = gr.Slider( value=Config.INFERENCE_STEPS, label="Inference Steps", minimum=1, maximum=50, step=1, ) deepcache_interval = gr.Slider( value=Config.DEEPCACHE_INTERVAL, label="DeepCache", minimum=1, maximum=4, step=1, ) with gr.Row(): width = gr.Slider( value=Config.WIDTH, label="Width", minimum=256, maximum=768, step=32, ) height = gr.Slider( value=Config.HEIGHT, label="Height", minimum=256, maximum=768, step=32, ) aspect_ratio = gr.Dropdown( value=f"{Config.WIDTH},{Config.HEIGHT}", label="Aspect Ratio", filterable=False, choices=[ ("Custom", None), ("4:7 (384x672)", "384,672"), ("7:9 (448x576)", "448,576"), ("1:1 (512x512)", "512,512"), ("9:7 (576x448)", "576,448"), ("7:4 (672x384)", "672,384"), ], ) with gr.Row(): file_format = gr.Dropdown( choices=["png", "jpeg", "webp"], label="File Format", filterable=False, value="png", ) num_images = gr.Dropdown( choices=list(range(1, 5)), value=Config.NUM_IMAGES, filterable=False, label="Images", ) scale = gr.Dropdown( choices=[(f"{s}x", s) for s in Config.SCALES], filterable=False, value=Config.SCALE, label="Scale", ) seed = gr.Number( value=Config.SEED, label="Seed", minimum=-1, maximum=(2**64) - 1, ) with gr.Row(): use_karras = gr.Checkbox( elem_classes=["checkbox"], label="Karras σ", value=True, ) use_taesd = gr.Checkbox( elem_classes=["checkbox"], label="Tiny VAE", value=False, ) use_freeu = gr.Checkbox( elem_classes=["checkbox"], label="FreeU", value=False, ) use_clip_skip = gr.Checkbox( elem_classes=["checkbox"], label="Clip skip", value=False, ) annotate_btn.click( annotate_fn, inputs=[control_image_input, control_annotator], outputs=[control_image_prompt], ) random_btn.click(random_fn, inputs=[], outputs=[prompt], show_api=False) refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js) seed.change(None, inputs=[seed], outputs=[], js=seed_js) file_format.change( lambda f: (gr.Gallery(format=f), gr.Image(format=f), gr.Image(format=f)), inputs=[file_format], outputs=[output_images, image_prompt, ip_image_prompt], show_api=False, ) # input events are only user input; change events are both user and programmatic aspect_ratio.input( None, inputs=[aspect_ratio, width, height], outputs=[width, height], js=aspect_ratio_js, ) # lock the input images so you don't lose them when the gallery updates output_images.change( gallery_fn, inputs=[output_images, image_prompt, ip_image_prompt], outputs=[image_select, ip_image_select], show_api=False, ) # show the selected image in the image input image_select.change( image_select_fn, inputs=[output_images, image_prompt, image_select], outputs=[image_prompt], show_api=False, ) ip_image_select.change( image_select_fn, inputs=[output_images, ip_image_prompt, ip_image_select], outputs=[ip_image_prompt], show_api=False, ) # reset the dropdown on clear image_prompt.clear( image_prompt_fn, inputs=[output_images], outputs=[image_select], show_api=False, ) ip_image_prompt.clear( image_prompt_fn, inputs=[output_images], outputs=[ip_image_select], show_api=False, ) # show "Custom" aspect ratio when manually changing width or height gr.on( triggers=[width.input, height.input], fn=None, inputs=[], outputs=[aspect_ratio], js="() => { return null; }", ) # toggle image prompts by updating session state gr.on( triggers=[disable_image.input, disable_ip_image.input], fn=lambda disable_image, disable_ip_image: (disable_image, disable_ip_image), inputs=[disable_image, disable_ip_image], outputs=[DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT], ) # generate images gr.on( triggers=[generate_btn.click, prompt.submit], fn=generate_fn, api_name="generate", outputs=[output_images], inputs=[ prompt, negative_prompt, image_prompt, ip_image_prompt, control_image_prompt, lora_1, lora_1_weight, lora_2, lora_2_weight, embeddings, style, seed, model, scheduler, control_annotator, width, height, guidance_scale, inference_steps, denoising_strength, deepcache_interval, scale, num_images, use_karras, use_taesd, use_freeu, use_clip_skip, use_ip_face, DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT, ], ) if __name__ == "__main__": parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) parser.add_argument("-s", "--server", type=str, metavar="STR", default="0.0.0.0") parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860) args = parser.parse_args() disable_progress_bars() for repo_id, allow_patterns in Config.HF_MODELS.items(): download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN) # download civit loras for lora_id, lora in Config.CIVIT_LORAS.items(): file_path = os.path.join(os.path.dirname(__file__), "loras") download_civit_file( lora_id, lora["model_version_id"], file_path=file_path, token=Config.CIVIT_TOKEN, ) # https://www.gradio.app/docs/gradio/interface#interface-queue demo.queue(default_concurrency_limit=1).launch( server_name=args.server, server_port=args.port, )