File size: 4,786 Bytes
62456b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import matplotlib.pyplot as plt
import cv2
import snowy
import os


def get_resized_image(img, size):
    if len(img.shape) == 2:
        img = np.repeat(np.expand_dims(img, 2), 3, 2)

    if (img.shape[0] < img.shape[1]):
        height = img.shape[0]
        ratio = height / size
        width = int(np.ceil(img.shape[1] / ratio))
        img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
    else:
        width = img.shape[1]
        ratio = width / size
        height = int(np.ceil(img.shape[0] / ratio))
        img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
        
    if (img.dtype == 'float32'):
        np.clip(img, 0, 1, out = img)    
        
    return img


def get_sketch_image(img, sketcher, mult_val):
    
    if mult_val:
        sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
    else:
        sketch_image = sketcher.get_sketch_with_resize(img)
        
    return sketch_image


def get_dfm_image(sketch):
    dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
    return dfm_image

def get_sketch(image, sketcher, dfm, mult = None):
    sketch_image = get_sketch_image(image, sketcher, mult)

    dfm_image = None

    if dfm:
        dfm_image = get_dfm_image(sketch_image)

    sketch_image = (sketch_image * 255).astype('uint8')

    if dfm:
        dfm_image = (dfm_image * 255).astype('uint8')

    return sketch_image, dfm_image

def get_sketches(image, sketcher, mult_list, dfm):
    for mult in mult_list:
        yield get_sketch(image, sketcher, dfm, mult)


def create_resized_dataset(source_path, target_path, side_size):
    images = os.listdir(source_path)
    
    for image_name in images:
        
        new_image_name = image_name[:image_name.rfind('.')] + '.png'
        new_path = os.path.join(target_path, new_image_name)
        
        if not os.path.exists(new_path):
            try:
                image = cv2.imread(os.path.join(source_path, image_name))
                
                if image is None:
                    raise Exception()
                
                image = get_resized_image(image, side_size)
               
                cv2.imwrite(new_path, image)                
            except:
                print('Failed to process {}'.format(image_name))
    

def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
    
    images = os.listdir(source_path)
    for image_name in images:
        try:
            image = cv2.imread(os.path.join(source_path, image_name))

            if image is None:
                raise Exception()
            
            for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
                new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
                cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
                
                if dfm:
                    dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
                    cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
            
        except:
            print('Failed to process {}'.format(image_name))
    
    
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
    images = os.listdir(source_path)
    
    color_path = os.path.join(target_path, 'color')
    sketch_path = os.path.join(target_path, 'bw')
    
    if not os.path.exists(color_path):
        os.makedirs(color_path)
        
    if not os.path.exists(sketch_path):
        os.makedirs(sketch_path)
        
    for image_name in images:
        new_image_name = image_name[:image_name.rfind('.')] + '.png'
        
        try:
            image = cv2.imread(os.path.join(source_path, image_name))
            
            if image is None:
                raise Exception()
                
            resized_image = get_resized_image(image, side_size)
            cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
            
            for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
                new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
                cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
                
                if dfm:
                    dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
                    cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
                    
        except:
            print('Failed to process {}'.format(image_name))