asafAdge's picture
Duplicate from HansBug/color_clustering
8917604
raw
history blame contribute delete
No virus
3.85 kB
import os
from typing import Tuple
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.cluster import KMeans
def _image_resize(image: Image.Image, pixels: int = 90000, **kwargs):
rt = (image.size[0] * image.size[1] / pixels) ** 0.5
if rt > 1.0:
small_image = image.resize((int(image.size[0] / rt), int(image.size[1] / rt)), **kwargs)
else:
small_image = image.copy()
return small_image
def get_main_colors(image: Image.Image, n: int = 28, pixels: int = 90000) \
-> Tuple[Image.Image, np.ndarray, np.ndarray, np.ndarray]:
image = image.copy()
if image.mode != 'RGB':
image = image.convert('RGB')
small_image = _image_resize(image, pixels)
few_raw = np.asarray(small_image).reshape(-1, 3)
kmeans = KMeans(n_clusters=n)
kmeans.fit(few_raw)
width, height = image.size
raw = np.asarray(image).reshape(-1, 3)
colors = kmeans.cluster_centers_.round().astype(np.uint8)
prediction = kmeans.predict(raw)
new_data = colors[prediction].reshape((height, width, 3))
new_image = Image.fromarray(new_data, mode='RGB')
cids, counts = np.unique(prediction, return_counts=True)
counts = np.asarray(list(map(lambda x: x[1], sorted(zip(cids, counts)))))
return new_image, colors, counts, prediction.reshape((height, width))
def main_func(image: Image.Image, n: int, pixels: int, fixed_width: bool, width: int):
if fixed_width:
_width, _height = image.size
r = width / _width
new_width, new_height = int(round(_width * r)), int(round(_height * r))
image = image.resize((new_width, new_height))
new_image, colors, counts, predictions = get_main_colors(image, n, pixels)
table = pd.DataFrame({
'r': colors[:, 0],
'g': colors[:, 1],
'b': colors[:, 2],
'count': counts,
})
table['ratio'] = table['count'] / table['count'].sum()
hexes = []
for r, g, b in zip(table['r'], table['g'], table['b']):
hexes.append(f'#{r:02x}{g:02x}{b:02x}')
table['hex'] = hexes
new_table = pd.DataFrame({
'Hex': table['hex'],
'Pixels': table['count'],
'Ratio': table['ratio'],
'Red': table['r'],
'Green': table['g'],
'Blue': table['b'],
}).sort_values('Pixels', ascending=False)
return new_image, new_table
if __name__ == '__main__':
pd.set_option("display.precision", 3)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
ch_image = gr.Image(type='pil', label='Original Image')
with gr.Row():
ch_clusters = gr.Slider(value=8, minimum=2, maximum=256, step=2, label='Clusters')
ch_pixels = gr.Slider(value=100000, minimum=10000, maximum=1000000, step=10000,
label='Pixels for Clustering')
ch_fixed_width = gr.Checkbox(value=True, label='Width Fixed')
ch_width = gr.Slider(value=200, minimum=12, maximum=2048, label='Width')
ch_submit = gr.Button(value='Submit', variant='primary')
with gr.Column():
with gr.Tabs():
with gr.Tab('Output Image'):
ch_output = gr.Image(type='pil', label='Output Image')
with gr.Tab('Color Map'):
ch_color_map = gr.Dataframe(
headers=['Hex', 'Pixels', 'Ratio', 'Red', 'Green', 'Blue'],
label='Color Map'
)
ch_submit.click(
main_func,
inputs=[ch_image, ch_clusters, ch_pixels, ch_fixed_width, ch_width],
outputs=[ch_output, ch_color_map],
)
demo.queue(os.cpu_count()).launch()