import gradio as gr import numpy as np import cv2 from PIL import Image MAX_COLORS = 12 def get_high_freq_colors(image): im = image.getcolors(maxcolors=1024*1024) sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) freqs = [c[0] for c in sorted_colors] mean_freq = sum(freqs) / len(freqs) high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency return high_freq_colors def color_quantization(image, n_colors): # Get color histogram hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256))) # Get most frequent colors colors = np.argwhere(hist > 0) colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]] colors = colors[:n_colors] # Replace each pixel with the closest color dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2) labels = np.argmin(dists, axis=1) return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8) def create_binary_matrix(img_arr, target_color): print(target_color) # Create mask of pixels with target color mask = np.all(img_arr == target_color, axis=-1) # Convert mask to binary matrix binary_matrix = mask.astype(int) return binary_matrix def process_sketch(image, binary_matrixes): high_freq_colors = get_high_freq_colors(image) how_many_colors = len(high_freq_colors) im2arr = np.array(image) # im2arr.shape: height x width x channel im2arr = color_quantization(im2arr, n_colors=how_many_colors) colors_fixed = [] for color in high_freq_colors[1:]: r = color[1][0] g = color[1][1] b = color[1][2] binary_matrix = create_binary_matrix(im2arr, (r,g,b)) binary_matrixes.append(binary_matrix) colors_fixed.append(gr.update(value=f'
')) visibilities = [] colors = [] for n in range(MAX_COLORS): visibilities.append(gr.update(visible=False)) colors.append(gr.update(value=f'
')) for n in range(how_many_colors-1): visibilities[n] = gr.update(visible=True) colors[n] = colors_fixed[n] return [gr.update(visible=True), binary_matrixes, *visibilities, *colors] def process_generation(binary_matrixes, master_prompt, *prompts): clipped_prompts = prompts[:len(binary_matrixes)] #Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs pass css = ''' #color-bg{display:flex;justify-content: center;align-items: center;} .color-bg-item{width: 100%; height: 32px} #main_button{width:100%} ''' def update_css(aspect): if(aspect=='Square'): width = 512 height = 512 elif(aspect == 'Horizontal'): width = 768 height = 512 elif(aspect=='Vertical'): width = 512 height = 768 return gr.update(value=f"") with gr.Blocks(css=css) as demo: binary_matrixes = gr.State([]) gr.Markdown('''## Control your Stable Diffusion generation with Sketches This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io). ![Examples](https://multidiffusion.github.io/pics/tight.jpg) ''') with gr.Row(): with gr.Box(elem_id="main-image"): #with gr.Accordion(open=True, label="Your color sketch") as sketch_accordion: with gr.Column(): with gr.Row(): image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil") with gr.Row(): aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio") button_run = gr.Button("I've finished my sketch",elem_id="main_button") prompts = [] colors = [] color_row = [None] * MAX_COLORS with gr.Column(visible=False) as post_sketch: general_prompt = gr.Textbox(label="General Prompt") for n in range(MAX_COLORS): with gr.Row(visible=False) as color_row[n]: with gr.Box(elem_id="color-bg"): colors.append(gr.HTML('
')) prompts.append(gr.Textbox(label="Prompt for this color")) final_run_btn = gr.Button("Generate!") out_image = gr.Image() css_height = gr.HTML("") aspect.change(update_css, inputs=aspect, outputs=css_height) button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors]) final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image) demo.launch()