File size: 10,318 Bytes
db57927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import modules.scripts as scripts
import gradio as gr

from modules import images
from modules.processing import process_images
from modules.shared import opts
import numpy as np


class Script(scripts.Script):

    def title(self):
        return "txt2palette"

    def show(self, is_img2img):
        return not is_img2img

    def ui(self, is_img2img):
        palette_size = gr.Slider(minimum=1, maximum=64, step=1, value=0,
            label="Palette size")
        method = gr.Radio(choices=['Median cut', 'KMeans'], value='Median cut', label='Palette extraction method')
        sort_by = gr.Radio(choices=["luminance", "hue", "saturation", "value", "lightness"], value="luminance", label="Sort colors by")
        overwrite = gr.Checkbox(False, label="Overwrite existing files")
        return [palette_size, method, sort_by, overwrite]

    def run(self, p, palette_size, method, sort_by, overwrite):
        import colorsys
        from PIL import Image
        try:
            from sklearn.cluster import KMeans
        except ImportError:
            if method == 'KMeans':
                print('"sklearn" library is not installed, switching the extraction method to Median cut.')
                method = "Median cut"

        class Color:
            luminance_weights = np.array([0.2126, 0.7152, 0.0722])

            def __init__(self, RGB, frequency):
                self.rgb = tuple([c for c in RGB])
                self.freq = frequency

            def display(self, w=50, h=50):
                """

                Displays the represented color in a w x h window.

                :param w: width in pixels

                :param h: height in pixels

                """

                img = Image.new("RGB", size=(w, h), color=self.rgb)
                img.show()

            def __lt__(self, other):
                return self.freq < other.freq

            def get_colors(self, colorspace="rgb"):
                """

                Get the color in terms of a colorspace (string).

                :param colorspace: rgb/hsv/hls

                :return: corresponding color values

                """
                colors = {"rgb": self.rgb, "hsv": self.hsv, "hls": self.hls}
                return colors[colorspace]

            @property
            def hsv(self):
                return colorsys.rgb_to_hsv(*self.rgb)

            @property
            def hls(self):
                return colorsys.rgb_to_hls(*self.rgb)

            @property
            def luminance(self):
                return np.dot(self.luminance_weights, self.rgb)

        class ColorBox:
            """

            Represents a box in the RGB color space, with associated attributes, used in the Median Cut algorithm.

            """
            def __init__(self, colors):
                """

                Initialize with a numpy array of RGB colors.

                :param colors: np.ndarray (width * height, 3)

                """

                self.colors = colors
                self._get_min_max()

            def _get_min_max(self):
                min_channel = np.min(self.colors, axis=0)
                max_channel = np.max(self.colors, axis=0)

                self.min_channel = min_channel
                self.max_channel = max_channel

            def __lt__(self, other):
                """

                Compare cubes by volume

                :param other:

                """
                return self.size < other.size

            @property
            def size(self):
                return self.volume

            def _get_dominant_channel(self):
                dominant_channel = np.argmax(self.max_channel - self.min_channel)
                return dominant_channel

            @property
            def average(self):
                """

                Returns the average color contained in ColorBox

                :return: [R, G, B]

                """

                return np.mean(self.colors, axis=0)

            @property
            def volume(self):
                return np.prod(
                    self.max_channel - self.min_channel,
                )

            def split(self):
                """

                Splits the ColorBox into two ColorBoxes at the median of the dominant color channel.

                :return: [ColorBox1, ColorBox2]

                """

                # get the color channel with highest range
                dominant_channel = self._get_dominant_channel()

                # sorting colors by the dominant channel
                self.colors = self.colors[self.colors[:, dominant_channel].argsort()]

                median_index = len(self.colors) // 2

                return [
                    ColorBox(self.colors[:median_index]),
                    ColorBox(self.colors[median_index:]),
                ]

        class Palette:
            def __init__(self, colors):
                """

                Initializes a color palette with a list of Color objects.

                :param colors: a list of Color-objects

                """

                self.colors = colors
                self.frequencies = [c.freq for c in colors]
                self.number_of_colors = len(colors)

            def get_image(self, w=50, h=50):
                img = Image.new("RGB", size=(w * self.number_of_colors, h))
                arr = np.asarray(img).copy()
                for i in range(self.number_of_colors):
                    c = self.colors[i]
                    arr[:, i * h : (i + 1) * h, :] = c.rgb
                img = Image.fromarray(arr, "RGB")
                return img

        def k_means_extraction(arr, height, width, palette_size):
            """

            Extracts a color palette using KMeans.

            :param arr: pixel array (height, width, 3)

            :param height: height

            :param width: width

            :param palette_size: number of colors

            :return: a palette of colors sorted by frequency

            """
            arr = np.reshape(arr, (width * height, -1))
            model = KMeans(n_clusters=palette_size)
            labels = model.fit_predict(arr)
            palette = np.array(model.cluster_centers_, dtype=int)
            color_count = np.bincount(labels)
            color_frequency = color_count / float(np.sum(color_count))
            colors = []
            for color, freq in zip(palette, color_frequency):
                colors.append(Color(color, freq))
            return colors
        
        def median_cut_extraction(arr, height, width, palette_size):
            """

            Extracts a color palette using the median cut algorithm.

            :param arr:

            :param height:

            :param width:

            :param palette_size:

            :return:

            """
            arr = arr.reshape((width * height, -1))
            c = [ColorBox(arr)]
            full_box_size = c[0].size
            # Each iteration, find the largest box, split it, remove original box from list of boxes, and add the two new boxes.
            while len(c) < palette_size:
                largest_c_idx = np.argmax(c)
                # add the two new boxes to the list, while removing the split box.
                c = c[:largest_c_idx] + c[largest_c_idx].split() + c[largest_c_idx + 1 :]
            colors = [Color(map(int, box.average), box.size / full_box_size) for box in c]
            return colors

        sort_methods = {
            "luminance": lambda c: c.luminance,
            "hue": lambda c: c.hsv[0],
            "saturation": lambda c: c.hsv[1],
            "value": lambda c: c.hsv[2],
            "lightness": lambda c: c.hls[2],
        }

        def extract_colors(image, palette_size=5, resize=True, mode="Median cut", sort_mode=None):
            """

            Extracts a set of 'palette_size' colors from the given image.

            :param image: PIL.Image object of path to Image file

            :param palette_size: number of colors to extract

            :param resize: whether to resize the image before processing, yielding faster results with lower quality

            :param mode: the color quantization algorithm to use. Currently supports K-Means (KM) and Median Cut (MC)

            :param sort_mode: sort colors by luminance, or by frequency

            :return: a list of the extracted colors

            """
            if isinstance(image, Image.Image):
                img = image
            else:
                img = Image.open(image)
            img = img.convert("RGB")
            if resize:
                img = img.resize((256, 256))
            width, height = img.size
            arr = np.asarray(img)

            if mode == "KMeans":
                colors = k_means_extraction(arr, height, width, palette_size)
            elif mode == "Median cut":
                colors = median_cut_extraction(arr, height, width, palette_size)
            else:
                raise NotImplementedError("Extraction mode not implemented!")

            if sort_mode in sort_methods:
                colors.sort(key=sort_methods.get(sort_mode), reverse=False)
            else:
                raise NotImplementedError("Sorting mode not implemented!")
            return Palette(colors)  


        if(not overwrite):
            basename = f"_palette_{palette_size}x"
        else:
            p.do_not_save_samples = True

        proc = process_images(p)

        #do not make palettes out of grids

        if len(proc.images) > 1:
            iter_offset = 1
            iter_num = len(proc.images) - 1
        else:
            iter_offset = 0
            iter_num = 1

        for i in range(iter_num):
            pal = extract_colors(proc.images[i+iter_offset], palette_size=palette_size, sort_mode=sort_by, mode=method)
            proc.images[i+iter_offset] = pal.get_image()

            images.save_image(proc.images[i+iter_offset], p.outpath_samples, basename,
            proc.seed + i, proc.prompt, opts.samples_format, info= proc.info, p=p)

        return proc