multimodalart HF staff commited on
Commit
94ca5b6
1 Parent(s): 923fca3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -10
app.py CHANGED
@@ -3,21 +3,61 @@ import numpy as np
3
  import cv2
4
  from PIL import Image
5
  import torch
 
 
 
6
  from region_control import MultiDiffusion, get_views, preprocess_mask
7
  from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
8
  MAX_COLORS = 12
9
 
10
  sd = MultiDiffusion("cuda", "2.0")
11
 
12
- def process_sketch(image, binary_matrixes):
13
- high_freq_colors, image = get_high_freq_colors(image)
14
- how_many_colors = len(high_freq_colors)
15
- im2arr = np.array(image) # im2arr.shape: height x width x channel
16
- im2arr = color_quantization(im2arr, high_freq_colors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
18
  colors_fixed = []
19
- for color in high_freq_colors:
20
- r, g, b = color[1]
21
  if any(c != 255 for c in (r, g, b)):
22
  binary_matrix = create_binary_matrix(im2arr, (r,g,b))
23
  binary_matrixes.append(binary_matrix)
@@ -27,7 +67,7 @@ def process_sketch(image, binary_matrixes):
27
  for n in range(MAX_COLORS):
28
  visibilities.append(gr.update(visible=False))
29
  colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
30
- for n in range(how_many_colors-1):
31
  visibilities[n] = gr.update(visible=True)
32
  colors[n] = colors_fixed[n]
33
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
@@ -65,7 +105,10 @@ with gr.Blocks(css=css) as demo:
65
  with gr.Row():
66
  with gr.Box(elem_id="main-image"):
67
  #with gr.Row():
68
- image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
 
 
 
69
  #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
70
  #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
71
  #with gr.Row():
@@ -90,6 +133,7 @@ with gr.Blocks(css=css) as demo:
90
  ''')
91
  #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
92
  #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
93
- button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
94
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
 
95
  demo.launch(debug=True)
 
3
  import cv2
4
  from PIL import Image
5
  import torch
6
+ import base64
7
+ import requests
8
+ from io import BytesIO
9
  from region_control import MultiDiffusion, get_views, preprocess_mask
10
  from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
11
  MAX_COLORS = 12
12
 
13
  sd = MultiDiffusion("cuda", "2.0")
14
 
15
+ canvas_html = "<div id='canvas-root'></div>"
16
+ load_js = """
17
+ async () => {
18
+ const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
19
+ fetch(url)
20
+ .then(res => res.text())
21
+ .then(text => {
22
+ const script = document.createElement('script');
23
+ script.type = "module"
24
+ script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
25
+ document.head.appendChild(script);
26
+ });
27
+ }
28
+ """
29
+
30
+ get_js_colors = """
31
+ async (canvasData) => {
32
+ const canvasEl = document.getElementById("canvas-root");
33
+ return [canvasEl._data]
34
+ }
35
+ """
36
+
37
+ set_canvas_size ="""
38
+ async (aspect) => {
39
+ if(aspect ==='square'){
40
+ _updateCanvas(512,512)
41
+ }
42
+ if(aspect ==='horizontal'){
43
+ _updateCanvas(768,512)
44
+ }
45
+ if(aspect ==='vertical'){
46
+ _updateCanvas(512,768)
47
+ }
48
+ }
49
+ """
50
+
51
+ def process_sketch(canvas_data, binary_matrixes):
52
+ base64_img = canvas_data['image']
53
+ image_data = base64.b64decode(base64_img.split(',')[1])
54
+ image = Image.open(BytesIO(image_data))
55
+ im2arr = np.array(image)
56
 
57
+ colors = [tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) for color in canvas_data['colors']]
58
  colors_fixed = []
59
+ for color in colors:
60
+ r, g, b = color
61
  if any(c != 255 for c in (r, g, b)):
62
  binary_matrix = create_binary_matrix(im2arr, (r,g,b))
63
  binary_matrixes.append(binary_matrix)
 
67
  for n in range(MAX_COLORS):
68
  visibilities.append(gr.update(visible=False))
69
  colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
70
+ for n in range(len(colors)-1):
71
  visibilities[n] = gr.update(visible=True)
72
  colors[n] = colors_fixed[n]
73
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
 
105
  with gr.Row():
106
  with gr.Box(elem_id="main-image"):
107
  #with gr.Row():
108
+ canvas_data = gr.JSON(value={}, visible=False)
109
+ canvas = gr.HTML(canvas_html)
110
+ #image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
111
+
112
  #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
113
  #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
114
  #with gr.Row():
 
133
  ''')
134
  #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
135
  #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
136
+ button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
137
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
138
+ demo.load(None, None, None, _js=load_js)
139
  demo.launch(debug=True)