File size: 4,831 Bytes
ce3552b 3f6a115 527bf99 ce3552b c590a10 ce3552b 7120442 ce3552b 7120442 ce3552b 527bf99 fc8546c ce3552b d42756a ce3552b 527bf99 ce3552b 527bf99 ce3552b 527bf99 ce3552b 527bf99 ce3552b 79488ea 1467efe 79488ea da9710b fce9e32 527bf99 fce9e32 251a915 fce9e32 da9710b fce9e32 2104e5b 527bf99 d42756a a8d52ff 527bf99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
from region_control import MultiDiffusion, get_views, preprocess_mask
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12
sd = MultiDiffusion("cuda", "2.0")
def process_sketch(image, binary_matrixes):
high_freq_colors, image = get_high_freq_colors(image)
how_many_colors = len(high_freq_colors)
im2arr = np.array(image) # im2arr.shape: height x width x channel
im2arr = color_quantization(im2arr, high_freq_colors)
colors_fixed = []
for color in high_freq_colors:
r, g, b = color[1]
if any(c != 255 for c in (r, g, b)):
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
binary_matrixes.append(binary_matrix)
colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>'))
visibilities = []
colors = []
for n in range(MAX_COLORS):
visibilities.append(gr.update(visible=False))
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
for n in range(how_many_colors-1):
visibilities[n] = gr.update(visible=True)
colors[n] = colors_fixed[n]
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
def process_generation(binary_matrixes, master_prompt, *prompts):
clipped_prompts = prompts[:len(binary_matrixes)]
prompts = [master_prompt] + list(clipped_prompts)
neg_prompts = [""] * len(prompts)
fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])
print(masks.size())
image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
return(image)
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
if(aspect=='Square'):
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
elif(aspect == 'Horizontal'):
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
elif(aspect=='Vertical'):
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
with gr.Blocks(css=css) as demo:
binary_matrixes = gr.State([])
gr.Markdown('''## Control your Stable Diffusion generation with Sketches
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
''')
with gr.Row():
with gr.Box(elem_id="main-image"):
#with gr.Row():
image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
#image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
#image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
#with gr.Row():
# aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
prompts = []
colors = []
color_row = [None] * MAX_COLORS
with gr.Column(visible=False) as post_sketch:
general_prompt = gr.Textbox(label="General Prompt")
for n in range(MAX_COLORS):
with gr.Row(visible=False) as color_row[n]:
with gr.Box(elem_id="color-bg"):
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
prompts.append(gr.Textbox(label="Prompt for this mask"))
final_run_btn = gr.Button("Generate!")
out_image = gr.Image(label="Result")
gr.Markdown('''
![Examples](https://multidiffusion.github.io/pics/tight.jpg)
''')
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
#aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
demo.launch(debug=True) |