multimodalart HF staff commited on
Commit
f797ad8
1 Parent(s): 7b90989

Create sketch_helper.py

Browse files
Files changed (1) hide show
  1. sketch_helper.py +38 -0
sketch_helper.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image
4
+
5
+ def get_high_freq_colors(image):
6
+ im = image.getcolors(maxcolors=1024*1024)
7
+ sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
8
+
9
+ freqs = [c[0] for c in sorted_colors]
10
+ mean_freq = sum(freqs) / len(freqs)
11
+
12
+ high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
13
+ return high_freq_colors
14
+
15
+ def color_quantization(image, n_colors):
16
+ # Get color histogram
17
+ hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
18
+ # Get most frequent colors
19
+ colors = np.argwhere(hist > 0)
20
+ colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
21
+ colors = colors[:n_colors]
22
+ # Replace each pixel with the closest color
23
+ dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
24
+ labels = np.argmin(dists, axis=1)
25
+ return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
26
+
27
+ def create_binary_matrix(img_arr, target_color):
28
+ # Create mask of pixels with target color
29
+ mask = np.all(img_arr == target_color, axis=-1)
30
+
31
+ # Convert mask to binary matrix
32
+ binary_matrix = mask.astype(int)
33
+ from datetime import datetime
34
+ binary_file_name = f'mask-{datetime.now().timestamp()}.png'
35
+ cv2.imwrite(binary_file_name, binary_matrix * 255)
36
+
37
+ #binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0)
38
+ return binary_file_name