File size: 4,589 Bytes
ce3552b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a8adf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import numpy as np
import cv2
from PIL import Image

MAX_COLORS = 12

def get_high_freq_colors(image):
  im = image.getcolors(maxcolors=1024*1024)
  sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
  
  freqs = [c[0] for c in sorted_colors]
  mean_freq = sum(freqs) / len(freqs)

  high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)]  # Ignore colors that occur very few times (less than 2) or less than half the average frequency
  return high_freq_colors

def color_quantization(image, n_colors):
    # Get color histogram
    hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
    # Get most frequent colors
    colors = np.argwhere(hist > 0)
    colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
    colors = colors[:n_colors]
    # Replace each pixel with the closest color
    dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
    labels = np.argmin(dists, axis=1)
    return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)

def create_binary_matrix(img_arr, target_color):
    print(target_color)
    # Create mask of pixels with target color
    mask = np.all(img_arr == target_color, axis=-1)
    
    # Convert mask to binary matrix
    binary_matrix = mask.astype(int)
    return binary_matrix

def process_sketch(image, binary_matrixes):
  high_freq_colors = get_high_freq_colors(image)
  how_many_colors = len(high_freq_colors)
  im2arr = np.array(image) # im2arr.shape: height x width x channel
  im2arr = color_quantization(im2arr, n_colors=how_many_colors)
  
  colors_fixed = []
  for color in high_freq_colors[1:]:
    r = color[1][0]
    g = color[1][1]
    b = color[1][2]
    binary_matrix = create_binary_matrix(im2arr, (r,g,b))
    binary_matrixes.append(binary_matrix)
    colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
  visibilities = []
  colors = []
  for n in range(MAX_COLORS):
    visibilities.append(gr.update(visible=False))
    colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
  for n in range(how_many_colors-1):
    visibilities[n] = gr.update(visible=True)
    colors[n] = colors_fixed[n]
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]

def process_generation(binary_matrixes, master_prompt, *prompts):
    clipped_prompts = prompts[:len(binary_matrixes)]
    print(clipped_prompts)
    #Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs
    pass

css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
  if(aspect=='Square'):
    width = 512
    height = 512
  elif(aspect == 'Horizontal'):
    width = 768
    height = 512
  elif(aspect=='Vertical'):
    width = 512
    height = 768
  return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>")

with gr.Blocks(css=css) as demo:
  binary_matrixes = gr.State([])
  with gr.Box(elem_id="main-image"):
    with gr.Row():
      image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil")
    with gr.Row():
      aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
    button_run = gr.Button("I've finished my sketch",elem_id="main_button")
    
    prompts = []
    colors = []
    color_row = [None] * MAX_COLORS
    with gr.Column(visible=False) as post_sketch:
      general_prompt = gr.Textbox(label="General Prompt")
      for n in range(MAX_COLORS):
        with gr.Row(visible=False) as color_row[n]:
          with gr.Box(elem_id="color-bg"):
            colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
          prompts.append(gr.Textbox(label="Prompt for this color"))
      final_run_btn = gr.Button("Generate!")
      gallery = gr.Gallery()
  
  css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
  
  aspect.change(update_css, inputs=aspect, outputs=css_height)
  button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=gallery)
demo.launch()