multimodalart HF staff commited on
Commit
ce3552b
1 Parent(s): e3d135e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+
6
+ MAX_COLORS = 12
7
+
8
+ def get_high_freq_colors(image):
9
+ im = image.getcolors(maxcolors=1024*1024)
10
+ sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
11
+
12
+ freqs = [c[0] for c in sorted_colors]
13
+ mean_freq = sum(freqs) / len(freqs)
14
+
15
+ high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
16
+ return high_freq_colors
17
+
18
+ def color_quantization(image, n_colors):
19
+ # Get color histogram
20
+ hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
21
+ # Get most frequent colors
22
+ colors = np.argwhere(hist > 0)
23
+ colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
24
+ colors = colors[:n_colors]
25
+ # Replace each pixel with the closest color
26
+ dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
27
+ labels = np.argmin(dists, axis=1)
28
+ return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
29
+
30
+ def create_binary_matrix(img_arr, target_color):
31
+ print(target_color)
32
+ # Create mask of pixels with target color
33
+ mask = np.all(img_arr == target_color, axis=-1)
34
+
35
+ # Convert mask to binary matrix
36
+ binary_matrix = mask.astype(int)
37
+ return binary_matrix
38
+
39
+ def process_sketch(image, binary_matrixes):
40
+ high_freq_colors = get_high_freq_colors(image)
41
+ how_many_colors = len(high_freq_colors)
42
+ im2arr = np.array(image) # im2arr.shape: height x width x channel
43
+ im2arr = color_quantization(im2arr, n_colors=how_many_colors)
44
+
45
+ colors_fixed = []
46
+ for color in high_freq_colors[1:]:
47
+ r = color[1][0]
48
+ g = color[1][1]
49
+ b = color[1][2]
50
+ binary_matrix = create_binary_matrix(im2arr, (r,g,b))
51
+ binary_matrixes.append(binary_matrix)
52
+ colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
53
+ visibilities = []
54
+ colors = []
55
+ for n in range(MAX_COLORS):
56
+ visibilities.append(gr.update(visible=False))
57
+ colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
58
+ for n in range(how_many_colors-1):
59
+ visibilities[n] = gr.update(visible=True)
60
+ colors[n] = colors_fixed[n]
61
+ return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
62
+
63
+ def process_generation(binary_matrixes, master_prompt, *prompts):
64
+ clipped_prompts = prompts[:len(binary_matrixes)]
65
+ print(clipped_prompts)
66
+ #Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs
67
+ pass
68
+
69
+ css = '''
70
+ #color-bg{display:flex;justify-content: center;align-items: center;}
71
+ .color-bg-item{width: 100%; height: 32px}
72
+ #main_button{width:100%}
73
+ '''
74
+ def update_css(aspect):
75
+ if(aspect=='Square'):
76
+ width = 512
77
+ height = 512
78
+ elif(aspect == 'Horizontal'):
79
+ width = 768
80
+ height = 512
81
+ elif(aspect=='Vertical'):
82
+ width = 512
83
+ height = 768
84
+ return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>")
85
+
86
+ with gr.Blocks(css=css) as demo:
87
+ binary_matrixes = gr.State([])
88
+ with gr.Box(elem_id="main-image"):
89
+ with gr.Row():
90
+ image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil")
91
+ with gr.Row():
92
+ aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
93
+ button_run = gr.Button("I've finished my sketch",elem_id="main_button")
94
+
95
+ prompts = []
96
+ colors = []
97
+ color_row = [None] * MAX_COLORS
98
+ with gr.Column(visible=False) as post_sketch:
99
+ general_prompt = gr.Textbox(label="General Prompt")
100
+ for n in range(MAX_COLORS):
101
+ with gr.Row(visible=False) as color_row[n]:
102
+ with gr.Box(elem_id="color-bg"):
103
+ colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
104
+ prompts.append(gr.Textbox(label="Prompt for this color"))
105
+ final_run_btn = gr.Button("Generate!")
106
+ gallery = gr.Gallery()
107
+
108
+ css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
109
+
110
+ aspect.change(update_css, inputs=aspect, outputs=css_height)
111
+ button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
112
+ final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=gallery)
113
+ demo.launch(share=True, debug=True)