File size: 5,003 Bytes
ce3552b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d42756a
ce3552b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79488ea
1467efe
79488ea
 
da9710b
fce9e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da9710b
fce9e32
ce3552b
 
 
 
d42756a
a8d52ff
f4a8adf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import numpy as np
import cv2
from PIL import Image

MAX_COLORS = 12

def get_high_freq_colors(image):
  im = image.getcolors(maxcolors=1024*1024)
  sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
  
  freqs = [c[0] for c in sorted_colors]
  mean_freq = sum(freqs) / len(freqs)

  high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)]  # Ignore colors that occur very few times (less than 2) or less than half the average frequency
  return high_freq_colors

def color_quantization(image, n_colors):
    # Get color histogram
    hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
    # Get most frequent colors
    colors = np.argwhere(hist > 0)
    colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
    colors = colors[:n_colors]
    # Replace each pixel with the closest color
    dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
    labels = np.argmin(dists, axis=1)
    return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)

def create_binary_matrix(img_arr, target_color):
    print(target_color)
    # Create mask of pixels with target color
    mask = np.all(img_arr == target_color, axis=-1)
    
    # Convert mask to binary matrix
    binary_matrix = mask.astype(int)
    return binary_matrix

def process_sketch(image, binary_matrixes):
  high_freq_colors = get_high_freq_colors(image)
  how_many_colors = len(high_freq_colors)
  im2arr = np.array(image) # im2arr.shape: height x width x channel
  im2arr = color_quantization(im2arr, n_colors=how_many_colors)
  
  colors_fixed = []
  for color in high_freq_colors[1:]:
    r = color[1][0]
    g = color[1][1]
    b = color[1][2]
    binary_matrix = create_binary_matrix(im2arr, (r,g,b))
    binary_matrixes.append(binary_matrix)
    colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
  visibilities = []
  colors = []
  for n in range(MAX_COLORS):
    visibilities.append(gr.update(visible=False))
    colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
  for n in range(how_many_colors-1):
    visibilities[n] = gr.update(visible=True)
    colors[n] = colors_fixed[n]
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]

def process_generation(binary_matrixes, master_prompt, *prompts):
    clipped_prompts = prompts[:len(binary_matrixes)]
    #Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs
    pass

css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
  if(aspect=='Square'):
    width = 512
    height = 512
  elif(aspect == 'Horizontal'):
    width = 768
    height = 512
  elif(aspect=='Vertical'):
    width = 512
    height = 768
  return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>")

with gr.Blocks(css=css) as demo:
  binary_matrixes = gr.State([])
  gr.Markdown('''## Control your Stable Diffusion generation with Sketches
  This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
  ''')
  with gr.Row():
    with gr.Box(elem_id="main-image"):
      with gr.Row():
          image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil")
      with gr.Row():
          aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
      button_run = gr.Button("I've finished my sketch",elem_id="main_button")
      
      prompts = []
      colors = []
      color_row = [None] * MAX_COLORS
      with gr.Column(visible=False) as post_sketch:
        general_prompt = gr.Textbox(label="General Prompt")
        for n in range(MAX_COLORS):
          with gr.Row(visible=False) as color_row[n]:
            with gr.Box(elem_id="color-bg"):
              colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
            prompts.append(gr.Textbox(label="Prompt for this color"))
        final_run_btn = gr.Button("Generate!")
    
    out_image = gr.Image(label="Result")
  
  css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
  
  aspect.change(update_css, inputs=aspect, outputs=css_height)
  button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
demo.launch()