File size: 2,524 Bytes
396f175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de22fe7
 
 
396f175
 
 
 
 
 
 
 
de22fe7
 
 
 
396f175
 
 
 
 
de22fe7
396f175
521b301
 
de22fe7
396f175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans

def extract_colors(img, num_colors):
    if img is None:
        return "<p>No image uploaded.</p>"
    
    # Resize image for faster processing
    img_resized = img.resize((150, 150))
    data = np.array(img_resized)
    data = data.reshape(-1, 3)
    
    # Remove any grayscale or alpha if present
    if data.shape[1] == 4:
        data = data[:, :3]
    elif data.shape[1] == 1:
        # If grayscale, replicate to RGB
        data = np.repeat(data, 3, axis=1)
    
    # Fit KMeans
    kmeans = KMeans(n_clusters=num_colors, random_state=42)
    kmeans.fit(data)
    
    # Get cluster centers
    colors = kmeans.cluster_centers_.round().astype(int)
    
    # Calculate percentages
    labels = kmeans.labels_
    percentages = [(np.sum(labels == i) / len(labels)) * 100 for i in range(num_colors)]
    
    # Convert to hex
    hex_colors = ['#' + ''.join(f'{c:02x}' for c in color) for color in colors]
    
    # Generate HTML for palette
    html = '<div style="text-align: center;"><h2>Color Palette</h2><div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 10px;">'
    for hex_color, percent in zip(hex_colors, percentages):
        html += f'''
        <div style="display: flex; flex-direction: column; align-items: center; gap: 5px;">
            <div style="background-color: {hex_color}; width: 80px; height: 80px; border: 2px solid #333; border-radius: 10px;"></div>
            <span style="font-size: 12px; font-weight: bold;">{hex_color} ({percent:.1f}%)</span>
        </div>
        '''
    html += '</div></div>'
    return html

with gr.Blocks(title="Image Color Palette Extractor") as demo:
    gr.Markdown("# Image Color Palette Extractor")
    gr.Markdown("Upload an image to extract the main colors and generate a palette.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_img = gr.Image(type="pil", label="Upload Image")
            num_colors = gr.Slider(2, 12, value=6, step=1, label="Number of Colors")
            extract_btn = gr.Button("Extract Colors", variant="primary")
        
        with gr.Column(scale=2):
            palette_output = gr.HTML()
    
    extract_btn.click(
        extract_colors,
        inputs=[input_img, num_colors],
        outputs=palette_output
    )
    
    input_img.change(
        extract_colors,
        inputs=[input_img, num_colors],
        outputs=palette_output
    )

demo.queue()
demo.launch()