Spaces:
Runtime error
Runtime error
from skimage.color import rgb2lab, deltaE_cie76 | |
import numpy as np | |
import cv2 | |
from sklearn.cluster import MiniBatchKMeans | |
def quantize_global(image, k): | |
k_means = MiniBatchKMeans(n_clusters=k, compute_labels=False) | |
k_means.fit(image.reshape(-1, 1)) | |
labels = k_means.predict(image.reshape(-1, 1)) | |
return np.uint8(k_means.cluster_centers_[labels]).reshape(image.shape) | |
def color_transfer(image, old_colors_hex, new_colors_hex): | |
# Load the image | |
# image = cv2.imread(image_path) | |
# Convert from BGR to RGB | |
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Convert image to LAB color space for more accurate color comparison | |
image_array = np.asarray(image) | |
image_quantized = quantize_global(image_array, 16) | |
image_np = np.array(image_quantized) | |
image_lab = rgb2lab(image_np) | |
# image_lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
# Convert old and new colors to LAB for comparison | |
old_colors_lab = [rgb2lab(np.uint8(np.asarray([[list(int(old_color_hex.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))]]))) for old_color_hex in old_colors_hex] | |
new_colors_rgb = [np.uint8(list(int(new_color_hex.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))) for new_color_hex in new_colors_hex] | |
# Create a copy of the image to recolor | |
recolored_image = image.copy() | |
# Define a tolerance for color matching | |
tolerance = 15.0 # Delta E value | |
# For each old color, find the pixels that match and replace them with the new color | |
for old_color_lab, new_color_rgb in zip(old_colors_lab, new_colors_rgb): | |
# Create a mask for pixels matching the old color within the tolerance | |
mask = deltaE_cie76(old_color_lab, image_lab) < tolerance | |
# Replace the colors in the image | |
recolored_image[mask] = new_color_rgb | |
# return recolored_image | |
# Convert the recolored image back to BGR color space | |
recolored_image_bgr = cv2.cvtColor(recolored_image, cv2.COLOR_RGB2BGR) | |
# Save the recolored image | |
# recolored_image_path = '/content/recolored_image.png' | |
cv2.imwrite('result.jpg', recolored_image_bgr) | |
# return recolored_image_path | |
def recolor(source, old_colors, new_colors): | |
recolored = color_transfer(source, old_colors, new_colors) | |
return recolored |