multimodalart HF staff commited on
Commit
b05fbf1
1 Parent(s): 98c2a34
Files changed (1) hide show
  1. app.py +39 -33
app.py CHANGED
@@ -5,14 +5,17 @@ from PIL import Image
5
  import torch
6
  import base64
7
  import requests
 
 
8
  from io import BytesIO
9
- from region_control import MultiDiffusion, get_views, preprocess_mask
10
  from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
11
  MAX_COLORS = 12
12
 
13
  sd = MultiDiffusion("cuda", "2.1")
14
-
15
- canvas_html = "<div id='canvas-root'></div>"
 
16
  load_js = """
17
  async () => {
18
  const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
@@ -49,6 +52,7 @@ async (aspect) => {
49
  """
50
 
51
  def process_sketch(canvas_data, binary_matrixes):
 
52
  base64_img = canvas_data['image']
53
  image_data = base64.b64decode(base64_img.split(',')[1])
54
  image = Image.open(BytesIO(image_data)).convert("RGB")
@@ -71,58 +75,55 @@ def process_sketch(canvas_data, binary_matrixes):
71
  colors[n] = colors_fixed[n]
72
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
73
 
74
- def process_generation(binary_matrixes, master_prompt, *prompts):
 
 
 
 
 
 
 
 
75
  clipped_prompts = prompts[:len(binary_matrixes)]
76
  prompts = [master_prompt] + list(clipped_prompts)
77
- neg_prompts = ["low quality"] * len(prompts)
78
- fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
79
  bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
80
  bg_mask[bg_mask < 0] = 0
81
  masks = torch.cat([bg_mask, fg_masks])
82
  print(masks.size())
83
- image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
84
  return(image)
85
 
86
  css = '''
87
  #color-bg{display:flex;justify-content: center;align-items: center;}
88
  .color-bg-item{width: 100%; height: 32px}
89
  #main_button{width:100%}
90
- .isPopup.svelte-160vdtq {
91
- top: -342px !important;
92
- z-index: 10001 !important;
93
- left: -25px !important;
94
- }
95
- .alpha.svelte-2ybi8r, .color.svelte-2ybi8r {
96
- width: 25px !important;
97
- height: 25px !important;
98
- }
99
  <style>
100
 
101
  '''
102
- def update_css(aspect):
103
- if(aspect=='Square'):
104
- return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
105
- elif(aspect == 'Horizontal'):
106
- return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
107
- elif(aspect=='Vertical'):
108
- return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
109
 
110
  with gr.Blocks(css=css) as demo:
111
  binary_matrixes = gr.State([])
112
  gr.Markdown('''## Control your Stable Diffusion generation with Sketches
113
  This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
 
114
  ''')
 
 
 
 
 
 
 
 
 
115
  with gr.Row():
116
  with gr.Box(elem_id="main-image"):
117
- #with gr.Row():
118
  canvas_data = gr.JSON(value={}, visible=False)
119
  canvas = gr.HTML(canvas_html)
120
- #image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
121
-
122
- #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
123
- #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
124
- #with gr.Row():
125
- # aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
126
  button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
127
 
128
  prompts = []
@@ -135,15 +136,20 @@ with gr.Blocks(css=css) as demo:
135
  with gr.Box(elem_id="color-bg"):
136
  colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
137
  prompts.append(gr.Textbox(label="Prompt for this mask"))
 
 
 
 
 
138
  final_run_btn = gr.Button("Generate!")
139
 
140
- out_image = gr.Image(label="Result")
141
  gr.Markdown('''
142
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
143
  ''')
144
  #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
145
- #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
146
  button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors)
147
- final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
148
  demo.load(None, None, None, _js=load_js)
149
  demo.launch(debug=True)
 
5
  import torch
6
  import base64
7
  import requests
8
+ import random
9
+ import os
10
  from io import BytesIO
11
+ from region_control import MultiDiffusion, get_views, preprocess_mask, seed_everything
12
  from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
13
  MAX_COLORS = 12
14
 
15
  sd = MultiDiffusion("cuda", "2.1")
16
+ is_shared_ui = True if "weizmannscience/multidiffusion-region-based" in os.environ['SPACE_ID'] else False
17
+ is_gpu_associated = True if torch.cuda.is_available() else False
18
+ canvas_html = "<div id='canvas-root' style='max-width:400px; margin: 0 auto'></div>"
19
  load_js = """
20
  async () => {
21
  const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
 
52
  """
53
 
54
  def process_sketch(canvas_data, binary_matrixes):
55
+ binary_matrixes.clear()
56
  base64_img = canvas_data['image']
57
  image_data = base64.b64decode(base64_img.split(',')[1])
58
  image = Image.open(BytesIO(image_data)).convert("RGB")
 
75
  colors[n] = colors_fixed[n]
76
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
77
 
78
+ def process_generation(model, binary_matrixes, boostrapping, aspect, steps, seed, master_prompt, negative_prompt, *prompts):
79
+ if(model != "stabilityai/stable-diffusion-2-1-base"):
80
+ sd = MultiDiffusion("cuda",model)
81
+ if(seed == -1):
82
+ seed = random.randint(1, 2147483647)
83
+ seed_everything(seed)
84
+ dimensions = {"square": (512, 512), "horizontal": (768, 512), "vertical": (512, 768)}
85
+ width, height = dimensions.get(aspect, dimensions["square"])
86
+
87
  clipped_prompts = prompts[:len(binary_matrixes)]
88
  prompts = [master_prompt] + list(clipped_prompts)
89
+ neg_prompts = [negative_prompt] * len(prompts)
90
+ fg_masks = torch.cat([preprocess_mask(mask_path, height // 8, width // 8, "cuda") for mask_path in binary_matrixes])
91
  bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
92
  bg_mask[bg_mask < 0] = 0
93
  masks = torch.cat([bg_mask, fg_masks])
94
  print(masks.size())
95
+ image = sd.generate(masks, prompts, neg_prompts, height, width, steps, bootstrapping=boostrapping)
96
  return(image)
97
 
98
  css = '''
99
  #color-bg{display:flex;justify-content: center;align-items: center;}
100
  .color-bg-item{width: 100%; height: 32px}
101
  #main_button{width:100%}
 
 
 
 
 
 
 
 
 
102
  <style>
103
 
104
  '''
 
 
 
 
 
 
 
105
 
106
  with gr.Blocks(css=css) as demo:
107
  binary_matrixes = gr.State([])
108
  gr.Markdown('''## Control your Stable Diffusion generation with Sketches
109
  This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
110
+
111
  ''')
112
+
113
+ if(is_shared_ui):
114
+ gr.HTML(f'''
115
+ <div>To skip the queue or try the model with custom models, you may duplicate the space and associate a GPU to it &nbsp;&nbsp;<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></div>
116
+ ''')
117
+ elif(not is_gpu_associated):
118
+ gr.HTML(f'''
119
+ <div>You have succesfully duplicated the Space 🎉, but it is running on CPU - which may break this application. Go to the <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings">settings</a> page to associate a GPU to it</div>
120
+ ''')
121
  with gr.Row():
122
  with gr.Box(elem_id="main-image"):
 
123
  canvas_data = gr.JSON(value={}, visible=False)
124
  canvas = gr.HTML(canvas_html)
125
+ aspect = gr.Radio(["square", "horizontal", "vertical"], value="square", label="Aspect Ratio", visible=False)
126
+ model = gr.Textbox(label="The id of any Hugging Face model in the diffusers format", value="stabilityai/stable-diffusion-2-1-base", visible=False if is_shared_ui else True)
 
 
 
 
127
  button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
128
 
129
  prompts = []
 
136
  with gr.Box(elem_id="color-bg"):
137
  colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
138
  prompts.append(gr.Textbox(label="Prompt for this mask"))
139
+ with gr.Accordion("Advanced options", open=False):
140
+ negative_prompt = gr.Textbox(label="Global negative prompt for all prompts", value="low quality")
141
+ boostrapping = gr.Slider(label="Bootstrapping", minimum=1, maximum=100, value=20, step=1)
142
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
143
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
144
  final_run_btn = gr.Button("Generate!")
145
 
146
+ out_image = gr.Image(label="Result", ).style(width=512,height=512)
147
  gr.Markdown('''
148
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
149
  ''')
150
  #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
151
+ aspect.change(None, inputs=[aspect], outputs=None, _js = set_canvas_size)
152
  button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors)
153
+ final_run_btn.click(process_generation, inputs=[model, binary_matrixes, boostrapping, aspect, steps, seed, general_prompt, negative_prompt, *prompts], outputs=out_image)
154
  demo.load(None, None, None, _js=load_js)
155
  demo.launch(debug=True)