multimodalart HF staff commited on
Commit
e889613
1 Parent(s): 95b841a

Update sketch_helper.py

Browse files
Files changed (1) hide show
  1. sketch_helper.py +19 -22
sketch_helper.py CHANGED
@@ -1,7 +1,25 @@
1
  import numpy as np
2
  import cv2
3
  from PIL import Image
4
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def get_high_freq_colors(image):
6
  im = image.getcolors(maxcolors=1024*1024)
7
  sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
@@ -24,27 +42,6 @@ def color_quantization_old(image, n_colors):
24
  labels = np.argmin(dists, axis=1)
25
  return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
26
 
27
- def color_quantization(image, n_colors=8, rounds=1):
28
- h, w = image.shape[:2]
29
- samples = np.zeros([h*w,3], dtype=np.float32)
30
- count = 0
31
-
32
- for x in range(h):
33
- for y in range(w):
34
- samples[count] = image[x][y]
35
- count += 1
36
-
37
- compactness, labels, centers = cv2.kmeans(samples,
38
- n_colors,
39
- None,
40
- (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10000, 0.0001),
41
- rounds,
42
- cv2.KMEANS_RANDOM_CENTERS)
43
-
44
- centers = np.uint8(centers)
45
- res = centers[labels.flatten()]
46
- return res.reshape((image.shape))
47
-
48
  def create_binary_matrix(img_arr, target_color):
49
  # Create mask of pixels with target color
50
  mask = np.all(img_arr == target_color, axis=-1)
 
1
  import numpy as np
2
  import cv2
3
  from PIL import Image
4
+ from skimage.color import rgb2lab
5
+ from skimage.color import lab2rgb
6
+
7
+ def color_quantization(image, n_colors):
8
+ # Convert image to LAB color space
9
+ lab_image = rgb2lab(image)
10
+ # Reshape image to 2D array of pixels
11
+ pixels = lab_image.reshape(-1, 3)
12
+ # Perform K-means clustering
13
+ kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(pixels)
14
+ # Replace each pixel with the closest color
15
+ labels = kmeans.predict(pixels)
16
+ colors = kmeans.cluster_centers_
17
+ quantized_pixels = colors[labels]
18
+ # Convert quantized image back to RGB color space
19
+ quantized_lab_image = quantized_pixels.reshape(lab_image.shape)
20
+ quantized_rgb_image = lab2rgb(quantized_lab_image)
21
+ return (quantized_rgb_image * 255).astype(np.uint8)
22
+
23
  def get_high_freq_colors(image):
24
  im = image.getcolors(maxcolors=1024*1024)
25
  sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
 
42
  labels = np.argmin(dists, axis=1)
43
  return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def create_binary_matrix(img_arr, target_color):
46
  # Create mask of pixels with target color
47
  mask = np.all(img_arr == target_color, axis=-1)