import gradio as gr import numpy as np import cv2 from PIL import Image import torch from region_control import MultiDiffusion, get_views, preprocess_mask from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix MAX_COLORS = 12 sd = MultiDiffusion("cuda", "2.0") def process_sketch(image, binary_matrixes): high_freq_colors, image = get_high_freq_colors(image) how_many_colors = len(high_freq_colors) im2arr = np.array(image) # im2arr.shape: height x width x channel im2arr = color_quantization(im2arr, high_freq_colors) colors_fixed = [] for color in high_freq_colors: r, g, b = color[1] if any(c != 255 for c in (r, g, b)): binary_matrix = create_binary_matrix(im2arr, (r,g,b)) binary_matrixes.append(binary_matrix) colors_fixed.append(gr.update(value=f'
')) visibilities = [] colors = [] for n in range(MAX_COLORS): visibilities.append(gr.update(visible=False)) colors.append(gr.update(value=f'
')) for n in range(how_many_colors-1): visibilities[n] = gr.update(visible=True) colors[n] = colors_fixed[n] return [gr.update(visible=True), binary_matrixes, *visibilities, *colors] def process_generation(binary_matrixes, master_prompt, *prompts): clipped_prompts = prompts[:len(binary_matrixes)] prompts = [master_prompt] + list(clipped_prompts) neg_prompts = [""] * len(prompts) fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes]) bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) bg_mask[bg_mask < 0] = 0 masks = torch.cat([bg_mask, fg_masks]) print(masks.size()) image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20) return(image) css = ''' #color-bg{display:flex;justify-content: center;align-items: center;} .color-bg-item{width: 100%; height: 32px} #main_button{width:100%} ''' def update_css(aspect): if(aspect=='Square'): return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] elif(aspect == 'Horizontal'): return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)] elif(aspect=='Vertical'): return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] with gr.Blocks(css=css) as demo: binary_matrixes = gr.State([]) gr.Markdown('''## Control your Stable Diffusion generation with Sketches This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io). ''') with gr.Row(): with gr.Box(elem_id="main-image"): #with gr.Row(): image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45) #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45) #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45) #with gr.Row(): # aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio") button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True) prompts = [] colors = [] color_row = [None] * MAX_COLORS with gr.Column(visible=False) as post_sketch: general_prompt = gr.Textbox(label="General Prompt") for n in range(MAX_COLORS): with gr.Row(visible=False) as color_row[n]: with gr.Box(elem_id="color-bg"): colors.append(gr.HTML('
')) prompts.append(gr.Textbox(label="Prompt for this mask")) final_run_btn = gr.Button("Generate!") out_image = gr.Image(label="Result") gr.Markdown(''' ![Examples](https://multidiffusion.github.io/pics/tight.jpg) ''') #css_height = gr.HTML("") #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical]) button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors]) final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image) demo.launch(debug=True)