|
import modules.scripts as scripts
|
|
import gradio as gr
|
|
|
|
from modules import images
|
|
from modules.processing import process_images
|
|
from modules.shared import opts
|
|
import numpy as np
|
|
|
|
|
|
class Script(scripts.Script):
|
|
|
|
def title(self):
|
|
return "txt2palette"
|
|
|
|
def show(self, is_img2img):
|
|
return not is_img2img
|
|
|
|
def ui(self, is_img2img):
|
|
palette_size = gr.Slider(minimum=1, maximum=64, step=1, value=0,
|
|
label="Palette size")
|
|
method = gr.Radio(choices=['Median cut', 'KMeans'], value='Median cut', label='Palette extraction method')
|
|
sort_by = gr.Radio(choices=["luminance", "hue", "saturation", "value", "lightness"], value="luminance", label="Sort colors by")
|
|
overwrite = gr.Checkbox(False, label="Overwrite existing files")
|
|
return [palette_size, method, sort_by, overwrite]
|
|
|
|
def run(self, p, palette_size, method, sort_by, overwrite):
|
|
import colorsys
|
|
from PIL import Image
|
|
try:
|
|
from sklearn.cluster import KMeans
|
|
except ImportError:
|
|
if method == 'KMeans':
|
|
print('"sklearn" library is not installed, switching the extraction method to Median cut.')
|
|
method = "Median cut"
|
|
|
|
class Color:
|
|
luminance_weights = np.array([0.2126, 0.7152, 0.0722])
|
|
|
|
def __init__(self, RGB, frequency):
|
|
self.rgb = tuple([c for c in RGB])
|
|
self.freq = frequency
|
|
|
|
def display(self, w=50, h=50):
|
|
"""
|
|
Displays the represented color in a w x h window.
|
|
:param w: width in pixels
|
|
:param h: height in pixels
|
|
"""
|
|
|
|
img = Image.new("RGB", size=(w, h), color=self.rgb)
|
|
img.show()
|
|
|
|
def __lt__(self, other):
|
|
return self.freq < other.freq
|
|
|
|
def get_colors(self, colorspace="rgb"):
|
|
"""
|
|
Get the color in terms of a colorspace (string).
|
|
:param colorspace: rgb/hsv/hls
|
|
:return: corresponding color values
|
|
"""
|
|
colors = {"rgb": self.rgb, "hsv": self.hsv, "hls": self.hls}
|
|
return colors[colorspace]
|
|
|
|
@property
|
|
def hsv(self):
|
|
return colorsys.rgb_to_hsv(*self.rgb)
|
|
|
|
@property
|
|
def hls(self):
|
|
return colorsys.rgb_to_hls(*self.rgb)
|
|
|
|
@property
|
|
def luminance(self):
|
|
return np.dot(self.luminance_weights, self.rgb)
|
|
|
|
class ColorBox:
|
|
"""
|
|
Represents a box in the RGB color space, with associated attributes, used in the Median Cut algorithm.
|
|
"""
|
|
def __init__(self, colors):
|
|
"""
|
|
Initialize with a numpy array of RGB colors.
|
|
:param colors: np.ndarray (width * height, 3)
|
|
"""
|
|
|
|
self.colors = colors
|
|
self._get_min_max()
|
|
|
|
def _get_min_max(self):
|
|
min_channel = np.min(self.colors, axis=0)
|
|
max_channel = np.max(self.colors, axis=0)
|
|
|
|
self.min_channel = min_channel
|
|
self.max_channel = max_channel
|
|
|
|
def __lt__(self, other):
|
|
"""
|
|
Compare cubes by volume
|
|
:param other:
|
|
"""
|
|
return self.size < other.size
|
|
|
|
@property
|
|
def size(self):
|
|
return self.volume
|
|
|
|
def _get_dominant_channel(self):
|
|
dominant_channel = np.argmax(self.max_channel - self.min_channel)
|
|
return dominant_channel
|
|
|
|
@property
|
|
def average(self):
|
|
"""
|
|
Returns the average color contained in ColorBox
|
|
:return: [R, G, B]
|
|
"""
|
|
|
|
return np.mean(self.colors, axis=0)
|
|
|
|
@property
|
|
def volume(self):
|
|
return np.prod(
|
|
self.max_channel - self.min_channel,
|
|
)
|
|
|
|
def split(self):
|
|
"""
|
|
Splits the ColorBox into two ColorBoxes at the median of the dominant color channel.
|
|
:return: [ColorBox1, ColorBox2]
|
|
"""
|
|
|
|
|
|
dominant_channel = self._get_dominant_channel()
|
|
|
|
|
|
self.colors = self.colors[self.colors[:, dominant_channel].argsort()]
|
|
|
|
median_index = len(self.colors) // 2
|
|
|
|
return [
|
|
ColorBox(self.colors[:median_index]),
|
|
ColorBox(self.colors[median_index:]),
|
|
]
|
|
|
|
class Palette:
|
|
def __init__(self, colors):
|
|
"""
|
|
Initializes a color palette with a list of Color objects.
|
|
:param colors: a list of Color-objects
|
|
"""
|
|
|
|
self.colors = colors
|
|
self.frequencies = [c.freq for c in colors]
|
|
self.number_of_colors = len(colors)
|
|
|
|
def get_image(self, w=50, h=50):
|
|
img = Image.new("RGB", size=(w * self.number_of_colors, h))
|
|
arr = np.asarray(img).copy()
|
|
for i in range(self.number_of_colors):
|
|
c = self.colors[i]
|
|
arr[:, i * h : (i + 1) * h, :] = c.rgb
|
|
img = Image.fromarray(arr, "RGB")
|
|
return img
|
|
|
|
def k_means_extraction(arr, height, width, palette_size):
|
|
"""
|
|
Extracts a color palette using KMeans.
|
|
:param arr: pixel array (height, width, 3)
|
|
:param height: height
|
|
:param width: width
|
|
:param palette_size: number of colors
|
|
:return: a palette of colors sorted by frequency
|
|
"""
|
|
arr = np.reshape(arr, (width * height, -1))
|
|
model = KMeans(n_clusters=palette_size)
|
|
labels = model.fit_predict(arr)
|
|
palette = np.array(model.cluster_centers_, dtype=int)
|
|
color_count = np.bincount(labels)
|
|
color_frequency = color_count / float(np.sum(color_count))
|
|
colors = []
|
|
for color, freq in zip(palette, color_frequency):
|
|
colors.append(Color(color, freq))
|
|
return colors
|
|
|
|
def median_cut_extraction(arr, height, width, palette_size):
|
|
"""
|
|
Extracts a color palette using the median cut algorithm.
|
|
:param arr:
|
|
:param height:
|
|
:param width:
|
|
:param palette_size:
|
|
:return:
|
|
"""
|
|
arr = arr.reshape((width * height, -1))
|
|
c = [ColorBox(arr)]
|
|
full_box_size = c[0].size
|
|
|
|
while len(c) < palette_size:
|
|
largest_c_idx = np.argmax(c)
|
|
|
|
c = c[:largest_c_idx] + c[largest_c_idx].split() + c[largest_c_idx + 1 :]
|
|
colors = [Color(map(int, box.average), box.size / full_box_size) for box in c]
|
|
return colors
|
|
|
|
sort_methods = {
|
|
"luminance": lambda c: c.luminance,
|
|
"hue": lambda c: c.hsv[0],
|
|
"saturation": lambda c: c.hsv[1],
|
|
"value": lambda c: c.hsv[2],
|
|
"lightness": lambda c: c.hls[2],
|
|
}
|
|
|
|
def extract_colors(image, palette_size=5, resize=True, mode="Median cut", sort_mode=None):
|
|
"""
|
|
Extracts a set of 'palette_size' colors from the given image.
|
|
:param image: PIL.Image object of path to Image file
|
|
:param palette_size: number of colors to extract
|
|
:param resize: whether to resize the image before processing, yielding faster results with lower quality
|
|
:param mode: the color quantization algorithm to use. Currently supports K-Means (KM) and Median Cut (MC)
|
|
:param sort_mode: sort colors by luminance, or by frequency
|
|
:return: a list of the extracted colors
|
|
"""
|
|
if isinstance(image, Image.Image):
|
|
img = image
|
|
else:
|
|
img = Image.open(image)
|
|
img = img.convert("RGB")
|
|
if resize:
|
|
img = img.resize((256, 256))
|
|
width, height = img.size
|
|
arr = np.asarray(img)
|
|
|
|
if mode == "KMeans":
|
|
colors = k_means_extraction(arr, height, width, palette_size)
|
|
elif mode == "Median cut":
|
|
colors = median_cut_extraction(arr, height, width, palette_size)
|
|
else:
|
|
raise NotImplementedError("Extraction mode not implemented!")
|
|
|
|
if sort_mode in sort_methods:
|
|
colors.sort(key=sort_methods.get(sort_mode), reverse=False)
|
|
else:
|
|
raise NotImplementedError("Sorting mode not implemented!")
|
|
return Palette(colors)
|
|
|
|
|
|
if(not overwrite):
|
|
basename = f"_palette_{palette_size}x"
|
|
else:
|
|
p.do_not_save_samples = True
|
|
|
|
proc = process_images(p)
|
|
|
|
|
|
|
|
if len(proc.images) > 1:
|
|
iter_offset = 1
|
|
iter_num = len(proc.images) - 1
|
|
else:
|
|
iter_offset = 0
|
|
iter_num = 1
|
|
|
|
for i in range(iter_num):
|
|
pal = extract_colors(proc.images[i+iter_offset], palette_size=palette_size, sort_mode=sort_by, mode=method)
|
|
proc.images[i+iter_offset] = pal.get_image()
|
|
|
|
images.save_image(proc.images[i+iter_offset], p.outpath_samples, basename,
|
|
proc.seed + i, proc.prompt, opts.samples_format, info= proc.info, p=p)
|
|
|
|
return proc |