Spaces:
Runtime error
Runtime error
#https://gfx.cs.princeton.edu/pubs/Chang_2015_PPR/index.php | |
import numpy as np | |
import cv2 | |
from sklearn.cluster import KMeans | |
from PIL import Image | |
def extract_palette(image, n_colors=5): | |
""" | |
Extracts dominant colors from the image using KMeans clustering. | |
Parameters: | |
- image: np.array, the input image | |
- n_colors: int, the number of colors to extract | |
Returns: | |
- palette: list, dominant colors in the image | |
""" | |
# Reshape the image to be a list of pixels | |
pixels = image.reshape(-1, 3) | |
# Apply KMeans clustering to find the most dominant colors | |
kmeans = KMeans(n_clusters=n_colors, n_init='auto') | |
kmeans.fit(pixels) | |
# Get the RGB values of the clusters' centers | |
palette = kmeans.cluster_centers_ | |
return palette | |
def map_colors(source_img, source_palette, target_palette): | |
""" | |
Maps colors from the source palette to the target palette. | |
Parameters: | |
- source_img: np.array, the source image | |
- source_palette: list, the source color palette | |
- target_palette: list, the target color palette | |
Returns: | |
- recolored_img: np.array, the recolored image | |
""" | |
recolored_img = np.copy(source_img) | |
for i in range(source_img.shape[0]): | |
for j in range(source_img.shape[1]): | |
# Find the nearest color in the source palette | |
distances = np.linalg.norm(source_palette - source_img[i, j], axis=1) | |
closest_idx = np.argmin(distances) | |
# Replace with the corresponding color from the target palette | |
recolored_img[i, j] = target_palette[closest_idx] | |
return recolored_img | |
def palette_based_color_transfer(source_img, target_palette, n_colors=5): | |
""" | |
Performs palette based color transfer. | |
Parameters: | |
- source_img: np.array, the source image | |
- target_palette: list, the target color palette | |
- n_colors: int, the number of colors in the source palette (default is 5) | |
Returns: | |
- recolored_img: np.array, the recolored image | |
""" | |
# Convert the source image to RGB | |
source_img = cv2.cvtColor(source_img, cv2.COLOR_BGR2RGB) | |
# Extract the source palette | |
source_palette = extract_palette(source_img, n_colors) | |
# Ensure target_palette has the same number of colors as source_palette | |
while len(target_palette) < n_colors: | |
target_palette += target_palette | |
# Perform color mapping | |
recolored_img = map_colors(source_img, source_palette, target_palette[:n_colors]) | |
return recolored_img | |
def create_color_palette(colors, palette_width=800, palette_height=200): | |
""" | |
Receives a list of colors in hex format and creates a palette image | |
""" | |
pixels = [] | |
n_colors = len(colors) | |
for i in range(n_colors): | |
color = hex_to_rgb(colors[i]) | |
for j in range(palette_width//n_colors * palette_height): | |
pixels.append(tuple(color)) # Convert color to tuple | |
img = Image.new('RGB', (palette_height, palette_width)) | |
img.putdata(pixels) | |
# img.show() | |
return img | |
def hex_to_rgb(hex_color): | |
return [int(hex_color[i:i+2], 16) for i in (1, 3, 5)] | |
# Example usage: | |
# source_img_path = "estampa.jpg" | |
# # target_palette = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255]] | |
# hex_colors = ["#db5a1e", "#555115", "#9a690e", "#1f3a19", "#da8007", "#9a0633", "#b70406", "#d01b4b", "#e20b0f", "#f7515d"] | |
# target_palette = [hex_to_rgb(color) for color in hex_colors] | |
# recolored_img = palette_based_color_transfer(source_img_path, target_palette, 10) | |
# cv2.imwrite("recolored_image.jpg", cv2.cvtColor(recolored_img, cv2.COLOR_RGB2BGR)) | |
def recolor(source, colors): | |
palette_img = create_color_palette(colors) | |
palette_bgr = cv2.cvtColor(np.array(palette_img), cv2.COLOR_RGB2BGR) | |
target_palette = [hex_to_rgb(color) for color in colors] | |
source_bgr = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) | |
# No need to convert source to BGR, assume it's already in RGB | |
recolored = palette_based_color_transfer(source_bgr, target_palette, 10) | |
cv2.imwrite("result.jpg", cv2.cvtColor(recolored, cv2.COLOR_RGB2BGR)) | |
return recolored | |