farm-recolor / recolorPaletteBasedTransfer.py
vettorazi's picture
copied local files. docker initial setup
52cbb9c
raw
history blame
No virus
4.03 kB
#https://gfx.cs.princeton.edu/pubs/Chang_2015_PPR/index.php
import numpy as np
import cv2
from sklearn.cluster import KMeans
from PIL import Image
def extract_palette(image, n_colors=5):
"""
Extracts dominant colors from the image using KMeans clustering.
Parameters:
- image: np.array, the input image
- n_colors: int, the number of colors to extract
Returns:
- palette: list, dominant colors in the image
"""
# Reshape the image to be a list of pixels
pixels = image.reshape(-1, 3)
# Apply KMeans clustering to find the most dominant colors
kmeans = KMeans(n_clusters=n_colors)
kmeans.fit(pixels)
# Get the RGB values of the clusters' centers
palette = kmeans.cluster_centers_
return palette
def map_colors(source_img, source_palette, target_palette):
"""
Maps colors from the source palette to the target palette.
Parameters:
- source_img: np.array, the source image
- source_palette: list, the source color palette
- target_palette: list, the target color palette
Returns:
- recolored_img: np.array, the recolored image
"""
recolored_img = np.copy(source_img)
for i in range(source_img.shape[0]):
for j in range(source_img.shape[1]):
# Find the nearest color in the source palette
distances = np.linalg.norm(source_palette - source_img[i, j], axis=1)
closest_idx = np.argmin(distances)
# Replace with the corresponding color from the target palette
recolored_img[i, j] = target_palette[closest_idx]
return recolored_img
def palette_based_color_transfer(source_img, target_palette, n_colors=5):
"""
Performs palette based color transfer.
Parameters:
- source_img: np.array, the source image
- target_palette: list, the target color palette
- n_colors: int, the number of colors in the source palette (default is 5)
Returns:
- recolored_img: np.array, the recolored image
"""
# Convert the source image to RGB
source_img = cv2.cvtColor(source_img, cv2.COLOR_BGR2RGB)
# Extract the source palette
source_palette = extract_palette(source_img, n_colors)
# Perform color mapping
recolored_img = map_colors(source_img, source_palette, target_palette)
return recolored_img
def create_color_palette(colors, palette_width=800, palette_height=200):
"""
Receives a list of colors in hex format and creates a palette image
"""
pixels = []
n_colors = len(colors)
for i in range(n_colors):
color = hex_to_rgb(colors[i])
for j in range(palette_width//n_colors * palette_height):
pixels.append(tuple(color)) # Convert color to tuple
img = Image.new('RGB', (palette_height, palette_width))
img.putdata(pixels)
# img.show()
return img
def hex_to_rgb(hex_color):
return [int(hex_color[i:i+2], 16) for i in (1, 3, 5)]
# Example usage:
# source_img_path = "estampa.jpg"
# # target_palette = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255]]
# hex_colors = ["#db5a1e", "#555115", "#9a690e", "#1f3a19", "#da8007", "#9a0633", "#b70406", "#d01b4b", "#e20b0f", "#f7515d"]
# target_palette = [hex_to_rgb(color) for color in hex_colors]
# recolored_img = palette_based_color_transfer(source_img_path, target_palette, 10)
# cv2.imwrite("recolored_image.jpg", cv2.cvtColor(recolored_img, cv2.COLOR_RGB2BGR))
def recolor(source, colors):
palette_img = create_color_palette(colors)
palette_bgr = cv2.cvtColor(np.array(palette_img), cv2.COLOR_RGB2BGR)
target_palette = [hex_to_rgb(color) for color in colors]
source_bgr = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
# No need to convert source to BGR, assume it's already in RGB
recolored = palette_based_color_transfer(source_bgr, target_palette, 10)
cv2.imwrite("result.jpg", cv2.cvtColor(recolored, cv2.COLOR_RGB2BGR))
return recolored