import gradio as gr import torch from omegaconf import OmegaConf from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt import json import numpy as np from PIL import Image, ImageDraw, ImageFont from functools import partial from collections import Counter import math import gc from gradio import processing_utils from typing import Optional import warnings from datetime import datetime from huggingface_hub import hf_hub_download hf_hub_download = partial(hf_hub_download, library_name="gligen_demo") import sys class ImageMask(gr.components.Image): is_template = True def __init__(self, **kwargs): super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) def preprocess(self, x): if x is None: return x if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict: decode_image = processing_utils.decode_base64_to_image(x) width, height = decode_image.size mask = np.zeros((height, width, 4), dtype=np.uint8) mask[..., -1] = 255 mask = self.postprocess(mask) x = {'image': x, 'mask': mask} return super().preprocess(x) with gr.Blocks( analytics_enabled=False, title="GLIGen demo", ) as main: with gr.Row(): with gr.Column(scale=4): sketch_pad_trigger = gr.Number(value=0, visible=False) sketch_pad_resize_trigger = gr.Number(value=0, visible=False) init_white_trigger = gr.Number(value=0, visible=False) image_scale = gr.Number(value=0, elem_id="image_scale", visible=False) new_image_trigger = gr.Number(value=0, visible=False) task = gr.Radio( choices=["Grounded Generation", 'Grounded Inpainting'], type="value", value="Grounded Generation", label="Task", ) language_instruction = gr.Textbox( label="Language instruction", ) grounding_instruction = gr.Textbox( label="Grounding instruction (Separated by semicolon)", ) with gr.Row(): sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image") out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad") with gr.Row(): clear_btn = gr.Button(value='Clear') gen_btn = gr.Button(value='Generate') with gr.Accordion("Advanced Options", open=False): with gr.Column(): alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)") guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale") batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples") append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption") use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False) with gr.Row(): fix_seed = gr.Checkbox(value=True, label="Fixed seed") rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed") with gr.Row(visible=False): use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition") style_cond_image = gr.Image(type="pil", label="Style Condition", visible=False, interactive=True) with gr.Column(scale=4): gr.HTML('Generated Images') with gr.Row(): out_gen_1 = gr.Image(type="pil", visible=True, show_label=False) out_gen_2 = gr.Image(type="pil", visible=True, show_label=False) with gr.Row(): out_gen_3 = gr.Image(type="pil", visible=False, show_label=False) out_gen_4 = gr.Image(type="pil", visible=False, show_label=False) state = gr.State({}) main.queue(concurrency_count=1, api_open=False).launch(share=False, show_api=False, show_error=True)