Spaces:
Runtime error
Runtime error
n1kkqt
commited on
Commit
•
6e7b2f8
1
Parent(s):
8ff6b46
Add app
Browse files- README.md +22 -12
- data_utils.py +280 -0
- examples/Genshin-Impact-anime.jpg +0 -0
- examples/Screenshot 2023-02-07 140049.jpg +0 -0
- examples/anime2.jpg +0 -0
- gradio_app.py +50 -0
- model.py +192 -0
- model/model.pth +3 -0
- requirements.txt +3 -0
- xdog.py +66 -0
README.md
CHANGED
@@ -1,12 +1,22 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Line Art Colorization
|
2 |
+
This project is a minimalistic implementation of AlacGan and it is based on the paper called User-Guided Deep Anime Line Art Colorization with Conditional Adversarial Networks (https://arxiv.org/pdf/1808.03240.pdf) as well as its github repository.
|
3 |
+
|
4 |
+
## Colab example to play with the model (just run!)
|
5 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jInaIELLo-Y1M8MnIgA-7aXSWRYlHnC9?usp=sharing)
|
6 |
+
|
7 |
+
## Differences from the original implementation
|
8 |
+
1. Less variants of line thickness (as it did not make the model performance significantly worse)
|
9 |
+
2. Different images in the dataset
|
10 |
+
3. No local features network added
|
11 |
+
4. All image pairs are acquired via the xdog algorithm whereas in the paper, real line art images were also used to train the model
|
12 |
+
|
13 |
+
Because of these differences, the results are slightly worse but the model was trained significantly faster and the process of collecting data did not take too long.
|
14 |
+
|
15 |
+
## Model weights
|
16 |
+
https://download938.mediafire.com/nd1xp1xdgitg/aig8n36f4vrne6t/gen_373000.pth
|
17 |
+
|
18 |
+
## Colorization examples
|
19 |
+
All these images were colorized by the alacgan neural network
|
20 |
+
|
21 |
+
![Results of colorization](https://i.imgur.com/qngw4BI.png)
|
22 |
+
|
data_utils.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import numbers
|
6 |
+
import requests
|
7 |
+
import shutil
|
8 |
+
import numpy as np
|
9 |
+
import scipy.stats as stats
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
|
13 |
+
from xdog import to_sketch
|
14 |
+
# -
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.utils.data as data
|
19 |
+
from torch.utils.data.sampler import Sampler
|
20 |
+
|
21 |
+
from torchvision import transforms
|
22 |
+
from torchvision.transforms import Resize, CenterCrop
|
23 |
+
|
24 |
+
mu, sigma = 1, 0.005
|
25 |
+
X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
|
26 |
+
|
27 |
+
denormalize = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
|
28 |
+
std = [ 1/0.5, 1/0.5, 1/0.5 ]),
|
29 |
+
transforms.Normalize(mean = [ -0.5, -0.5, -0.5 ],
|
30 |
+
std = [ 1., 1., 1. ]),])
|
31 |
+
|
32 |
+
etrans = transforms.Compose([
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize((0.5), (0.5))
|
35 |
+
])
|
36 |
+
|
37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
|
39 |
+
def predict_img(gen, sk, hnt = None):
|
40 |
+
#sk = Image.open(sketch_path).convert('L')
|
41 |
+
sk = etrans(sk)
|
42 |
+
|
43 |
+
pad_w = 16 - sk.shape[1] % 16 if sk.shape[1] % 16 != 0 else 0
|
44 |
+
pad_h = 16 - sk.shape[2] % 16 if sk.shape[2] % 16 != 0 else 0
|
45 |
+
pad = nn.ZeroPad2d((pad_h, 0, pad_w, 0))
|
46 |
+
sk = pad(sk)
|
47 |
+
|
48 |
+
sk = sk.unsqueeze(0)
|
49 |
+
sk = sk.to(device)
|
50 |
+
|
51 |
+
if hnt == None:
|
52 |
+
hnt = torch.zeros((1, 4, sk.shape[2]//4, sk.shape[3]//4))
|
53 |
+
|
54 |
+
hnt = hnt.to(device)
|
55 |
+
|
56 |
+
img_gen = gen(sk, hnt, sketch_feat=None).squeeze(0)
|
57 |
+
img_gen = denormalize(img_gen) * 255
|
58 |
+
img_gen = img_gen.permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
|
59 |
+
#return img_gen[pad_w:, pad_h:]
|
60 |
+
return Image.fromarray(img_gen[pad_w:, pad_h:])
|
61 |
+
|
62 |
+
def files(img_path, img_size=512):
|
63 |
+
img_path = os.path.abspath(img_path)
|
64 |
+
line_widths = sorted([el for el in os.listdir(os.path.join(img_path, 'pics_sketch')) if el != '.ipynb_checkpoints'])
|
65 |
+
images_names = sorted([el for el in os.listdir(os.path.join(img_path, 'pics_sketch', line_widths[0])) if '.jpg' in el])
|
66 |
+
|
67 |
+
images_names = [el for el in images_names if np.all(np.array(Image.open(os.path.join(img_path, 'pics', el)).size) >= np.array([img_size, img_size]))]
|
68 |
+
|
69 |
+
images_color = [os.path.join(img_path, 'pics', el) for el in images_names]
|
70 |
+
images_sketch = {line_width:[os.path.join(img_path, 'pics_sketch', line_width, el) for el in images_names] for line_width in line_widths}
|
71 |
+
return images_color, images_sketch
|
72 |
+
|
73 |
+
def mask_gen(img_size=512, bs=4):
|
74 |
+
maskS = img_size // 4
|
75 |
+
|
76 |
+
mask1 = torch.cat([torch.rand(1, 1, maskS, maskS).ge(X.rvs(1)[0]).float() for _ in range(bs // 2)], 0)
|
77 |
+
mask2 = torch.cat([torch.zeros(1, 1, maskS, maskS).float() for _ in range(bs // 2)], 0)
|
78 |
+
mask = torch.cat([mask1, mask2], 0)
|
79 |
+
return mask
|
80 |
+
|
81 |
+
def jitter(x):
|
82 |
+
ran = random.uniform(0.7, 1)
|
83 |
+
return x * ran + 1 - ran
|
84 |
+
|
85 |
+
def make_trans(img_size):
|
86 |
+
vtrans = transforms.Compose([
|
87 |
+
RandomSizedCrop(img_size // 4, Image.BICUBIC),
|
88 |
+
transforms.ToTensor(),
|
89 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
90 |
+
])
|
91 |
+
|
92 |
+
ctrans = transforms.Compose([
|
93 |
+
transforms.Resize(img_size, Image.BICUBIC),
|
94 |
+
transforms.ToTensor(),
|
95 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
96 |
+
])
|
97 |
+
|
98 |
+
strans = transforms.Compose([
|
99 |
+
transforms.Resize(img_size, Image.BICUBIC),
|
100 |
+
transforms.ToTensor(),
|
101 |
+
transforms.Lambda(jitter),
|
102 |
+
transforms.Normalize((0.5), (0.5))
|
103 |
+
])
|
104 |
+
|
105 |
+
return vtrans, ctrans, strans
|
106 |
+
|
107 |
+
class RandomCrop(object):
|
108 |
+
"""Crops the given PIL.Image at a random location to have a region of
|
109 |
+
the given size. size can be a tuple (target_height, target_width)
|
110 |
+
or an integer, in which case the target will be of a square shape (size, size)
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, size):
|
114 |
+
if isinstance(size, numbers.Number):
|
115 |
+
self.size = (int(size), int(size))
|
116 |
+
else:
|
117 |
+
self.size = size
|
118 |
+
|
119 |
+
def __call__(self, img1, img2):
|
120 |
+
w, h = img1.size
|
121 |
+
th, tw = self.size
|
122 |
+
if w == tw and h == th: # ValueError: empty range for randrange() (0,0, 0)
|
123 |
+
return img1, img2
|
124 |
+
|
125 |
+
if w == tw:
|
126 |
+
x1 = 0
|
127 |
+
y1 = random.randint(0, h - th)
|
128 |
+
return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th))
|
129 |
+
|
130 |
+
elif h == th:
|
131 |
+
x1 = random.randint(0, w - tw)
|
132 |
+
y1 = 0
|
133 |
+
return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th))
|
134 |
+
|
135 |
+
else:
|
136 |
+
x1 = random.randint(0, w - tw)
|
137 |
+
y1 = random.randint(0, h - th)
|
138 |
+
return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th))
|
139 |
+
|
140 |
+
class RandomSizedCrop(object):
|
141 |
+
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
|
142 |
+
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
|
143 |
+
This is popularly used to train the Inception networks
|
144 |
+
size: size of the smaller edge
|
145 |
+
interpolation: Default: PIL.Image.BILINEAR
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, size, interpolation=Image.BICUBIC):
|
149 |
+
self.size = size
|
150 |
+
self.interpolation = interpolation
|
151 |
+
|
152 |
+
def __call__(self, img):
|
153 |
+
for attempt in range(10):
|
154 |
+
area = img.size[0] * img.size[1]
|
155 |
+
target_area = random.uniform(0.9, 1.) * area
|
156 |
+
aspect_ratio = random.uniform(7. / 8, 8. / 7)
|
157 |
+
|
158 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
159 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
160 |
+
|
161 |
+
if random.random() < 0.5:
|
162 |
+
w, h = h, w
|
163 |
+
|
164 |
+
if w <= img.size[0] and h <= img.size[1]:
|
165 |
+
x1 = random.randint(0, img.size[0] - w)
|
166 |
+
y1 = random.randint(0, img.size[1] - h)
|
167 |
+
|
168 |
+
img = img.crop((x1, y1, x1 + w, y1 + h))
|
169 |
+
assert (img.size == (w, h))
|
170 |
+
|
171 |
+
return img.resize((self.size, self.size), self.interpolation)
|
172 |
+
|
173 |
+
# Fallback
|
174 |
+
Resize = Resize(self.size, interpolation=self.interpolation)
|
175 |
+
crop = CenterCrop(self.size)
|
176 |
+
return crop(Resize(img))
|
177 |
+
|
178 |
+
|
179 |
+
class ImageFolder(data.Dataset):
|
180 |
+
def __init__(self, img_path, img_size):
|
181 |
+
|
182 |
+
self.images_color, self.images_sketch = files(img_path, img_size)
|
183 |
+
if (any([self.images_sketch[key] == 0 for key in self.images_sketch])) or (len(self.images_color) == 0):
|
184 |
+
raise (RuntimeError("Found 0 images in one of the folders."))
|
185 |
+
if any([len(self.images_sketch[key]) != len(self.images_color) for key in self.images_sketch]):
|
186 |
+
raise (RuntimeError("The number of sketches is not equal to the number of colorized images."))
|
187 |
+
self.img_path = img_path
|
188 |
+
self.img_size = img_size
|
189 |
+
self.vtrans, self.ctrans, self.strans = make_trans(img_size)
|
190 |
+
|
191 |
+
def __getitem__(self, index):
|
192 |
+
color = Image.open(self.images_color[index]).convert('RGB')
|
193 |
+
|
194 |
+
random_line_width = random.choice(list(self.images_sketch.keys()))
|
195 |
+
sketch = Image.open(self.images_sketch[random_line_width][index]).convert('L')
|
196 |
+
#the image can be smaller than img_size, fix!
|
197 |
+
color, sketch = RandomCrop(self.img_size)(color, sketch)
|
198 |
+
if random.random() < 0.5:
|
199 |
+
color, sketch = color.transpose(Image.FLIP_LEFT_RIGHT), sketch.transpose(Image.FLIP_LEFT_RIGHT)
|
200 |
+
|
201 |
+
color, color_down, sketch = self.ctrans(color), self.vtrans(color), self.strans(sketch)
|
202 |
+
|
203 |
+
return color, color_down, sketch
|
204 |
+
|
205 |
+
def __len__(self):
|
206 |
+
return len(self.images_color)
|
207 |
+
|
208 |
+
|
209 |
+
class GivenIterationSampler(Sampler):
|
210 |
+
def __init__(self, dataset, total_iter, batch_size, diter, last_iter=-1):
|
211 |
+
self.dataset = dataset
|
212 |
+
self.total_iter = total_iter
|
213 |
+
self.batch_size = batch_size
|
214 |
+
self.diter = diter
|
215 |
+
self.last_iter = last_iter
|
216 |
+
|
217 |
+
self.total_size = self.total_iter * self.batch_size * (self.diter + 1)
|
218 |
+
|
219 |
+
self.indices = self.gen_new_list()
|
220 |
+
self.call = 0
|
221 |
+
|
222 |
+
|
223 |
+
def __iter__(self):
|
224 |
+
#if self.call == 0:
|
225 |
+
#self.call = 1
|
226 |
+
return iter(self.indices[(self.last_iter + 1) * self.batch_size * (self.diter + 1):])
|
227 |
+
#else:
|
228 |
+
# raise RuntimeError("this sampler is not designed to be called more than once!!")
|
229 |
+
|
230 |
+
def gen_new_list(self):
|
231 |
+
# each process shuffle all list with same seed
|
232 |
+
np.random.seed(0)
|
233 |
+
|
234 |
+
indices = np.arange(len(self.dataset))
|
235 |
+
indices = indices[:self.total_size]
|
236 |
+
num_repeat = (self.total_size - 1) // indices.shape[0] + 1
|
237 |
+
indices = np.tile(indices, num_repeat)
|
238 |
+
indices = indices[:self.total_size]
|
239 |
+
|
240 |
+
np.random.shuffle(indices)
|
241 |
+
assert len(indices) == self.total_size
|
242 |
+
return indices
|
243 |
+
|
244 |
+
def __len__(self):
|
245 |
+
# note here we do not take last iter into consideration, since __len__
|
246 |
+
# should only be used for displaying, the correct remaining size is
|
247 |
+
# handled by dataloader
|
248 |
+
# return self.total_size - (self.last_iter+1)*self.batch_size
|
249 |
+
return self.total_size
|
250 |
+
|
251 |
+
def get_dataloader(img_path, img_size=512, seed=0, total_iter=250000, bs=4, diters=1, last_iter=-1):
|
252 |
+
|
253 |
+
random.seed(seed)
|
254 |
+
|
255 |
+
train_dataset = ImageFolder(img_path=img_path, img_size=img_size)
|
256 |
+
|
257 |
+
train_sampler = GivenIterationSampler(train_dataset, total_iter, bs, diters, last_iter=last_iter)
|
258 |
+
|
259 |
+
return data.DataLoader(train_dataset, batch_size=bs, shuffle=False, pin_memory=True, num_workers=4, sampler=train_sampler)
|
260 |
+
|
261 |
+
|
262 |
+
def get_data(links, img_path='alacgan_data', line_widths=[0.3, 0.5]):
|
263 |
+
c = 0
|
264 |
+
|
265 |
+
for line_width in line_widths:
|
266 |
+
lw = str(line_width)
|
267 |
+
if lw not in os.listdir(os.path.join(img_path, 'pics_sketch')):
|
268 |
+
os.mkdir(os.path.join(img_path, 'pics_sketch', lw))
|
269 |
+
else:
|
270 |
+
shutil.rmtree(os.path.join(img_path, 'pics_sketch', lw))
|
271 |
+
os.mkdir(os.path.join(img_path, 'pics_sketch', lw))
|
272 |
+
|
273 |
+
for link in tqdm(links):
|
274 |
+
img_orig = Image.open(requests.get(link, stream=True).raw).convert('RGB')
|
275 |
+
img_orig.save(os.path.join(img_path, 'pics', str(c) + '.jpg'), 'JPEG')
|
276 |
+
for line_width in line_widths:
|
277 |
+
sketch_test = to_sketch(img_orig, sigma=line_width, k=5, gamma=0.96, epsilon=-1, phi=10e15, area_min=2)
|
278 |
+
sketch_test.save(os.path.join(img_path, 'pics_sketch', str(line_width), str(c) + '.jpg'), 'JPEG')
|
279 |
+
|
280 |
+
c += 1
|
examples/Genshin-Impact-anime.jpg
ADDED
examples/Screenshot 2023-02-07 140049.jpg
ADDED
examples/anime2.jpg
ADDED
gradio_app.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from xdog import to_sketch
|
3 |
+
from model import Generator, ResNeXtBottleneck
|
4 |
+
import torch
|
5 |
+
from data_utils import *
|
6 |
+
import glob
|
7 |
+
gen = torch.load('model/model.pth')
|
8 |
+
|
9 |
+
def convert_to_lineart(img, sigma, k, gamma, epsilon, phi, area_min):
|
10 |
+
phi = 10 * phi
|
11 |
+
out = to_sketch(img, sigma=sigma, k=k, gamma=gamma, epsilon=epsilon, phi=phi, area_min=area_min)
|
12 |
+
return out
|
13 |
+
|
14 |
+
def inference(sk):
|
15 |
+
return predict_img(gen, sk, hnt = None)
|
16 |
+
|
17 |
+
title = "To Line Art"
|
18 |
+
description = "Line art colorization showcase. "
|
19 |
+
article = "Github Repo"
|
20 |
+
|
21 |
+
with gr.Blocks() as demo:
|
22 |
+
with gr.Row():
|
23 |
+
with gr.Column():
|
24 |
+
image = gr.Image(type="pil", value='examples/Genshin-Impact-anime.jpg')
|
25 |
+
to_lineart_button = gr.Button("To Lineart")
|
26 |
+
|
27 |
+
gr.Examples(
|
28 |
+
examples=glob.glob('examples/*.jpg'),
|
29 |
+
inputs=image,
|
30 |
+
outputs=image,
|
31 |
+
fn=None,
|
32 |
+
cache_examples=False,
|
33 |
+
)
|
34 |
+
|
35 |
+
with gr.Column():
|
36 |
+
sigma = gr.Slider(0.1, 0.5, value=0.3, step=0.1, label='σ')
|
37 |
+
k = gr.Slider(1.0, 8.0, value=4.5, step=0.5, label='k')
|
38 |
+
gamma = gr.Slider(0.05, 1.0, value=0.95, step=0.05, label='γ')
|
39 |
+
epsilon = gr.Slider(-2, 2, value=-1, step=0.5, label='ε')
|
40 |
+
phi = gr.Slider(10, 20, label = 'φ', value=15)
|
41 |
+
min_area = gr.Slider(1, 5, value=2, step=1, label='Minimal Area')
|
42 |
+
|
43 |
+
with gr.Column():
|
44 |
+
lineart = gr.Image(type="pil", image_mode='L')
|
45 |
+
inpaint_button = gr.Button("Inpaint")
|
46 |
+
|
47 |
+
to_lineart_button.click(convert_to_lineart, inputs=[image, sigma, k, gamma, epsilon, phi, min_area], outputs=lineart)
|
48 |
+
inpaint_button.click(inference, inputs=lineart, outputs=lineart)
|
49 |
+
|
50 |
+
demo.launch()
|
model.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class ResNeXtBottleneck(nn.Module):
|
6 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
7 |
+
super(ResNeXtBottleneck, self).__init__()
|
8 |
+
|
9 |
+
D = out_channels // 2
|
10 |
+
self.out_channels = out_channels
|
11 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
12 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate, groups=cardinality, bias=False)
|
13 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
14 |
+
self.shortcut = nn.Sequential()
|
15 |
+
|
16 |
+
if stride != 1:
|
17 |
+
self.shortcut.add_module('shortcut', nn.AvgPool2d(2, stride=2))
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
bottleneck = self.conv_reduce.forward(x)
|
21 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
22 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
23 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
24 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
25 |
+
x = self.shortcut.forward(x)
|
26 |
+
|
27 |
+
return x + bottleneck
|
28 |
+
|
29 |
+
|
30 |
+
class Generator(nn.Module):
|
31 |
+
def __init__(self, ngf=64, feat=True):
|
32 |
+
super(Generator, self).__init__()
|
33 |
+
self.feat = feat
|
34 |
+
if feat:
|
35 |
+
add_channels = 512
|
36 |
+
else:
|
37 |
+
add_channels = 0
|
38 |
+
|
39 |
+
self.toH = self._block(4, ngf, kernel_size=7, stride=1, padding=3)
|
40 |
+
self.to0 = self._block(1, ngf // 2, kernel_size=3, stride=1, padding=1)
|
41 |
+
self.to1 = self._block(ngf // 2, ngf, kernel_size=4, stride=2, padding=1)
|
42 |
+
self.to2 = self._block(ngf, ngf * 2, kernel_size=4, stride=2, padding=1)
|
43 |
+
self.to3 = self._block(ngf * 3, ngf * 4, kernel_size=4, stride=2, padding=1)
|
44 |
+
self.to4 = self._block(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1)
|
45 |
+
|
46 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
|
47 |
+
|
48 |
+
self.tunnel4 = nn.Sequential(self._block(ngf * 8 + add_channels, ngf * 8, kernel_size=3, stride=1, padding=1),
|
49 |
+
tunnel4,
|
50 |
+
nn.Conv2d(ngf * 8, ngf * 16, kernel_size = 3, stride=1, padding=1),
|
51 |
+
nn.PixelShuffle(2),
|
52 |
+
nn.LeakyReLU(0.2, True))
|
53 |
+
|
54 |
+
depth = 2
|
55 |
+
|
56 |
+
tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
|
57 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
|
58 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
|
59 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
|
60 |
+
ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
|
61 |
+
tunnel3 = nn.Sequential(*tunnel)
|
62 |
+
|
63 |
+
self.tunnel3 = nn.Sequential(self._block(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
|
64 |
+
tunnel3,
|
65 |
+
nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=1, padding=1),
|
66 |
+
nn.PixelShuffle(2),
|
67 |
+
nn.LeakyReLU(0.2, True))
|
68 |
+
|
69 |
+
tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
|
70 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
|
71 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
|
72 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
|
73 |
+
ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
|
74 |
+
tunnel2 = nn.Sequential(*tunnel)
|
75 |
+
|
76 |
+
self.tunnel2 = nn.Sequential(self._block(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
|
77 |
+
tunnel2,
|
78 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
|
79 |
+
nn.PixelShuffle(2),
|
80 |
+
nn.LeakyReLU(0.2, True))
|
81 |
+
|
82 |
+
tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
83 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
|
84 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
|
85 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
|
86 |
+
ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
87 |
+
tunnel1 = nn.Sequential(*tunnel)
|
88 |
+
|
89 |
+
self.tunnel1 = nn.Sequential(self._block(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
|
90 |
+
tunnel1,
|
91 |
+
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
|
92 |
+
nn.PixelShuffle(2),
|
93 |
+
nn.LeakyReLU(0.2, True))
|
94 |
+
|
95 |
+
self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
|
96 |
+
|
97 |
+
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
98 |
+
return nn.Sequential(
|
99 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
|
100 |
+
nn.LeakyReLU(0.2, True)
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
def forward(self, sketch, hint, sketch_feat):
|
105 |
+
hint = self.toH(hint)
|
106 |
+
|
107 |
+
x0 = self.to0(sketch)
|
108 |
+
x1 = self.to1(x0)
|
109 |
+
x2 = self.to2(x1)
|
110 |
+
x3 = self.to3(torch.cat([x2, hint], 1))
|
111 |
+
x4 = self.to4(x3)
|
112 |
+
|
113 |
+
if self.feat:
|
114 |
+
x = self.tunnel4(torch.cat([x4, sketch_feat], 1))
|
115 |
+
x = self.tunnel3(torch.cat([x, x3], 1))
|
116 |
+
x = self.tunnel2(torch.cat([x, x2], 1))
|
117 |
+
x = self.tunnel1(torch.cat([x, x1], 1))
|
118 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
119 |
+
else:
|
120 |
+
x = self.tunnel4(x4)
|
121 |
+
x = self.tunnel3(torch.cat([x, x3], 1))
|
122 |
+
x = self.tunnel2(torch.cat([x, x2], 1))
|
123 |
+
x = self.tunnel1(torch.cat([x, x1], 1))
|
124 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class Discriminator(nn.Module):
|
129 |
+
def __init__(self, ndf=64, feat=True):
|
130 |
+
super(Discriminator, self).__init__()
|
131 |
+
self.feat = feat
|
132 |
+
|
133 |
+
if feat:
|
134 |
+
add_channels = ndf * 8
|
135 |
+
ks = 4
|
136 |
+
else:
|
137 |
+
add_channels = 0
|
138 |
+
ks = 3
|
139 |
+
|
140 |
+
self.feed = nn.Sequential(
|
141 |
+
self._block(3, ndf, kernel_size=7, stride=1, padding=1),
|
142 |
+
self._block(ndf, ndf, kernel_size=4, stride=2, padding=1),
|
143 |
+
|
144 |
+
ResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
|
145 |
+
ResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2),
|
146 |
+
self._block(ndf, ndf * 2, kernel_size=1, stride=1, padding=0),
|
147 |
+
|
148 |
+
ResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
|
149 |
+
ResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2),
|
150 |
+
self._block(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0),
|
151 |
+
|
152 |
+
ResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
|
153 |
+
ResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2)
|
154 |
+
)
|
155 |
+
|
156 |
+
self.feed2 = nn.Sequential(
|
157 |
+
self._block(ndf * 4 + add_channels, ndf * 8, kernel_size=3, stride=1, padding=1),
|
158 |
+
|
159 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
160 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
|
161 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
162 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
|
163 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
164 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2),
|
165 |
+
ResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
166 |
+
|
167 |
+
self._block(ndf * 8, ndf * 8, kernel_size=ks, stride=1, padding=0),
|
168 |
+
)
|
169 |
+
|
170 |
+
self.out = nn.Linear(512, 1)
|
171 |
+
|
172 |
+
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
173 |
+
return nn.Sequential(
|
174 |
+
nn.Conv2d(in_channels,
|
175 |
+
out_channels,
|
176 |
+
kernel_size,
|
177 |
+
stride,
|
178 |
+
padding,
|
179 |
+
bias=False),
|
180 |
+
nn.LeakyReLU(0.2, True)
|
181 |
+
)
|
182 |
+
|
183 |
+
def forward(self, color, sketch_feat=None):
|
184 |
+
x = self.feed(color)
|
185 |
+
|
186 |
+
if self.feat:
|
187 |
+
x = self.feed2(torch.cat([x, sketch_feat], 1))
|
188 |
+
else:
|
189 |
+
x = self.feed2(x)
|
190 |
+
|
191 |
+
out = self.out(x.view(color.size(0), -1))
|
192 |
+
return out
|
model/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08d8483543bbc2fcae2d86a89422dc7be566c239ba99b6b889a7dc29a289a01c
|
3 |
+
size 78392739
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
opencv
|
xdog.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
|
6 |
+
# Difference of Gaussians applied to img input
|
7 |
+
def dog(img,size=(0,0),k=1.6,sigma=0.5,gamma=1):
|
8 |
+
img1 = cv2.GaussianBlur(img,size,sigma)
|
9 |
+
img2 = cv2.GaussianBlur(img,size,sigma*k)
|
10 |
+
return (img1-gamma*img2)
|
11 |
+
|
12 |
+
# Threshold the dog image, with dog(sigma,k) > 0 ? 1(255):0(0)
|
13 |
+
def edge_dog(img,sigma=0.5,k=200,gamma=0.98):
|
14 |
+
aux = dog(img,sigma=sigma,k=k,gamma=0.98)
|
15 |
+
for i in range(0,aux.shape[0]):
|
16 |
+
for j in range(0,aux.shape[1]):
|
17 |
+
if(aux[i,j] > 0):
|
18 |
+
aux[i,j] = 255
|
19 |
+
else:
|
20 |
+
aux[i,j] = 0
|
21 |
+
return aux
|
22 |
+
|
23 |
+
# garygrossi xdog version
|
24 |
+
def xdog_garygrossi(img,sigma=0.5,k=200, gamma=0.98,epsilon=0.1,phi=10):
|
25 |
+
aux = dog(img,sigma=sigma,k=k,gamma=gamma)/255
|
26 |
+
for i in range(0,aux.shape[0]):
|
27 |
+
for j in range(0,aux.shape[1]):
|
28 |
+
if(aux[i,j] >= epsilon):
|
29 |
+
aux[i,j] = 1
|
30 |
+
else:
|
31 |
+
ht = np.tanh(phi*(aux[i][j] - epsilon))
|
32 |
+
aux[i][j] = 1 + ht
|
33 |
+
return aux*255
|
34 |
+
|
35 |
+
def hatchBlend(image):
|
36 |
+
xdogImage = xdog(image,sigma=1,k=200, gamma=0.5,epsilon=-0.5,phi=10)
|
37 |
+
hatchTexture = cv2.imread('./imgs/hatch.jpg', cv2.CV_LOAD_IMAGE_GRAYSCALE)
|
38 |
+
hatchTexture = cv2.resize(hatchTexture,(image.shape[1],image.shape[0]))
|
39 |
+
alpha = 0.120
|
40 |
+
return (1-alpha)*xdogImage + alpha*hatchTexture
|
41 |
+
|
42 |
+
# version of xdog inspired by article
|
43 |
+
def xdog(img,sigma=0.5,k=1.6, gamma=1,epsilon=1,phi=1):
|
44 |
+
aux = dog(img,sigma=sigma,k=k,gamma=gamma)/255
|
45 |
+
for i in range(0,aux.shape[0]):
|
46 |
+
for j in range(0,aux.shape[1]):
|
47 |
+
if(aux[i,j] < epsilon):
|
48 |
+
aux[i,j] = 1*255
|
49 |
+
else:
|
50 |
+
aux[i,j] = 255*(1 + np.tanh(phi*(aux[i,j])))
|
51 |
+
return aux
|
52 |
+
|
53 |
+
def to_sketch(img_orig, sigma=0.3, k=4.5, gamma=0.95, epsilon=-1, phi=10e15, area_min=2):
|
54 |
+
img_cnts = []
|
55 |
+
img = cv2.cvtColor(np.array(img_orig), cv2.COLOR_RGB2GRAY)
|
56 |
+
img_xdog = xdog(img, sigma=sigma, k=k, gamma=gamma, epsilon=epsilon, phi=phi).astype(np.uint8)
|
57 |
+
new_img = np.zeros_like(img_xdog)
|
58 |
+
thresh = cv2.threshold(img_xdog, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
|
59 |
+
cnts = cv2.findContours(thresh.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
60 |
+
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
|
61 |
+
for c in cnts:
|
62 |
+
area = cv2.contourArea(c)
|
63 |
+
if area > area_min:
|
64 |
+
img_cnts.append(c)
|
65 |
+
|
66 |
+
return Image.fromarray(255 - cv2.drawContours(new_img, img_cnts, -1, (255,255,255), -1))
|