ColorHarmonization / utils.py
kargaranamir's picture
add color harmonization
dc4014d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Color Harmonization utility functions.
Some Codes are imported and adopted from https://github.com/tartarskunk/ColorHarmonization
"""
# Import Libraries
import cv2
import numpy as np
import matplotlib.pyplot as plt
import io
# Constants
HueTemplates = {
"i": [(0.00, 0.05)],
"V": [(0.00, 0.26)],
"L": [(0.00, 0.05), (0.25, 0.22)],
"mirror_L": [(0.00, 0.05), (-0.25, 0.22)],
"I": [(0.00, 0.05), (0.50, 0.05)],
"T": [(0.25, 0.50)],
"Y": [(0.00, 0.26), (0.50, 0.05)],
"X": [(0.00, 0.26), (0.50, 0.26)],
}
template_types = list(HueTemplates.keys())
M = len(template_types)
A = 360
def deg_distance(a, b):
d1 = np.abs(a - b)
d2 = np.abs(360 - d1)
d = np.minimum(d1, d2)
return d
def normalized_gaussian(X, mu, S):
X = np.asarray(X).astype(np.float64)
S = np.asarray(S).astype(np.float64)
D = np.deg2rad(X - mu)
S = np.deg2rad(S)
D2 = np.multiply(D, D)
S2 = np.multiply(S, S)
return np.exp(-D2 / (2 * S2))
class HueSector:
def __init__(self, center, width):
# In Degree [0,2 pi)
self.center = center
self.width = width
self.border = [(self.center - self.width / 2), (self.center + self.width / 2)]
def is_in_sector(self, H):
# True/False matrix if hue resides in the sector
return deg_distance(H, self.center) < self.width / 2
def distance_to_border(self, H):
H_1 = deg_distance(H, self.border[0])
H_2 = deg_distance(H, self.border[1])
H_dist2bdr = np.minimum(H_1, H_2)
return H_dist2bdr
def closest_border(self, H):
H_1 = deg_distance(H, self.border[0])
H_2 = deg_distance(H, self.border[1])
H_cls_bdr = np.argmin((H_1, H_2), axis=0)
H_cls_bdr = 2 * (H_cls_bdr - 0.5)
return H_cls_bdr
def distance_to_center(self, H):
H_dist2ctr = deg_distance(H, self.center)
return H_dist2ctr
class HarmonicScheme:
def __init__(self, m, alpha):
self.m = m
self.alpha = alpha
self.reset_sectors()
def reset_sectors(self):
self.sectors = []
for t in HueTemplates[self.m]:
center = t[0] * 360 + self.alpha
width = t[1] * 360
sector = HueSector(center, width)
self.sectors.append(sector)
def harmony_score(self, X):
# Opencv store H as [0, 180) --> [0, 360)
H = X[:, :, 0].astype(np.int32) * 2
# Opencv store S as [0, 255] --> [0, 1]
S = X[:, :, 1].astype(np.float32) / 255.0
H_dis = self.hue_distance(H)
H_dis = np.deg2rad(H_dis)
return np.sum(np.multiply(H_dis, S))
def hue_distance(self, H):
H_dis = []
for i in range(len(self.sectors)):
sector = self.sectors[i]
H_dis.append(sector.distance_to_border(H))
H_dis[i][sector.is_in_sector(H)] = 0
H_dis = np.asarray(H_dis)
H_dis = H_dis.min(axis=0)
return H_dis
def hue_shifted(self, X, num_superpixels=-1):
Y = X.copy()
H = X[:, :, 0].astype(np.int32) * 2
S = X[:, :, 1].astype(np.float32) / 255.0
H_d2b = [sector.distance_to_border(H) for sector in self.sectors]
H_d2b = np.asarray(H_d2b)
H_cls = np.argmin(H_d2b, axis=0)
if num_superpixels != -1:
SEEDS = cv2.ximgproc.createSuperpixelSEEDS(X.shape[1], X.shape[0], X.shape[2], num_superpixels, 10)
SEEDS.iterate(X, 4)
V = np.zeros(H.shape).reshape(-1)
N = V.shape[0]
H_ctr = np.zeros((H.shape))
grid_num = SEEDS.getNumberOfSuperpixels()
labels = SEEDS.getLabels()
for i in range(grid_num):
P = [[], []]
s = np.average(H_cls[labels == i])
if s > 0.5:
s = 1
else:
s = 0
H_cls[labels == i] = s
H_ctr = np.zeros((H.shape))
H_wid = np.zeros((H.shape))
H_d2c = np.zeros((H.shape))
H_dir = np.zeros((H.shape))
for i in range(len(self.sectors)):
sector = self.sectors[i]
mask = (H_cls == i)
H_ctr[mask] = sector.center
H_wid[mask] = sector.width
H_dir += sector.closest_border(H) * mask
H_dist2ctr = sector.distance_to_center(H)
H_d2c += H_dist2ctr * mask
H_sgm = H_wid / 2
H_gau = normalized_gaussian(H_d2c, 0, H_sgm)
H_tmp = np.multiply(H_wid / 2, 1 - H_gau)
H_shf = np.multiply(H_dir, H_tmp)
H_new = (H_ctr + H_shf).astype(np.int32)
for i in range(len(self.sectors)):
sector = self.sectors[i]
mask = sector.is_in_sector(H)
np.copyto(H_new, H, where=sector.is_in_sector(H))
H_new = np.remainder(H_new, 360)
H_new = (H_new / 2).astype(np.uint8)
Y[:, :, 0] = H_new
return Y
def count_hue_histogram(X):
N = 360
H = X[:, :, 0].astype(np.int32) * 2
S = X[:, :, 1].astype(np.float64) / 255.0
H_flat = H.flatten()
S_flat = S.flatten()
histo = np.zeros(N)
for i in range(len(H_flat)):
histo[H_flat[i]] += S_flat[i]
return histo
def plothis(hue_histo, harmonic_scheme, caption: str):
N = 360
# Compute pie slices
theta = np.linspace(0.0, 2 * np.pi, N, endpoint=False)
width = np.pi / 180
# Compute colors, RGB values for the hue
hue_colors = np.zeros((N, 4))
for i in range(hue_colors.shape[0]):
color_HSV = np.zeros((1, 1, 3), dtype=np.uint8)
color_HSV[0, 0, :] = [int(i / 2), 255, 255]
color_BGR = cv2.cvtColor(color_HSV, cv2.COLOR_HSV2BGR)
B = int(color_BGR[0, 0, 0]) / 255.0
G = int(color_BGR[0, 0, 1]) / 255.0
R = int(color_BGR[0, 0, 2]) / 255.0
hue_colors[i] = (R, G, B, 1.0)
# Compute colors, for the shadow
shadow_colors = np.zeros((N, 4))
for i in range(shadow_colors.shape[0]):
shadow_colors[i] = (0.0, 0.0, 0.0, 1.0)
# Create hue, guidline and shadow arrays
hue_histo = hue_histo.astype(float)
hue_histo_msx = float(np.max(hue_histo))
if hue_histo_msx != 0.0:
hue_histo /= np.max(hue_histo)
guide_histo = np.array([0.05] * N)
shadow_histo = np.array([0.0] * N)
# Compute angels of shadow, template types
for sector in harmonic_scheme.sectors:
sector_center = sector.center
sector_width = sector.width
end = int((sector_center + sector_width / 2) % 360)
start = int((sector_center - sector_width / 2) % 360)
if start < end:
shadow_histo[start: end] = 1.0
else:
shadow_histo[start: 360] = 1.0
shadow_histo[0: end] = 1.0
# Plot, 1280 * 800
fig = plt.figure(figsize=(3.2, 4))
ax = fig.add_subplot(111, projection='polar')
# add hue histogram
ax.bar(theta, hue_histo, width=width, bottom=0.0, color=hue_colors, alpha=1.0)
# add guidline
ax.bar(theta, guide_histo, width=width, bottom=1.0, color=hue_colors, alpha=1.0)
# add shadow angels for the template types
ax.bar(theta, shadow_histo, width=width, bottom=0.0, color=shadow_colors, alpha=0.1)
ax.set_title(caption, pad=15)
plt.close()
return fig
# https://stackoverflow.com/questions/7821518/matplotlib-save-plot-to-numpy-array
def get_img_from_fig(fig, dpi=100):
"""
a function which returns an image as numpy array from figure
"""
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
return img