##!/usr/bin/python3 # -*- coding: utf-8 -*- import os, random, sys import numpy as np import requests import torch import spaces import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download, snapshot_download from scipy.ndimage import binary_dilation, binary_erosion from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration, Qwen2VLProcessor) from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler from diffusers.image_processor import VaeImageProcessor from app.src.vlm_pipeline import ( vlm_response_editing_type, vlm_response_object_wait_for_edit, vlm_response_mask, vlm_response_prompt_after_apply_instruction ) from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline from app.utils.utils import load_grounding_dino_model from app.src.vlm_template import vlms_template from app.src.base_model_template import base_models_template from app.src.aspect_ratio_template import aspect_ratios from openai import OpenAI # base_openai_url = "" #### Description #### logo = r"""
BrushEdit logo
""" head = r"""

BrushEdit: All-In-One Image Inpainting and Editing

Project Page

""" descriptions = r""" Official Gradio Demo for BrushEdit: All-In-One Image Inpainting and Editing
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.
""" instructions = r""" Currently, we support two modes: fully automated command editing and interactive command editing. 🛠️ Fully automated instruction-based editing: 🛠️ Interactive instruction-based editing: We strongly recommend using GPT-4o for reasoning. After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model. We recommend zooming out in your browser for a better viewing range and experience. For more detailed feature descriptions, see the bottom. ☕️ Have fun! 🎄 Wishing you a merry Christmas! """ tips = r""" 💡 Some Tips: 💡 Detailed Features: 💡 Advanced Features: """ citation = r""" If BrushEdit is helpful, please help to ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/BrushEdit?style=social)](https://github.com/TencentARC/BrushEdit) --- 📝 **Citation**
If our work is useful for your research, please consider citing: ```bibtex @misc{li2024brushedit, title={BrushEdit: All-In-One Image Inpainting and Editing}, author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu}, year={2024}, eprint={2412.10316}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` 📧 **Contact**
If you have any questions, please feel free to reach me out at liyaowei@gmail.com. """ # - - - - - examples - - - - - # EXAMPLES = [ [ Image.open("./assets/frog/frog.jpeg").convert("RGBA"), "add a magic hat on frog head.", 642087011, "frog", "frog", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"), "replace the background to ancient China.", 648464818, "chinese_girl", "chinese_girl", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"), "remove the deer.", 648464818, "angel_christmas", "angel_christmas", False, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"), "add a wreath on head.", 648464818, "sunflower_girl", "sunflower_girl", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"), "add a butterfly fairy.", 648464818, "girl_on_sun", "girl_on_sun", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"), "remove the christmas hat.", 642087011, "spider_man_rm", "spider_man_rm", False, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"), "remove the flower.", 642087011, "anime_flower", "anime_flower", False, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"), "replace the clothes to a delicated floral skirt.", 648464818, "chenduling", "chenduling", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"), "make the hedgehog in Italy.", 648464818, "hedgehog_rp_bg", "hedgehog_rp_bg", True, False, "GPT4-o (Highly Recommended)" ], ] INPUT_IMAGE_PATH = { "frog": "./assets/frog/frog.jpeg", "chinese_girl": "./assets/chinese_girl/chinese_girl.png", "angel_christmas": "./assets/angel_christmas/angel_christmas.png", "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png", "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png", "spider_man_rm": "./assets/spider_man_rm/spider_man.png", "anime_flower": "./assets/anime_flower/anime_flower.png", "chenduling": "./assets/chenduling/chengduling.jpg", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png", } MASK_IMAGE_PATH = { "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png", "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png", "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png", "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png", "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", } MASKED_IMAGE_PATH = { "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png", "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png", "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png", "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png", "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", } OUTPUT_IMAGE_PATH = { "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png", "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png", "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png", "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png", "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png", "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png", "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png", "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png", } # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir' # os.makedirs('gradio_temp_dir', exist_ok=True) VLM_MODEL_NAMES = list(vlms_template.keys()) DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)" BASE_MODELS = list(base_models_template.keys()) DEFAULT_BASE_MODEL = "realisticVision (Default)" ASPECT_RATIO_LABELS = list(aspect_ratios) DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0] ## init device try: if torch.cuda.is_available(): device = "cuda" elif sys.platform == "darwin" and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" except: device = "cpu" # ## init torch dtype # if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): # torch_dtype = torch.bfloat16 # else: # torch_dtype = torch.float16 # if device == "mps": # torch_dtype = torch.float16 torch_dtype = torch.float16 # download hf models BrushEdit_path = "models/" if not os.path.exists(BrushEdit_path): BrushEdit_path = snapshot_download( repo_id="TencentARC/BrushEdit", local_dir=BrushEdit_path, token=os.getenv("HF_TOKEN"), ) ## init default VLM vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME] if vlm_processor != "" and vlm_model != "": vlm_model.to(device) else: gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.") ## init base model base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE") brushnet_path = os.path.join(BrushEdit_path, "brushnetX") sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth") groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth") # input brushnetX ckpt path brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype) pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False ) # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) # remove following line if xformers is not installed or when using Torch 2.0. # pipe.enable_xformers_memory_efficient_attention() pipe.enable_model_cpu_offload() ## init SAM sam = build_sam(checkpoint=sam_path) sam.to(device=device) sam_predictor = SamPredictor(sam) sam_automask_generator = SamAutomaticMaskGenerator(sam) ## init groundingdino_model config_file = 'app/utils/GroundingDINO_SwinT_OGC.py' groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device) ## Ordinary function def crop_and_resize(image: Image.Image, target_width: int, target_height: int) -> Image.Image: """ Crops and resizes an image while preserving the aspect ratio. Args: image (Image.Image): Input PIL image to be cropped and resized. target_width (int): Target width of the output image. target_height (int): Target height of the output image. Returns: Image.Image: Cropped and resized image. """ # Original dimensions original_width, original_height = image.size original_aspect = original_width / original_height target_aspect = target_width / target_height # Calculate crop box to maintain aspect ratio if original_aspect > target_aspect: # Crop horizontally new_width = int(original_height * target_aspect) new_height = original_height left = (original_width - new_width) / 2 top = 0 right = left + new_width bottom = original_height else: # Crop vertically new_width = original_width new_height = int(original_width / target_aspect) left = 0 top = (original_height - new_height) / 2 right = original_width bottom = top + new_height # Crop and resize cropped_image = image.crop((left, top, right, bottom)) resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST) return resized_image ## Ordinary function def resize(image: Image.Image, target_width: int, target_height: int) -> Image.Image: """ Crops and resizes an image while preserving the aspect ratio. Args: image (Image.Image): Input PIL image to be cropped and resized. target_width (int): Target width of the output image. target_height (int): Target height of the output image. Returns: Image.Image: Cropped and resized image. """ # Original dimensions resized_image = image.resize((target_width, target_height), Image.NEAREST) return resized_image def move_mask_func(mask, direction, units): binary_mask = mask.squeeze()>0 rows, cols = binary_mask.shape moved_mask = np.zeros_like(binary_mask, dtype=bool) if direction == 'down': # move down moved_mask[max(0, units):, :] = binary_mask[:rows - units, :] elif direction == 'up': # move up moved_mask[:rows - units, :] = binary_mask[units:, :] elif direction == 'right': # move left moved_mask[:, max(0, units):] = binary_mask[:, :cols - units] elif direction == 'left': # move right moved_mask[:, :cols - units] = binary_mask[:, units:] return moved_mask def random_mask_func(mask, dilation_type='square', dilation_size=20): # Randomly select the size of dilation binary_mask = mask.squeeze()>0 if dilation_type == 'square_dilation': structure = np.ones((dilation_size, dilation_size), dtype=bool) dilated_mask = binary_dilation(binary_mask, structure=structure) elif dilation_type == 'square_erosion': structure = np.ones((dilation_size, dilation_size), dtype=bool) dilated_mask = binary_erosion(binary_mask, structure=structure) elif dilation_type == 'bounding_box': # find the most left top and left bottom point rows, cols = np.where(binary_mask) if len(rows) == 0 or len(cols) == 0: return mask # return original mask if no valid points min_row = np.min(rows) max_row = np.max(rows) min_col = np.min(cols) max_col = np.max(cols) # create a bounding box dilated_mask = np.zeros_like(binary_mask, dtype=bool) dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True elif dilation_type == 'bounding_ellipse': # find the most left top and left bottom point rows, cols = np.where(binary_mask) if len(rows) == 0 or len(cols) == 0: return mask # return original mask if no valid points min_row = np.min(rows) max_row = np.max(rows) min_col = np.min(cols) max_col = np.max(cols) # calculate the center and axis length of the ellipse center = ((min_col + max_col) // 2, (min_row + max_row) // 2) a = (max_col - min_col) // 2 # half long axis b = (max_row - min_row) // 2 # half short axis # create a bounding ellipse y, x = np.ogrid[:mask.shape[0], :mask.shape[1]] ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1 dilated_mask = np.zeros_like(binary_mask, dtype=bool) dilated_mask[ellipse_mask] = True else: raise ValueError("dilation_type must be 'square' or 'ellipse'") # use binary dilation dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255 return dilated_mask ## Gradio component function def update_vlm_model(vlm_name): global vlm_model, vlm_processor if vlm_model is not None: del vlm_model torch.cuda.empty_cache() vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name] ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py if vlm_type == "llava-next": if vlm_processor != "" and vlm_model != "": vlm_model.to(device) return vlm_model_dropdown else: if os.path.exists(vlm_local_path): vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path) vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") else: if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llama3-llava-next-8b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llava-v1.6-34b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llava-next-72b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto") elif vlm_type == "qwen2-vl": if vlm_processor != "" and vlm_model != "": vlm_model.to(device) return vlm_model_dropdown else: if os.path.exists(vlm_local_path): vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path) vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") else: if vlm_name == "qwen2-vl-2b-instruct (Preload)": vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto") elif vlm_name == "qwen2-vl-7b-instruct (Preload)": vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto") elif vlm_name == "qwen2-vl-72b-instruct (Preload)": vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct") vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto") elif vlm_type == "openai": pass return "success" def update_base_model(base_model_name): global pipe ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py if pipe is not None: del pipe torch.cuda.empty_cache() base_model_path, pipe = base_models_template[base_model_name] if pipe != "": pipe.to(device) else: if os.path.exists(base_model_path): pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False ) # pipe.enable_xformers_memory_efficient_attention() pipe.enable_model_cpu_offload() else: raise gr.Error(f"The base model {base_model_name} does not exist") return "success" def submit_GPT4o_KEY(GPT4o_KEY): global vlm_model, vlm_processor if vlm_model is not None: del vlm_model torch.cuda.empty_cache() try: vlm_model = OpenAI(api_key=GPT4o_KEY) vlm_processor = "" response = vlm_model.chat.completions.create( model="gpt-4o-2024-08-06", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Say this is a test"} ] ) response_str = response.choices[0].message.content return "Success, " + response_str, "GPT4-o (Highly Recommended)" except Exception as e: return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)" @spaces.GPU(duration=180) def process(input_image, original_image, original_mask, prompt, negative_prompt, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps, num_samples, blending, category, target_prompt, resize_default, aspect_ratio_name, invert_mask_state): if original_image is None: if input_image is None: raise gr.Error('Please upload the input image') else: image_pil = input_image["background"].convert("RGB") original_image = np.array(image_pil) if prompt is None or prompt == "": raise gr.Error("Please input your instructions, e.g., remove the xxx") alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if invert_mask_state: original_mask = original_mask else: if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if category is not None: pass else: category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) if original_mask is not None: original_mask = np.clip(original_mask, 0, 255).astype(np.uint8) else: object_wait_for_edit = vlm_response_object_wait_for_edit( vlm_processor, vlm_model, original_image, category, prompt, device) original_mask = vlm_response_mask(vlm_processor, vlm_model, category, original_image, prompt, object_wait_for_edit, sam, sam_predictor, sam_automask_generator, groundingdino_model, device) if original_mask.ndim == 2: original_mask = original_mask[:,:,None] if len(target_prompt) <= 1: prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( vlm_processor, vlm_model, original_image, prompt, device) else: prompt_after_apply_instruction = target_prompt generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed) with torch.autocast(device): image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe, prompt_after_apply_instruction, original_mask, original_image, generator, num_inference_steps, guidance_scale, control_strength, negative_prompt, num_samples, blending) original_image = np.array(init_image_np) masked_image = original_image * (1 - (mask_np>0)) masked_image = masked_image.astype(np.uint8) masked_image = Image.fromarray(masked_image) # Save the images (optional) # import uuid # uuid = str(uuid.uuid4()) # image[0].save(f"outputs/image_edit_{uuid}_0.png") # image[1].save(f"outputs/image_edit_{uuid}_1.png") # image[2].save(f"outputs/image_edit_{uuid}_2.png") # image[3].save(f"outputs/image_edit_{uuid}_3.png") # mask_image.save(f"outputs/mask_{uuid}.png") # masked_image.save(f"outputs/masked_image_{uuid}.png") return image, [mask_image], [masked_image], prompt, '', prompt_after_apply_instruction, False def generate_target_prompt(input_image, original_image, prompt): # load example image if isinstance(original_image, str): original_image = input_image prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( vlm_processor, vlm_model, original_image, prompt, device) return prompt_after_apply_instruction, prompt_after_apply_instruction def process_mask(input_image, original_image, prompt, resize_default, aspect_ratio_name): if original_image is None: raise gr.Error('Please upload the input image') if prompt is None: raise gr.Error("Please input your instructions, e.g., remove the xxx") ## load mask alpha_mask = input_image["layers"][0].split()[3] input_mask = np.array(alpha_mask) # load example image if isinstance(original_image, str): original_image = input_image["background"] if input_mask.max() == 0: category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor, vlm_model, original_image, category, prompt, device) # original mask: h,w,1 [0, 255] original_mask = vlm_response_mask( vlm_processor, vlm_model, category, original_image, prompt, object_wait_for_edit, sam, sam_predictor, sam_automask_generator, groundingdino_model, device) else: original_mask = input_mask category = None ## resize mask if needed output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if original_mask.ndim == 2: original_mask = original_mask[:,:,None] mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (original_mask>0)) masked_image = masked_image.astype(np.uint8) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], original_mask.astype(np.uint8), category def process_random_mask(input_image, original_image, original_mask, resize_default, aspect_ratio_name, ): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse']) random_mask = random_mask_func(original_mask, dilation_type).squeeze() mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (random_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) def process_dilation_mask(input_image, original_image, original_mask, resize_default, aspect_ratio_name, dilation_size=20): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] dilation_type = np.random.choice(['square_dilation']) random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (random_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) def process_erosion_mask(input_image, original_image, original_mask, resize_default, aspect_ratio_name, dilation_size=20): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] dilation_type = np.random.choice(['square_erosion']) random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (random_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) def move_mask_left(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def move_mask_right(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def move_mask_up(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def move_mask_down(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def invert_mask(input_image, original_image, original_mask, ): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) if input_mask.max() == 0: original_mask = 1 - (original_mask>0).astype(np.uint8) else: original_mask = 1 - (input_mask>0).astype(np.uint8) if original_mask is None: raise gr.Error('Please generate mask first') original_mask = original_mask.squeeze() mask_image = Image.fromarray(original_mask*255).convert("RGB") if original_mask.ndim == 2: original_mask = original_mask[:,:,None] if original_mask.max() <= 1: original_mask = (original_mask * 255).astype(np.uint8) masked_image = original_image * (1 - (original_mask>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], original_mask, True def init_img(base, init_type, prompt, aspect_ratio, example_change_times ): image_pil = base["background"].convert("RGB") original_image = np.array(image_pil) if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0: raise gr.Error('image aspect ratio cannot be larger than 2.0') if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2: mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")] masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")] result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")] width, height = image_pil.size image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) image_pil = image_pil.resize((width_new, height_new)) mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) result_gallery[0] = result_gallery[0].resize((width_new, height_new)) original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times else: return base, original_image, None, "", None, None, None, "", "", "", aspect_ratio, True, False, 0 def reset_func(input_image, original_image, original_mask, prompt, target_prompt, target_prompt_output): input_image = None original_image = None original_mask = None prompt = '' mask_gallery = [] masked_gallery = [] result_gallery = [] target_prompt = '' target_prompt_output = '' if torch.cuda.is_available(): torch.cuda.empty_cache() return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, target_prompt_output, True, False def update_example(example_type, prompt, example_change_times): input_image = INPUT_IMAGE_PATH[example_type] image_pil = Image.open(input_image).convert("RGB") mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")] masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")] result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")] width, height = image_pil.size image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) image_pil = image_pil.resize((width_new, height_new)) mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) result_gallery[0] = result_gallery[0].resize((width_new, height_new)) original_image = np.array(image_pil) original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 aspect_ratio = "Custom resolution" example_change_times += 1 return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", "", False, example_change_times block = gr.Blocks( theme=gr.themes.Soft( radius_size=gr.themes.sizes.radius_none, text_size=gr.themes.sizes.text_md ) ).queue() with block as demo: with gr.Row(): with gr.Column(): gr.HTML(head) gr.Markdown(descriptions) with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"): with gr.Row(equal_height=True): gr.Markdown(instructions) original_image = gr.State(value=None) original_mask = gr.State(value=None) category = gr.State(value=None) status = gr.State(value=None) invert_mask_state = gr.State(value=False) example_change_times = gr.State(value=0) with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.ImageEditor( label="Input Image", type="pil", brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"), layers = False, interactive=True, height=1024, sources=["upload"], ) vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True) with gr.Group(): with gr.Row(): GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1) GPT4o_KEY_submit = gr.Button("Submit and Verify") aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO) resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True) prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1) run_button = gr.Button("💫 Run") with gr.Row(): mask_button = gr.Button("Generate Mask") random_mask_button = gr.Button("Square/Circle Mask ") with gr.Row(): generate_target_prompt_button = gr.Button("Generate Target Prompt") target_prompt = gr.Text( label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2 ) with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"): base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True) negative_prompt = gr.Text( label="Negative Prompt", max_lines=5, placeholder="Please input your negative prompt", value='ugly, low quality',lines=1 ) control_strength = gr.Slider( label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01 ) with gr.Group(): seed = gr.Slider( label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818 ) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) blending = gr.Checkbox(label="Blending mode", value=True) num_samples = gr.Slider( label="Num samples", minimum=0, maximum=4, step=1, value=4 ) with gr.Group(): with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1, maximum=12, step=0.1, value=7.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=50, ) with gr.Column(): with gr.Row(): with gr.Tab(elem_classes="feedback", label="Masked Image"): masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360) with gr.Tab(elem_classes="feedback", label="Mask"): mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360) invert_mask_button = gr.Button("Invert Mask") dilation_size = gr.Slider( label="Dilation size: ", minimum=0, maximum=50, step=1, value=20 ) with gr.Row(): dilation_mask_button = gr.Button("Dilation Generated Mask") erosion_mask_button = gr.Button("Erosion Generated Mask") moving_pixels = gr.Slider( label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1 ) with gr.Row(): move_left_button = gr.Button("Move Left") move_right_button = gr.Button("Move Right") with gr.Row(): move_up_button = gr.Button("Move Up") move_down_button = gr.Button("Move Down") with gr.Tab(elem_classes="feedback", label="Output"): result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400) target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False) reset_button = gr.Button("Reset") init_type = gr.Textbox(label="Init Name", value="", visible=False) example_type = gr.Textbox(label="Example Name", value="", visible=False) with gr.Row(): example = gr.Examples( label="Quick Example", examples=EXAMPLES, inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown], examples_per_page=10, cache_examples=False, ) with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"): with gr.Row(equal_height=True): gr.Markdown(tips) with gr.Row(): gr.Markdown(citation) ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery. ## And we need to solve the conflict between the upload and change example functions. input_image.upload( init_img, [input_image, init_type, prompt, aspect_ratio, example_change_times], [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, target_prompt_output, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times] ) example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, target_prompt_output, invert_mask_state, example_change_times]) ## vlm and base model dropdown vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status]) base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status]) GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown]) invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state]) ips=[input_image, original_image, original_mask, prompt, negative_prompt, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps, num_samples, blending, category, target_prompt, resize_default, aspect_ratio, invert_mask_state] ## run brushedit run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, target_prompt_output, invert_mask_state]) ## mask func mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category]) random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) ## move mask func move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) ## prompt func generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt, target_prompt_output]) ## reset func reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt, target_prompt_output], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, target_prompt_output, resize_default, invert_mask_state]) demo.launch(server_name="0.0.0.0", server_port=12345, share=False)