Commit
•
b05fbf1
1
Parent(s):
98c2a34
rc1
Browse files
app.py
CHANGED
@@ -5,14 +5,17 @@ from PIL import Image
|
|
5 |
import torch
|
6 |
import base64
|
7 |
import requests
|
|
|
|
|
8 |
from io import BytesIO
|
9 |
-
from region_control import MultiDiffusion, get_views, preprocess_mask
|
10 |
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
|
11 |
MAX_COLORS = 12
|
12 |
|
13 |
sd = MultiDiffusion("cuda", "2.1")
|
14 |
-
|
15 |
-
|
|
|
16 |
load_js = """
|
17 |
async () => {
|
18 |
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
|
@@ -49,6 +52,7 @@ async (aspect) => {
|
|
49 |
"""
|
50 |
|
51 |
def process_sketch(canvas_data, binary_matrixes):
|
|
|
52 |
base64_img = canvas_data['image']
|
53 |
image_data = base64.b64decode(base64_img.split(',')[1])
|
54 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
@@ -71,58 +75,55 @@ def process_sketch(canvas_data, binary_matrixes):
|
|
71 |
colors[n] = colors_fixed[n]
|
72 |
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
|
73 |
|
74 |
-
def process_generation(binary_matrixes, master_prompt, *prompts):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
clipped_prompts = prompts[:len(binary_matrixes)]
|
76 |
prompts = [master_prompt] + list(clipped_prompts)
|
77 |
-
neg_prompts = [
|
78 |
-
fg_masks = torch.cat([preprocess_mask(mask_path,
|
79 |
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
|
80 |
bg_mask[bg_mask < 0] = 0
|
81 |
masks = torch.cat([bg_mask, fg_masks])
|
82 |
print(masks.size())
|
83 |
-
image = sd.generate(masks, prompts, neg_prompts,
|
84 |
return(image)
|
85 |
|
86 |
css = '''
|
87 |
#color-bg{display:flex;justify-content: center;align-items: center;}
|
88 |
.color-bg-item{width: 100%; height: 32px}
|
89 |
#main_button{width:100%}
|
90 |
-
.isPopup.svelte-160vdtq {
|
91 |
-
top: -342px !important;
|
92 |
-
z-index: 10001 !important;
|
93 |
-
left: -25px !important;
|
94 |
-
}
|
95 |
-
.alpha.svelte-2ybi8r, .color.svelte-2ybi8r {
|
96 |
-
width: 25px !important;
|
97 |
-
height: 25px !important;
|
98 |
-
}
|
99 |
<style>
|
100 |
|
101 |
'''
|
102 |
-
def update_css(aspect):
|
103 |
-
if(aspect=='Square'):
|
104 |
-
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
|
105 |
-
elif(aspect == 'Horizontal'):
|
106 |
-
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
|
107 |
-
elif(aspect=='Vertical'):
|
108 |
-
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
109 |
|
110 |
with gr.Blocks(css=css) as demo:
|
111 |
binary_matrixes = gr.State([])
|
112 |
gr.Markdown('''## Control your Stable Diffusion generation with Sketches
|
113 |
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
|
|
|
114 |
''')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
with gr.Row():
|
116 |
with gr.Box(elem_id="main-image"):
|
117 |
-
#with gr.Row():
|
118 |
canvas_data = gr.JSON(value={}, visible=False)
|
119 |
canvas = gr.HTML(canvas_html)
|
120 |
-
|
121 |
-
|
122 |
-
#image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
|
123 |
-
#image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
|
124 |
-
#with gr.Row():
|
125 |
-
# aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
|
126 |
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
|
127 |
|
128 |
prompts = []
|
@@ -135,15 +136,20 @@ with gr.Blocks(css=css) as demo:
|
|
135 |
with gr.Box(elem_id="color-bg"):
|
136 |
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
|
137 |
prompts.append(gr.Textbox(label="Prompt for this mask"))
|
|
|
|
|
|
|
|
|
|
|
138 |
final_run_btn = gr.Button("Generate!")
|
139 |
|
140 |
-
out_image = gr.Image(label="Result")
|
141 |
gr.Markdown('''
|
142 |
![Examples](https://multidiffusion.github.io/pics/tight.jpg)
|
143 |
''')
|
144 |
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
|
145 |
-
|
146 |
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors)
|
147 |
-
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
|
148 |
demo.load(None, None, None, _js=load_js)
|
149 |
demo.launch(debug=True)
|
|
|
5 |
import torch
|
6 |
import base64
|
7 |
import requests
|
8 |
+
import random
|
9 |
+
import os
|
10 |
from io import BytesIO
|
11 |
+
from region_control import MultiDiffusion, get_views, preprocess_mask, seed_everything
|
12 |
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
|
13 |
MAX_COLORS = 12
|
14 |
|
15 |
sd = MultiDiffusion("cuda", "2.1")
|
16 |
+
is_shared_ui = True if "weizmannscience/multidiffusion-region-based" in os.environ['SPACE_ID'] else False
|
17 |
+
is_gpu_associated = True if torch.cuda.is_available() else False
|
18 |
+
canvas_html = "<div id='canvas-root' style='max-width:400px; margin: 0 auto'></div>"
|
19 |
load_js = """
|
20 |
async () => {
|
21 |
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
|
|
|
52 |
"""
|
53 |
|
54 |
def process_sketch(canvas_data, binary_matrixes):
|
55 |
+
binary_matrixes.clear()
|
56 |
base64_img = canvas_data['image']
|
57 |
image_data = base64.b64decode(base64_img.split(',')[1])
|
58 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
|
|
75 |
colors[n] = colors_fixed[n]
|
76 |
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
|
77 |
|
78 |
+
def process_generation(model, binary_matrixes, boostrapping, aspect, steps, seed, master_prompt, negative_prompt, *prompts):
|
79 |
+
if(model != "stabilityai/stable-diffusion-2-1-base"):
|
80 |
+
sd = MultiDiffusion("cuda",model)
|
81 |
+
if(seed == -1):
|
82 |
+
seed = random.randint(1, 2147483647)
|
83 |
+
seed_everything(seed)
|
84 |
+
dimensions = {"square": (512, 512), "horizontal": (768, 512), "vertical": (512, 768)}
|
85 |
+
width, height = dimensions.get(aspect, dimensions["square"])
|
86 |
+
|
87 |
clipped_prompts = prompts[:len(binary_matrixes)]
|
88 |
prompts = [master_prompt] + list(clipped_prompts)
|
89 |
+
neg_prompts = [negative_prompt] * len(prompts)
|
90 |
+
fg_masks = torch.cat([preprocess_mask(mask_path, height // 8, width // 8, "cuda") for mask_path in binary_matrixes])
|
91 |
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
|
92 |
bg_mask[bg_mask < 0] = 0
|
93 |
masks = torch.cat([bg_mask, fg_masks])
|
94 |
print(masks.size())
|
95 |
+
image = sd.generate(masks, prompts, neg_prompts, height, width, steps, bootstrapping=boostrapping)
|
96 |
return(image)
|
97 |
|
98 |
css = '''
|
99 |
#color-bg{display:flex;justify-content: center;align-items: center;}
|
100 |
.color-bg-item{width: 100%; height: 32px}
|
101 |
#main_button{width:100%}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
<style>
|
103 |
|
104 |
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
with gr.Blocks(css=css) as demo:
|
107 |
binary_matrixes = gr.State([])
|
108 |
gr.Markdown('''## Control your Stable Diffusion generation with Sketches
|
109 |
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
|
110 |
+
|
111 |
''')
|
112 |
+
|
113 |
+
if(is_shared_ui):
|
114 |
+
gr.HTML(f'''
|
115 |
+
<div>To skip the queue or try the model with custom models, you may duplicate the space and associate a GPU to it <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></div>
|
116 |
+
''')
|
117 |
+
elif(not is_gpu_associated):
|
118 |
+
gr.HTML(f'''
|
119 |
+
<div>You have succesfully duplicated the Space 🎉, but it is running on CPU - which may break this application. Go to the <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings">settings</a> page to associate a GPU to it</div>
|
120 |
+
''')
|
121 |
with gr.Row():
|
122 |
with gr.Box(elem_id="main-image"):
|
|
|
123 |
canvas_data = gr.JSON(value={}, visible=False)
|
124 |
canvas = gr.HTML(canvas_html)
|
125 |
+
aspect = gr.Radio(["square", "horizontal", "vertical"], value="square", label="Aspect Ratio", visible=False)
|
126 |
+
model = gr.Textbox(label="The id of any Hugging Face model in the diffusers format", value="stabilityai/stable-diffusion-2-1-base", visible=False if is_shared_ui else True)
|
|
|
|
|
|
|
|
|
127 |
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
|
128 |
|
129 |
prompts = []
|
|
|
136 |
with gr.Box(elem_id="color-bg"):
|
137 |
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
|
138 |
prompts.append(gr.Textbox(label="Prompt for this mask"))
|
139 |
+
with gr.Accordion("Advanced options", open=False):
|
140 |
+
negative_prompt = gr.Textbox(label="Global negative prompt for all prompts", value="low quality")
|
141 |
+
boostrapping = gr.Slider(label="Bootstrapping", minimum=1, maximum=100, value=20, step=1)
|
142 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
|
143 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
|
144 |
final_run_btn = gr.Button("Generate!")
|
145 |
|
146 |
+
out_image = gr.Image(label="Result", ).style(width=512,height=512)
|
147 |
gr.Markdown('''
|
148 |
![Examples](https://multidiffusion.github.io/pics/tight.jpg)
|
149 |
''')
|
150 |
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
|
151 |
+
aspect.change(None, inputs=[aspect], outputs=None, _js = set_canvas_size)
|
152 |
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors)
|
153 |
+
final_run_btn.click(process_generation, inputs=[model, binary_matrixes, boostrapping, aspect, steps, seed, general_prompt, negative_prompt, *prompts], outputs=out_image)
|
154 |
demo.load(None, None, None, _js=load_js)
|
155 |
demo.launch(debug=True)
|