ColorExtractor / app.py
eienmojiki's picture
Upload app.py with huggingface_hub
de22fe7 verified
import gradio as gr
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
def extract_colors(img, num_colors):
if img is None:
return "<p>No image uploaded.</p>"
# Resize image for faster processing
img_resized = img.resize((150, 150))
data = np.array(img_resized)
data = data.reshape(-1, 3)
# Remove any grayscale or alpha if present
if data.shape[1] == 4:
data = data[:, :3]
elif data.shape[1] == 1:
# If grayscale, replicate to RGB
data = np.repeat(data, 3, axis=1)
# Fit KMeans
kmeans = KMeans(n_clusters=num_colors, random_state=42)
kmeans.fit(data)
# Get cluster centers
colors = kmeans.cluster_centers_.round().astype(int)
# Calculate percentages
labels = kmeans.labels_
percentages = [(np.sum(labels == i) / len(labels)) * 100 for i in range(num_colors)]
# Convert to hex
hex_colors = ['#' + ''.join(f'{c:02x}' for c in color) for color in colors]
# Generate HTML for palette
html = '<div style="text-align: center;"><h2>Color Palette</h2><div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 10px;">'
for hex_color, percent in zip(hex_colors, percentages):
html += f'''
<div style="display: flex; flex-direction: column; align-items: center; gap: 5px;">
<div style="background-color: {hex_color}; width: 80px; height: 80px; border: 2px solid #333; border-radius: 10px;"></div>
<span style="font-size: 12px; font-weight: bold;">{hex_color} ({percent:.1f}%)</span>
</div>
'''
html += '</div></div>'
return html
with gr.Blocks(title="Image Color Palette Extractor") as demo:
gr.Markdown("# Image Color Palette Extractor")
gr.Markdown("Upload an image to extract the main colors and generate a palette.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Upload Image")
num_colors = gr.Slider(2, 12, value=6, step=1, label="Number of Colors")
extract_btn = gr.Button("Extract Colors", variant="primary")
with gr.Column(scale=2):
palette_output = gr.HTML()
extract_btn.click(
extract_colors,
inputs=[input_img, num_colors],
outputs=palette_output
)
input_img.change(
extract_colors,
inputs=[input_img, num_colors],
outputs=palette_output
)
demo.queue()
demo.launch()