n1kkqt commited on
Commit
6e7b2f8
1 Parent(s): 8ff6b46
README.md CHANGED
@@ -1,12 +1,22 @@
1
- ---
2
- title: Line Art Colorization
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.17.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
+ # Line Art Colorization
2
+ This project is a minimalistic implementation of AlacGan and it is based on the paper called User-Guided Deep Anime Line Art Colorization with Conditional Adversarial Networks (https://arxiv.org/pdf/1808.03240.pdf) as well as its github repository.
3
+
4
+ ## Colab example to play with the model (just run!)
5
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jInaIELLo-Y1M8MnIgA-7aXSWRYlHnC9?usp=sharing)
6
+
7
+ ## Differences from the original implementation
8
+ 1. Less variants of line thickness (as it did not make the model performance significantly worse)
9
+ 2. Different images in the dataset
10
+ 3. No local features network added
11
+ 4. All image pairs are acquired via the xdog algorithm whereas in the paper, real line art images were also used to train the model
12
+
13
+ Because of these differences, the results are slightly worse but the model was trained significantly faster and the process of collecting data did not take too long.
14
+
15
+ ## Model weights
16
+ https://download938.mediafire.com/nd1xp1xdgitg/aig8n36f4vrne6t/gen_373000.pth
17
+
18
+ ## Colorization examples
19
+ All these images were colorized by the alacgan neural network
20
+
21
+ ![Results of colorization](https://i.imgur.com/qngw4BI.png)
22
+
data_utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +
2
+ import os
3
+ import math
4
+ import random
5
+ import numbers
6
+ import requests
7
+ import shutil
8
+ import numpy as np
9
+ import scipy.stats as stats
10
+ from PIL import Image
11
+ from tqdm.auto import tqdm
12
+
13
+ from xdog import to_sketch
14
+ # -
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.utils.data as data
19
+ from torch.utils.data.sampler import Sampler
20
+
21
+ from torchvision import transforms
22
+ from torchvision.transforms import Resize, CenterCrop
23
+
24
+ mu, sigma = 1, 0.005
25
+ X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
26
+
27
+ denormalize = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
28
+ std = [ 1/0.5, 1/0.5, 1/0.5 ]),
29
+ transforms.Normalize(mean = [ -0.5, -0.5, -0.5 ],
30
+ std = [ 1., 1., 1. ]),])
31
+
32
+ etrans = transforms.Compose([
33
+ transforms.ToTensor(),
34
+ transforms.Normalize((0.5), (0.5))
35
+ ])
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ def predict_img(gen, sk, hnt = None):
40
+ #sk = Image.open(sketch_path).convert('L')
41
+ sk = etrans(sk)
42
+
43
+ pad_w = 16 - sk.shape[1] % 16 if sk.shape[1] % 16 != 0 else 0
44
+ pad_h = 16 - sk.shape[2] % 16 if sk.shape[2] % 16 != 0 else 0
45
+ pad = nn.ZeroPad2d((pad_h, 0, pad_w, 0))
46
+ sk = pad(sk)
47
+
48
+ sk = sk.unsqueeze(0)
49
+ sk = sk.to(device)
50
+
51
+ if hnt == None:
52
+ hnt = torch.zeros((1, 4, sk.shape[2]//4, sk.shape[3]//4))
53
+
54
+ hnt = hnt.to(device)
55
+
56
+ img_gen = gen(sk, hnt, sketch_feat=None).squeeze(0)
57
+ img_gen = denormalize(img_gen) * 255
58
+ img_gen = img_gen.permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
59
+ #return img_gen[pad_w:, pad_h:]
60
+ return Image.fromarray(img_gen[pad_w:, pad_h:])
61
+
62
+ def files(img_path, img_size=512):
63
+ img_path = os.path.abspath(img_path)
64
+ line_widths = sorted([el for el in os.listdir(os.path.join(img_path, 'pics_sketch')) if el != '.ipynb_checkpoints'])
65
+ images_names = sorted([el for el in os.listdir(os.path.join(img_path, 'pics_sketch', line_widths[0])) if '.jpg' in el])
66
+
67
+ images_names = [el for el in images_names if np.all(np.array(Image.open(os.path.join(img_path, 'pics', el)).size) >= np.array([img_size, img_size]))]
68
+
69
+ images_color = [os.path.join(img_path, 'pics', el) for el in images_names]
70
+ images_sketch = {line_width:[os.path.join(img_path, 'pics_sketch', line_width, el) for el in images_names] for line_width in line_widths}
71
+ return images_color, images_sketch
72
+
73
+ def mask_gen(img_size=512, bs=4):
74
+ maskS = img_size // 4
75
+
76
+ mask1 = torch.cat([torch.rand(1, 1, maskS, maskS).ge(X.rvs(1)[0]).float() for _ in range(bs // 2)], 0)
77
+ mask2 = torch.cat([torch.zeros(1, 1, maskS, maskS).float() for _ in range(bs // 2)], 0)
78
+ mask = torch.cat([mask1, mask2], 0)
79
+ return mask
80
+
81
+ def jitter(x):
82
+ ran = random.uniform(0.7, 1)
83
+ return x * ran + 1 - ran
84
+
85
+ def make_trans(img_size):
86
+ vtrans = transforms.Compose([
87
+ RandomSizedCrop(img_size // 4, Image.BICUBIC),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
90
+ ])
91
+
92
+ ctrans = transforms.Compose([
93
+ transforms.Resize(img_size, Image.BICUBIC),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
96
+ ])
97
+
98
+ strans = transforms.Compose([
99
+ transforms.Resize(img_size, Image.BICUBIC),
100
+ transforms.ToTensor(),
101
+ transforms.Lambda(jitter),
102
+ transforms.Normalize((0.5), (0.5))
103
+ ])
104
+
105
+ return vtrans, ctrans, strans
106
+
107
+ class RandomCrop(object):
108
+ """Crops the given PIL.Image at a random location to have a region of
109
+ the given size. size can be a tuple (target_height, target_width)
110
+ or an integer, in which case the target will be of a square shape (size, size)
111
+ """
112
+
113
+ def __init__(self, size):
114
+ if isinstance(size, numbers.Number):
115
+ self.size = (int(size), int(size))
116
+ else:
117
+ self.size = size
118
+
119
+ def __call__(self, img1, img2):
120
+ w, h = img1.size
121
+ th, tw = self.size
122
+ if w == tw and h == th: # ValueError: empty range for randrange() (0,0, 0)
123
+ return img1, img2
124
+
125
+ if w == tw:
126
+ x1 = 0
127
+ y1 = random.randint(0, h - th)
128
+ return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th))
129
+
130
+ elif h == th:
131
+ x1 = random.randint(0, w - tw)
132
+ y1 = 0
133
+ return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th))
134
+
135
+ else:
136
+ x1 = random.randint(0, w - tw)
137
+ y1 = random.randint(0, h - th)
138
+ return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th))
139
+
140
+ class RandomSizedCrop(object):
141
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
142
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
143
+ This is popularly used to train the Inception networks
144
+ size: size of the smaller edge
145
+ interpolation: Default: PIL.Image.BILINEAR
146
+ """
147
+
148
+ def __init__(self, size, interpolation=Image.BICUBIC):
149
+ self.size = size
150
+ self.interpolation = interpolation
151
+
152
+ def __call__(self, img):
153
+ for attempt in range(10):
154
+ area = img.size[0] * img.size[1]
155
+ target_area = random.uniform(0.9, 1.) * area
156
+ aspect_ratio = random.uniform(7. / 8, 8. / 7)
157
+
158
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
159
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
160
+
161
+ if random.random() < 0.5:
162
+ w, h = h, w
163
+
164
+ if w <= img.size[0] and h <= img.size[1]:
165
+ x1 = random.randint(0, img.size[0] - w)
166
+ y1 = random.randint(0, img.size[1] - h)
167
+
168
+ img = img.crop((x1, y1, x1 + w, y1 + h))
169
+ assert (img.size == (w, h))
170
+
171
+ return img.resize((self.size, self.size), self.interpolation)
172
+
173
+ # Fallback
174
+ Resize = Resize(self.size, interpolation=self.interpolation)
175
+ crop = CenterCrop(self.size)
176
+ return crop(Resize(img))
177
+
178
+
179
+ class ImageFolder(data.Dataset):
180
+ def __init__(self, img_path, img_size):
181
+
182
+ self.images_color, self.images_sketch = files(img_path, img_size)
183
+ if (any([self.images_sketch[key] == 0 for key in self.images_sketch])) or (len(self.images_color) == 0):
184
+ raise (RuntimeError("Found 0 images in one of the folders."))
185
+ if any([len(self.images_sketch[key]) != len(self.images_color) for key in self.images_sketch]):
186
+ raise (RuntimeError("The number of sketches is not equal to the number of colorized images."))
187
+ self.img_path = img_path
188
+ self.img_size = img_size
189
+ self.vtrans, self.ctrans, self.strans = make_trans(img_size)
190
+
191
+ def __getitem__(self, index):
192
+ color = Image.open(self.images_color[index]).convert('RGB')
193
+
194
+ random_line_width = random.choice(list(self.images_sketch.keys()))
195
+ sketch = Image.open(self.images_sketch[random_line_width][index]).convert('L')
196
+ #the image can be smaller than img_size, fix!
197
+ color, sketch = RandomCrop(self.img_size)(color, sketch)
198
+ if random.random() < 0.5:
199
+ color, sketch = color.transpose(Image.FLIP_LEFT_RIGHT), sketch.transpose(Image.FLIP_LEFT_RIGHT)
200
+
201
+ color, color_down, sketch = self.ctrans(color), self.vtrans(color), self.strans(sketch)
202
+
203
+ return color, color_down, sketch
204
+
205
+ def __len__(self):
206
+ return len(self.images_color)
207
+
208
+
209
+ class GivenIterationSampler(Sampler):
210
+ def __init__(self, dataset, total_iter, batch_size, diter, last_iter=-1):
211
+ self.dataset = dataset
212
+ self.total_iter = total_iter
213
+ self.batch_size = batch_size
214
+ self.diter = diter
215
+ self.last_iter = last_iter
216
+
217
+ self.total_size = self.total_iter * self.batch_size * (self.diter + 1)
218
+
219
+ self.indices = self.gen_new_list()
220
+ self.call = 0
221
+
222
+
223
+ def __iter__(self):
224
+ #if self.call == 0:
225
+ #self.call = 1
226
+ return iter(self.indices[(self.last_iter + 1) * self.batch_size * (self.diter + 1):])
227
+ #else:
228
+ # raise RuntimeError("this sampler is not designed to be called more than once!!")
229
+
230
+ def gen_new_list(self):
231
+ # each process shuffle all list with same seed
232
+ np.random.seed(0)
233
+
234
+ indices = np.arange(len(self.dataset))
235
+ indices = indices[:self.total_size]
236
+ num_repeat = (self.total_size - 1) // indices.shape[0] + 1
237
+ indices = np.tile(indices, num_repeat)
238
+ indices = indices[:self.total_size]
239
+
240
+ np.random.shuffle(indices)
241
+ assert len(indices) == self.total_size
242
+ return indices
243
+
244
+ def __len__(self):
245
+ # note here we do not take last iter into consideration, since __len__
246
+ # should only be used for displaying, the correct remaining size is
247
+ # handled by dataloader
248
+ # return self.total_size - (self.last_iter+1)*self.batch_size
249
+ return self.total_size
250
+
251
+ def get_dataloader(img_path, img_size=512, seed=0, total_iter=250000, bs=4, diters=1, last_iter=-1):
252
+
253
+ random.seed(seed)
254
+
255
+ train_dataset = ImageFolder(img_path=img_path, img_size=img_size)
256
+
257
+ train_sampler = GivenIterationSampler(train_dataset, total_iter, bs, diters, last_iter=last_iter)
258
+
259
+ return data.DataLoader(train_dataset, batch_size=bs, shuffle=False, pin_memory=True, num_workers=4, sampler=train_sampler)
260
+
261
+
262
+ def get_data(links, img_path='alacgan_data', line_widths=[0.3, 0.5]):
263
+ c = 0
264
+
265
+ for line_width in line_widths:
266
+ lw = str(line_width)
267
+ if lw not in os.listdir(os.path.join(img_path, 'pics_sketch')):
268
+ os.mkdir(os.path.join(img_path, 'pics_sketch', lw))
269
+ else:
270
+ shutil.rmtree(os.path.join(img_path, 'pics_sketch', lw))
271
+ os.mkdir(os.path.join(img_path, 'pics_sketch', lw))
272
+
273
+ for link in tqdm(links):
274
+ img_orig = Image.open(requests.get(link, stream=True).raw).convert('RGB')
275
+ img_orig.save(os.path.join(img_path, 'pics', str(c) + '.jpg'), 'JPEG')
276
+ for line_width in line_widths:
277
+ sketch_test = to_sketch(img_orig, sigma=line_width, k=5, gamma=0.96, epsilon=-1, phi=10e15, area_min=2)
278
+ sketch_test.save(os.path.join(img_path, 'pics_sketch', str(line_width), str(c) + '.jpg'), 'JPEG')
279
+
280
+ c += 1
examples/Genshin-Impact-anime.jpg ADDED
examples/Screenshot 2023-02-07 140049.jpg ADDED
examples/anime2.jpg ADDED
gradio_app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from xdog import to_sketch
3
+ from model import Generator, ResNeXtBottleneck
4
+ import torch
5
+ from data_utils import *
6
+ import glob
7
+ gen = torch.load('model/model.pth')
8
+
9
+ def convert_to_lineart(img, sigma, k, gamma, epsilon, phi, area_min):
10
+ phi = 10 * phi
11
+ out = to_sketch(img, sigma=sigma, k=k, gamma=gamma, epsilon=epsilon, phi=phi, area_min=area_min)
12
+ return out
13
+
14
+ def inference(sk):
15
+ return predict_img(gen, sk, hnt = None)
16
+
17
+ title = "To Line Art"
18
+ description = "Line art colorization showcase. "
19
+ article = "Github Repo"
20
+
21
+ with gr.Blocks() as demo:
22
+ with gr.Row():
23
+ with gr.Column():
24
+ image = gr.Image(type="pil", value='examples/Genshin-Impact-anime.jpg')
25
+ to_lineart_button = gr.Button("To Lineart")
26
+
27
+ gr.Examples(
28
+ examples=glob.glob('examples/*.jpg'),
29
+ inputs=image,
30
+ outputs=image,
31
+ fn=None,
32
+ cache_examples=False,
33
+ )
34
+
35
+ with gr.Column():
36
+ sigma = gr.Slider(0.1, 0.5, value=0.3, step=0.1, label='σ')
37
+ k = gr.Slider(1.0, 8.0, value=4.5, step=0.5, label='k')
38
+ gamma = gr.Slider(0.05, 1.0, value=0.95, step=0.05, label='γ')
39
+ epsilon = gr.Slider(-2, 2, value=-1, step=0.5, label='ε')
40
+ phi = gr.Slider(10, 20, label = 'φ', value=15)
41
+ min_area = gr.Slider(1, 5, value=2, step=1, label='Minimal Area')
42
+
43
+ with gr.Column():
44
+ lineart = gr.Image(type="pil", image_mode='L')
45
+ inpaint_button = gr.Button("Inpaint")
46
+
47
+ to_lineart_button.click(convert_to_lineart, inputs=[image, sigma, k, gamma, epsilon, phi, min_area], outputs=lineart)
48
+ inpaint_button.click(inference, inputs=lineart, outputs=lineart)
49
+
50
+ demo.launch()
model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResNeXtBottleneck(nn.Module):
6
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
7
+ super(ResNeXtBottleneck, self).__init__()
8
+
9
+ D = out_channels // 2
10
+ self.out_channels = out_channels
11
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
12
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate, groups=cardinality, bias=False)
13
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
14
+ self.shortcut = nn.Sequential()
15
+
16
+ if stride != 1:
17
+ self.shortcut.add_module('shortcut', nn.AvgPool2d(2, stride=2))
18
+
19
+ def forward(self, x):
20
+ bottleneck = self.conv_reduce.forward(x)
21
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
22
+ bottleneck = self.conv_conv.forward(bottleneck)
23
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
24
+ bottleneck = self.conv_expand.forward(bottleneck)
25
+ x = self.shortcut.forward(x)
26
+
27
+ return x + bottleneck
28
+
29
+
30
+ class Generator(nn.Module):
31
+ def __init__(self, ngf=64, feat=True):
32
+ super(Generator, self).__init__()
33
+ self.feat = feat
34
+ if feat:
35
+ add_channels = 512
36
+ else:
37
+ add_channels = 0
38
+
39
+ self.toH = self._block(4, ngf, kernel_size=7, stride=1, padding=3)
40
+ self.to0 = self._block(1, ngf // 2, kernel_size=3, stride=1, padding=1)
41
+ self.to1 = self._block(ngf // 2, ngf, kernel_size=4, stride=2, padding=1)
42
+ self.to2 = self._block(ngf, ngf * 2, kernel_size=4, stride=2, padding=1)
43
+ self.to3 = self._block(ngf * 3, ngf * 4, kernel_size=4, stride=2, padding=1)
44
+ self.to4 = self._block(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1)
45
+
46
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
47
+
48
+ self.tunnel4 = nn.Sequential(self._block(ngf * 8 + add_channels, ngf * 8, kernel_size=3, stride=1, padding=1),
49
+ tunnel4,
50
+ nn.Conv2d(ngf * 8, ngf * 16, kernel_size = 3, stride=1, padding=1),
51
+ nn.PixelShuffle(2),
52
+ nn.LeakyReLU(0.2, True))
53
+
54
+ depth = 2
55
+
56
+ tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
57
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
58
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
59
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
60
+ ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
61
+ tunnel3 = nn.Sequential(*tunnel)
62
+
63
+ self.tunnel3 = nn.Sequential(self._block(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
64
+ tunnel3,
65
+ nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=1, padding=1),
66
+ nn.PixelShuffle(2),
67
+ nn.LeakyReLU(0.2, True))
68
+
69
+ tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
70
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
71
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
72
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
73
+ ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
74
+ tunnel2 = nn.Sequential(*tunnel)
75
+
76
+ self.tunnel2 = nn.Sequential(self._block(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
77
+ tunnel2,
78
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
79
+ nn.PixelShuffle(2),
80
+ nn.LeakyReLU(0.2, True))
81
+
82
+ tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
83
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
84
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
85
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
86
+ ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
87
+ tunnel1 = nn.Sequential(*tunnel)
88
+
89
+ self.tunnel1 = nn.Sequential(self._block(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
90
+ tunnel1,
91
+ nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
92
+ nn.PixelShuffle(2),
93
+ nn.LeakyReLU(0.2, True))
94
+
95
+ self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
96
+
97
+ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
98
+ return nn.Sequential(
99
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
100
+ nn.LeakyReLU(0.2, True)
101
+ )
102
+
103
+
104
+ def forward(self, sketch, hint, sketch_feat):
105
+ hint = self.toH(hint)
106
+
107
+ x0 = self.to0(sketch)
108
+ x1 = self.to1(x0)
109
+ x2 = self.to2(x1)
110
+ x3 = self.to3(torch.cat([x2, hint], 1))
111
+ x4 = self.to4(x3)
112
+
113
+ if self.feat:
114
+ x = self.tunnel4(torch.cat([x4, sketch_feat], 1))
115
+ x = self.tunnel3(torch.cat([x, x3], 1))
116
+ x = self.tunnel2(torch.cat([x, x2], 1))
117
+ x = self.tunnel1(torch.cat([x, x1], 1))
118
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
119
+ else:
120
+ x = self.tunnel4(x4)
121
+ x = self.tunnel3(torch.cat([x, x3], 1))
122
+ x = self.tunnel2(torch.cat([x, x2], 1))
123
+ x = self.tunnel1(torch.cat([x, x1], 1))
124
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
125
+ return x
126
+
127
+
128
+ class Discriminator(nn.Module):
129
+ def __init__(self, ndf=64, feat=True):
130
+ super(Discriminator, self).__init__()
131
+ self.feat = feat
132
+
133
+ if feat:
134
+ add_channels = ndf * 8
135
+ ks = 4
136
+ else:
137
+ add_channels = 0
138
+ ks = 3
139
+
140
+ self.feed = nn.Sequential(
141
+ self._block(3, ndf, kernel_size=7, stride=1, padding=1),
142
+ self._block(ndf, ndf, kernel_size=4, stride=2, padding=1),
143
+
144
+ ResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
145
+ ResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2),
146
+ self._block(ndf, ndf * 2, kernel_size=1, stride=1, padding=0),
147
+
148
+ ResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
149
+ ResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2),
150
+ self._block(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0),
151
+
152
+ ResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
153
+ ResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2)
154
+ )
155
+
156
+ self.feed2 = nn.Sequential(
157
+ self._block(ndf * 4 + add_channels, ndf * 8, kernel_size=3, stride=1, padding=1),
158
+
159
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
160
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
161
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
162
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
163
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
164
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
165
+ ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
166
+
167
+ self._block(ndf * 8, ndf * 8, kernel_size=ks, stride=1, padding=0),
168
+ )
169
+
170
+ self.out = nn.Linear(512, 1)
171
+
172
+ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
173
+ return nn.Sequential(
174
+ nn.Conv2d(in_channels,
175
+ out_channels,
176
+ kernel_size,
177
+ stride,
178
+ padding,
179
+ bias=False),
180
+ nn.LeakyReLU(0.2, True)
181
+ )
182
+
183
+ def forward(self, color, sketch_feat=None):
184
+ x = self.feed(color)
185
+
186
+ if self.feat:
187
+ x = self.feed2(torch.cat([x, sketch_feat], 1))
188
+ else:
189
+ x = self.feed2(x)
190
+
191
+ out = self.out(x.view(color.size(0), -1))
192
+ return out
model/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d8483543bbc2fcae2d86a89422dc7be566c239ba99b6b889a7dc29a289a01c
3
+ size 78392739
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv
xdog.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import requests
5
+
6
+ # Difference of Gaussians applied to img input
7
+ def dog(img,size=(0,0),k=1.6,sigma=0.5,gamma=1):
8
+ img1 = cv2.GaussianBlur(img,size,sigma)
9
+ img2 = cv2.GaussianBlur(img,size,sigma*k)
10
+ return (img1-gamma*img2)
11
+
12
+ # Threshold the dog image, with dog(sigma,k) > 0 ? 1(255):0(0)
13
+ def edge_dog(img,sigma=0.5,k=200,gamma=0.98):
14
+ aux = dog(img,sigma=sigma,k=k,gamma=0.98)
15
+ for i in range(0,aux.shape[0]):
16
+ for j in range(0,aux.shape[1]):
17
+ if(aux[i,j] > 0):
18
+ aux[i,j] = 255
19
+ else:
20
+ aux[i,j] = 0
21
+ return aux
22
+
23
+ # garygrossi xdog version
24
+ def xdog_garygrossi(img,sigma=0.5,k=200, gamma=0.98,epsilon=0.1,phi=10):
25
+ aux = dog(img,sigma=sigma,k=k,gamma=gamma)/255
26
+ for i in range(0,aux.shape[0]):
27
+ for j in range(0,aux.shape[1]):
28
+ if(aux[i,j] >= epsilon):
29
+ aux[i,j] = 1
30
+ else:
31
+ ht = np.tanh(phi*(aux[i][j] - epsilon))
32
+ aux[i][j] = 1 + ht
33
+ return aux*255
34
+
35
+ def hatchBlend(image):
36
+ xdogImage = xdog(image,sigma=1,k=200, gamma=0.5,epsilon=-0.5,phi=10)
37
+ hatchTexture = cv2.imread('./imgs/hatch.jpg', cv2.CV_LOAD_IMAGE_GRAYSCALE)
38
+ hatchTexture = cv2.resize(hatchTexture,(image.shape[1],image.shape[0]))
39
+ alpha = 0.120
40
+ return (1-alpha)*xdogImage + alpha*hatchTexture
41
+
42
+ # version of xdog inspired by article
43
+ def xdog(img,sigma=0.5,k=1.6, gamma=1,epsilon=1,phi=1):
44
+ aux = dog(img,sigma=sigma,k=k,gamma=gamma)/255
45
+ for i in range(0,aux.shape[0]):
46
+ for j in range(0,aux.shape[1]):
47
+ if(aux[i,j] < epsilon):
48
+ aux[i,j] = 1*255
49
+ else:
50
+ aux[i,j] = 255*(1 + np.tanh(phi*(aux[i,j])))
51
+ return aux
52
+
53
+ def to_sketch(img_orig, sigma=0.3, k=4.5, gamma=0.95, epsilon=-1, phi=10e15, area_min=2):
54
+ img_cnts = []
55
+ img = cv2.cvtColor(np.array(img_orig), cv2.COLOR_RGB2GRAY)
56
+ img_xdog = xdog(img, sigma=sigma, k=k, gamma=gamma, epsilon=epsilon, phi=phi).astype(np.uint8)
57
+ new_img = np.zeros_like(img_xdog)
58
+ thresh = cv2.threshold(img_xdog, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
59
+ cnts = cv2.findContours(thresh.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
60
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
61
+ for c in cnts:
62
+ area = cv2.contourArea(c)
63
+ if area > area_min:
64
+ img_cnts.append(c)
65
+
66
+ return Image.fromarray(255 - cv2.drawContours(new_img, img_cnts, -1, (255,255,255), -1))