farm-recolor / recolorOTAlgo.py
vettorazi's picture
improvement in the Recolor OTA algorithim
397c2d2
import numpy as np
import cv2
from PIL import Image
def color_transfer(source, target):
source = cv2.cvtColor(source, cv2.COLOR_BGR2Lab).astype("float32")
target = cv2.cvtColor(target, cv2.COLOR_BGR2Lab).astype("float32")
lMeanSrc, aMeanSrc, bMeanSrc = (source[..., i].mean() for i in range(3))
lStdSrc, aStdSrc, bStdSrc = (source[..., i].std() for i in range(3))
lMeanTar, aMeanTar, bMeanTar = (target[..., i].mean() for i in range(3))
lStdTar, aStdTar, bStdTar = (target[..., i].std() for i in range(3))
(l, a, b) = cv2.split(source)
l -= lMeanSrc
a -= aMeanSrc
b -= bMeanSrc
l = (lStdTar / lStdSrc) * l
a = (aStdTar / aStdSrc) * a
b = (bStdTar / bStdSrc) * b
l += lMeanTar
a += aMeanTar
b += bMeanTar
l = np.clip(l, 0, 255)
a = np.clip(a, 0, 255)
b = np.clip(b, 0, 255)
transfer = cv2.merge([l, a, b])
transfer = cv2.cvtColor(transfer.astype("uint8"), cv2.COLOR_Lab2BGR)
return transfer
def hex_to_rgb(hex_val):
hex_val = hex_val.lstrip('#')
return tuple(int(hex_val[i:i+2], 16) for i in (0, 2, 4))
def create_color_palette(colors, palette_width=800, palette_height=200):
pixels = []
n_colors = len(colors)
for i in range(n_colors):
color = hex_to_rgb(colors[i])
for j in range(palette_width//n_colors * palette_height):
pixels.append(color)
img = Image.new('RGB', (palette_height, palette_width))
img.putdata(pixels)
return img
def recolor_statistical(source, colors):
pallete_img = create_color_palette(colors)
palette_bgr = cv2.cvtColor(np.array(pallete_img), cv2.COLOR_RGB2BGR)
recolored = color_transfer(source, palette_bgr)
# Adjust bilateral filtering parameters
diameter = 15
sigma_color = 30
sigma_space = 20
smoothed = cv2.bilateralFilter(recolored, diameter, sigma_color, sigma_space)
# Save image at maximum quality
filename = "result.jpg"
cv2.imwrite(filename, smoothed, [cv2.IMWRITE_JPEG_QUALITY, 100])
return filename
# Usage
# source_img = cv2.imread("path_to_your_source_image.jpg")
# hexcolors = ["#db5a1e", "#555115", "#9a690e", "#1f3a19", "#da8007",
# "#9a0633", "#b70406", "#d01b4b", "#e20b0f", "#f7515d"]
# result_path = recolor_statistical(source_img, hexcolors)
def recolor(source, colors):
# pallete_img = create_color_palette(colors)
# palette_bgr = cv2.cvtColor(np.array(pallete_img), cv2.COLOR_RGB2BGR)
# recolored = optimal_transport_color_transfer(source, palette_bgr)
# smooth = True#test with true for different results.
# if smooth:
# # Apply bilateral filtering
# diameter = 10 # diameter of each pixel neighborhood, adjust based on your image size
# sigma_color = 25 # larger value means colors farther to each other will mix together
# sigma_space = 15 # larger values means farther pixels will influence each other if their colors are close enough
# smoothed = cv2.bilateralFilter(recolored, diameter, sigma_color, sigma_space)
# recoloredFile = cv2.imwrite("result.jpg", smoothed, [cv2.IMWRITE_JPEG_QUALITY, 100])
# return recoloredFile
# else:
# recoloredFile = cv2.imwrite("result.jpg", recolored)
# return recoloredFile
source = source.astype(np.uint8)
source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
source_img = source
hexcolors = ["#db5a1e", "#555115", "#9a690e", "#1f3a19", "#da8007",
"#9a0633", "#b70406", "#d01b4b", "#e20b0f", "#f7515d"]
result_path = recolor_statistical(source_img, colors)