vettorazi commited on
Commit
397c2d2
1 Parent(s): 8b6c774

improvement in the Recolor OTA algorithim

Browse files
Files changed (1) hide show
  1. recolorOTAlgo.py +69 -68
recolorOTAlgo.py CHANGED
@@ -1,52 +1,44 @@
1
  import numpy as np
2
  import cv2
3
- import ot
4
  from PIL import Image
5
 
6
- def transfer_channel(source_channel, target_channel):
7
- source_hist, _ = np.histogram(source_channel, bins=256, range=(0, 256))
8
- target_hist, _ = np.histogram(target_channel, bins=256, range=(0, 256))
9
 
10
- source_hist = source_hist.astype(np.float64) / source_hist.sum()
11
- target_hist = target_hist.astype(np.float64) / target_hist.sum()
12
 
13
- r = np.arange(256).reshape((-1, 1))
14
- c = np.arange(256).reshape((1, -1))
15
- M = (r - c) ** 2
16
- M = M / M.max()
17
 
18
- P = ot.emd(source_hist, target_hist, M)
19
-
20
- transferred_channel = np.zeros_like(source_channel)
21
- for i in range(256):
22
- transferred_channel[source_channel == i] = P[i].argmax()
23
-
24
- return transferred_channel
25
 
26
- def optimal_transport_color_transfer(source, target):
27
- source_lab = cv2.cvtColor(source, cv2.COLOR_BGR2Lab).astype(np.float64)
28
- target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2Lab).astype(np.float64)
29
 
30
- # Transfer all channels L, a, and b
31
- for ch in range(3):
32
- source_lab[:, :, ch] = transfer_channel(source_lab[:, :, ch], target_lab[:, :, ch])
33
 
34
- transferred_rgb = cv2.cvtColor(source_lab.astype(np.uint8), cv2.COLOR_Lab2BGR)
35
- return transferred_rgb
36
-
37
 
38
- def rgb_to_hex(rgb):
39
- return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
40
 
41
- def hex_to_rgb(hex):
42
- hex = hex.lstrip('#')
43
- return tuple(int(hex[i:i+2], 16) for i in (0, 2, 4))
44
 
 
 
 
45
 
46
  def create_color_palette(colors, palette_width=800, palette_height=200):
47
- """
48
- Receives a list of colors in hex format and creates a palette image
49
- """
50
  pixels = []
51
  n_colors = len(colors)
52
  for i in range(n_colors):
@@ -54,44 +46,53 @@ def create_color_palette(colors, palette_width=800, palette_height=200):
54
  for j in range(palette_width//n_colors * palette_height):
55
  pixels.append(color)
56
  img = Image.new('RGB', (palette_height, palette_width))
57
- img.putdata(pixels)
58
- # img.show()
59
  return img
60
 
61
-
62
- # if __name__ == "__main__":
63
- # source = cv2.imread("estampa-test.png")
64
- # target = cv2.imread("color.png")
65
- # transferred = optimal_transport_color_transfer(source, target)
66
-
67
- # smooth = False#test with true to have a different result.
68
- # if smooth:
69
- # # Apply bilateral filtering
70
- # diameter = 30 # diameter of each pixel neighborhood, adjust based on your image size
71
- # sigma_color = 25 # larger value means colors farther to each other will mix together
72
- # sigma_space = 25 # larger values means farther pixels will influence each other if their colors are close enough
73
- # smoothed = cv2.bilateralFilter(transferred, diameter, sigma_color, sigma_space)
74
- # cv2.imwrite("result_OTA.jpg", smoothed)
75
- # else:
76
- # cv2.imwrite("result_OTA.jpg", transferred)
77
-
78
-
79
- def recolor(source, colors):
80
  pallete_img = create_color_palette(colors)
81
  palette_bgr = cv2.cvtColor(np.array(pallete_img), cv2.COLOR_RGB2BGR)
82
- recolored = optimal_transport_color_transfer(source, palette_bgr)
83
- smooth = True#test with true for different results.
84
- if smooth:
85
- # Apply bilateral filtering
86
- diameter = 10 # diameter of each pixel neighborhood, adjust based on your image size
87
- sigma_color = 25 # larger value means colors farther to each other will mix together
88
- sigma_space = 15 # larger values means farther pixels will influence each other if their colors are close enough
89
- smoothed = cv2.bilateralFilter(recolored, diameter, sigma_color, sigma_space)
90
- recoloredFile = cv2.imwrite("result.jpg", smoothed, [cv2.IMWRITE_JPEG_QUALITY, 100])
91
- return recoloredFile
92
- else:
93
- recoloredFile = cv2.imwrite("result.jpg", recolored)
94
- return recoloredFile
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
 
1
  import numpy as np
2
  import cv2
 
3
  from PIL import Image
4
 
5
+ def color_transfer(source, target):
6
+ source = cv2.cvtColor(source, cv2.COLOR_BGR2Lab).astype("float32")
7
+ target = cv2.cvtColor(target, cv2.COLOR_BGR2Lab).astype("float32")
8
 
9
+ lMeanSrc, aMeanSrc, bMeanSrc = (source[..., i].mean() for i in range(3))
10
+ lStdSrc, aStdSrc, bStdSrc = (source[..., i].std() for i in range(3))
11
 
12
+ lMeanTar, aMeanTar, bMeanTar = (target[..., i].mean() for i in range(3))
13
+ lStdTar, aStdTar, bStdTar = (target[..., i].std() for i in range(3))
 
 
14
 
15
+ (l, a, b) = cv2.split(source)
16
+ l -= lMeanSrc
17
+ a -= aMeanSrc
18
+ b -= bMeanSrc
 
 
 
19
 
20
+ l = (lStdTar / lStdSrc) * l
21
+ a = (aStdTar / aStdSrc) * a
22
+ b = (bStdTar / bStdSrc) * b
23
 
24
+ l += lMeanTar
25
+ a += aMeanTar
26
+ b += bMeanTar
27
 
28
+ l = np.clip(l, 0, 255)
29
+ a = np.clip(a, 0, 255)
30
+ b = np.clip(b, 0, 255)
31
 
32
+ transfer = cv2.merge([l, a, b])
33
+ transfer = cv2.cvtColor(transfer.astype("uint8"), cv2.COLOR_Lab2BGR)
34
 
35
+ return transfer
 
 
36
 
37
+ def hex_to_rgb(hex_val):
38
+ hex_val = hex_val.lstrip('#')
39
+ return tuple(int(hex_val[i:i+2], 16) for i in (0, 2, 4))
40
 
41
  def create_color_palette(colors, palette_width=800, palette_height=200):
 
 
 
42
  pixels = []
43
  n_colors = len(colors)
44
  for i in range(n_colors):
 
46
  for j in range(palette_width//n_colors * palette_height):
47
  pixels.append(color)
48
  img = Image.new('RGB', (palette_height, palette_width))
49
+ img.putdata(pixels)
 
50
  return img
51
 
52
+ def recolor_statistical(source, colors):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  pallete_img = create_color_palette(colors)
54
  palette_bgr = cv2.cvtColor(np.array(pallete_img), cv2.COLOR_RGB2BGR)
55
+ recolored = color_transfer(source, palette_bgr)
56
+
57
+ # Adjust bilateral filtering parameters
58
+ diameter = 15
59
+ sigma_color = 30
60
+ sigma_space = 20
61
+ smoothed = cv2.bilateralFilter(recolored, diameter, sigma_color, sigma_space)
62
+
63
+ # Save image at maximum quality
64
+ filename = "result.jpg"
65
+ cv2.imwrite(filename, smoothed, [cv2.IMWRITE_JPEG_QUALITY, 100])
66
+ return filename
 
67
 
68
+ # Usage
69
+ # source_img = cv2.imread("path_to_your_source_image.jpg")
70
+ # hexcolors = ["#db5a1e", "#555115", "#9a690e", "#1f3a19", "#da8007",
71
+ # "#9a0633", "#b70406", "#d01b4b", "#e20b0f", "#f7515d"]
72
+ # result_path = recolor_statistical(source_img, hexcolors)
73
+
74
+ def recolor(source, colors):
75
+ # pallete_img = create_color_palette(colors)
76
+ # palette_bgr = cv2.cvtColor(np.array(pallete_img), cv2.COLOR_RGB2BGR)
77
+ # recolored = optimal_transport_color_transfer(source, palette_bgr)
78
+ # smooth = True#test with true for different results.
79
+ # if smooth:
80
+ # # Apply bilateral filtering
81
+ # diameter = 10 # diameter of each pixel neighborhood, adjust based on your image size
82
+ # sigma_color = 25 # larger value means colors farther to each other will mix together
83
+ # sigma_space = 15 # larger values means farther pixels will influence each other if their colors are close enough
84
+ # smoothed = cv2.bilateralFilter(recolored, diameter, sigma_color, sigma_space)
85
+ # recoloredFile = cv2.imwrite("result.jpg", smoothed, [cv2.IMWRITE_JPEG_QUALITY, 100])
86
+ # return recoloredFile
87
+ # else:
88
+ # recoloredFile = cv2.imwrite("result.jpg", recolored)
89
+ # return recoloredFile
90
+ source = source.astype(np.uint8)
91
+ source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
92
+
93
+ source_img = source
94
+ hexcolors = ["#db5a1e", "#555115", "#9a690e", "#1f3a19", "#da8007",
95
+ "#9a0633", "#b70406", "#d01b4b", "#e20b0f", "#f7515d"]
96
+ result_path = recolor_statistical(source_img, colors)
97
 
98