Upload 47 files
Browse files- .dockerignore +6 -0
- .gitignore +10 -0
- Dockerfile +11 -0
- configs/train_config.json +10 -0
- configs/xdog_config.json +8 -0
- dataset/__pycache__/datasets.cpython-310.pyc +0 -0
- dataset/__pycache__/datasets.cpython-36.pyc +0 -0
- dataset/__pycache__/datasets.cpython-39.pyc +0 -0
- dataset/datasets.py +91 -0
- denoising/denoiser.py +113 -0
- denoising/functions.py +101 -0
- denoising/models.py +422 -0
- denoising/models/.gitkeep +0 -0
- denoising/utils.py +66 -0
- drawing.py +165 -0
- inference.py +215 -0
- model/__pycache__/extractor.cpython-310.pyc +0 -0
- model/__pycache__/extractor.cpython-36.pyc +0 -0
- model/__pycache__/extractor.cpython-39.pyc +0 -0
- model/__pycache__/models.cpython-310.pyc +0 -0
- model/__pycache__/models.cpython-36.pyc +0 -0
- model/__pycache__/models.cpython-39.pyc +0 -0
- model/extractor.pth +3 -0
- model/extractor.py +127 -0
- model/models.py +422 -0
- model/vgg16-397923af.pth +3 -0
- readme.md +22 -0
- requirements.txt +10 -0
- run_drawing.sh +1 -0
- static/js/draw.js +120 -0
- static/temp_images/.gitkeep +0 -0
- templates/drawing.html +206 -0
- templates/submit.html +11 -0
- templates/upload.html +20 -0
- train.py +293 -0
- train/bw/blackclover_cl268.png +0 -0
- train/bw/dfm_blackclover_cl268.png +0 -0
- train/color/blackclover_cl268.png +0 -0
- train/real_manga/blackclover_cl268.png +0 -0
- train/real_manga/dfm_blackclover_cl268.png +0 -0
- utils/__pycache__/utils.cpython-310.pyc +0 -0
- utils/__pycache__/utils.cpython-36.pyc +0 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/dataset_utils.py +141 -0
- utils/utils.py +102 -0
- utils/xdog.py +68 -0
- web.py +108 -0
.dockerignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.ipynb
|
2 |
+
|
3 |
+
model/*.pth
|
4 |
+
|
5 |
+
temp_colorization/
|
6 |
+
__pycache__/
|
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.ipynb
|
2 |
+
*.pth
|
3 |
+
*.zip
|
4 |
+
|
5 |
+
__pycache__/
|
6 |
+
temp_colorization/
|
7 |
+
|
8 |
+
static/temp_images/*
|
9 |
+
|
10 |
+
!.gitkeep
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
|
2 |
+
|
3 |
+
RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
|
4 |
+
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir -r ./requirements.txt
|
8 |
+
|
9 |
+
EXPOSE 5000
|
10 |
+
|
11 |
+
CMD gunicorn --timeout 200 -w 3 -b 0.0.0.0:5000 drawing:app
|
configs/train_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"generator_lr" : 1e-4,
|
3 |
+
"discriminator_lr" : 4e-4,
|
4 |
+
"epochs" : 15,
|
5 |
+
"lr_decrease_epoch" : 10,
|
6 |
+
"finetuning_generator_lr" : 1e-6,
|
7 |
+
"finetuning_iterations" : 3500,
|
8 |
+
"batch_size" : 1,
|
9 |
+
"number_of_mults" : 1
|
10 |
+
}
|
configs/xdog_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sigma" : 0.5,
|
3 |
+
"k" : 8,
|
4 |
+
"phi" : 89.25,
|
5 |
+
"gamma" : 0.95,
|
6 |
+
"eps" : -0.1,
|
7 |
+
"mult" : 7
|
8 |
+
}
|
dataset/__pycache__/datasets.cpython-310.pyc
ADDED
Binary file (3.21 kB). View file
|
|
dataset/__pycache__/datasets.cpython-36.pyc
ADDED
Binary file (3.23 kB). View file
|
|
dataset/__pycache__/datasets.cpython-39.pyc
ADDED
Binary file (3.56 kB). View file
|
|
dataset/datasets.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from utils.utils import generate_mask
|
7 |
+
|
8 |
+
class TrainDataset(torch.utils.data.Dataset):
|
9 |
+
def __init__(self, data_path, transform=None):
|
10 |
+
self.data = os.listdir(os.path.join(data_path, 'color'))
|
11 |
+
self.data_path = data_path
|
12 |
+
self.transform = transform
|
13 |
+
self.ToTensor = transforms.ToTensor()
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, idx):
|
19 |
+
image_name = self.data[idx]
|
20 |
+
|
21 |
+
color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
|
22 |
+
bw_name = self.data[idx]
|
23 |
+
dfm_name = 'dfm_' + self.data[idx]
|
24 |
+
|
25 |
+
bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
|
26 |
+
dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
|
27 |
+
|
28 |
+
color_img = np.array(color_img)
|
29 |
+
bw_img = np.array(bw_img)
|
30 |
+
dfm_img = np.array(dfm_img)
|
31 |
+
|
32 |
+
bw_img = np.expand_dims(bw_img, 2)
|
33 |
+
dfm_img = np.expand_dims(dfm_img, 2)
|
34 |
+
bw_img = np.concatenate([bw_img, dfm_img], axis=2)
|
35 |
+
|
36 |
+
if self.transform:
|
37 |
+
result = self.transform(image=color_img, mask=bw_img)
|
38 |
+
color_img = result['image']
|
39 |
+
bw_img = result['mask']
|
40 |
+
|
41 |
+
color_img = self.ToTensor(color_img)
|
42 |
+
bw_img = self.ToTensor(bw_img)
|
43 |
+
color_img = (color_img - 0.5) / 0.5 # Normalización de color_img
|
44 |
+
|
45 |
+
mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
|
46 |
+
hint = torch.cat((color_img * mask, mask), 0)
|
47 |
+
|
48 |
+
return bw_img, bw_img, color_img, hint
|
49 |
+
|
50 |
+
class FineTuningDataset(torch.utils.data.Dataset):
|
51 |
+
def __init__(self, data_path, transform=None, mult_amount=1):
|
52 |
+
self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
|
53 |
+
self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
|
54 |
+
self.data_path = data_path
|
55 |
+
self.transform = transform
|
56 |
+
self.mults_amount = mult_amount
|
57 |
+
|
58 |
+
np.random.shuffle(self.color_data)
|
59 |
+
self.ToTensor = transforms.ToTensor()
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return len(self.data)
|
63 |
+
|
64 |
+
def __getitem__(self, idx):
|
65 |
+
image_name = self.data[idx]
|
66 |
+
|
67 |
+
color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
|
68 |
+
bw_name = self.data[idx]
|
69 |
+
dfm_name = 'dfm_' + self.data[idx]
|
70 |
+
|
71 |
+
bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
|
72 |
+
dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
|
73 |
+
|
74 |
+
color_img = np.array(color_img)
|
75 |
+
bw_img = np.array(bw_img)
|
76 |
+
dfm_img = np.array(dfm_img)
|
77 |
+
|
78 |
+
bw_img = np.expand_dims(bw_img, 2)
|
79 |
+
dfm_img = np.expand_dims(dfm_img, 2)
|
80 |
+
bw_img = np.concatenate([bw_img, dfm_img], axis=2)
|
81 |
+
|
82 |
+
if self.transform:
|
83 |
+
result = self.transform(image=color_img, mask=bw_img)
|
84 |
+
color_img = result['image']
|
85 |
+
bw_img = result['mask']
|
86 |
+
|
87 |
+
color_img = self.ToTensor(color_img)
|
88 |
+
bw_img = self.ToTensor(bw_img)
|
89 |
+
color_img = (color_img - 0.5) / 0.5 # Normalización de color_img
|
90 |
+
|
91 |
+
return bw_img, color_img # Devuelve bw_img una vez y color_img
|
denoising/denoiser.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Denoise an image with the FFDNet denoising method
|
3 |
+
|
4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
5 |
+
|
6 |
+
This program is free software: you can use, modify and/or
|
7 |
+
redistribute it under the terms of the GNU General Public
|
8 |
+
License as published by the Free Software Foundation, either
|
9 |
+
version 3 of the License, or (at your option) any later
|
10 |
+
version. You should have received a copy of this license along
|
11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
12 |
+
"""
|
13 |
+
import os
|
14 |
+
import argparse
|
15 |
+
import time
|
16 |
+
import numpy as np
|
17 |
+
import cv2
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.autograd import Variable
|
21 |
+
from denoising.models import FFDNet
|
22 |
+
from denoising.utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
|
23 |
+
|
24 |
+
class FFDNetDenoiser:
|
25 |
+
def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
|
26 |
+
self.sigma = _sigma / 255
|
27 |
+
self.weights_dir = _weights_dir
|
28 |
+
self.channels = _in_ch
|
29 |
+
self.device = _device
|
30 |
+
|
31 |
+
self.model = FFDNet(num_input_channels = _in_ch)
|
32 |
+
self.load_weights()
|
33 |
+
self.model.eval()
|
34 |
+
|
35 |
+
|
36 |
+
def load_weights(self):
|
37 |
+
weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
|
38 |
+
weights_path = os.path.join(self.weights_dir, weights_name)
|
39 |
+
if self.device == 'cuda':
|
40 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
41 |
+
device_ids = [0]
|
42 |
+
self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
|
43 |
+
else:
|
44 |
+
state_dict = torch.load(weights_path, map_location='cpu')
|
45 |
+
# CPU mode: remove the DataParallel wrapper
|
46 |
+
state_dict = remove_dataparallel_wrapper(state_dict)
|
47 |
+
self.model.load_state_dict(state_dict)
|
48 |
+
|
49 |
+
def get_denoised_image(self, imorig, sigma = None):
|
50 |
+
|
51 |
+
if sigma is not None:
|
52 |
+
cur_sigma = sigma / 255
|
53 |
+
else:
|
54 |
+
cur_sigma = self.sigma
|
55 |
+
|
56 |
+
if len(imorig.shape) < 3 or imorig.shape[2] == 1:
|
57 |
+
imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
|
58 |
+
|
59 |
+
if (max(imorig.shape[0], imorig.shape[1]) > 1200):
|
60 |
+
ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
|
61 |
+
imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
|
62 |
+
|
63 |
+
imorig = imorig.transpose(2, 0, 1)
|
64 |
+
|
65 |
+
if (imorig.max() > 1.2):
|
66 |
+
imorig = normalize(imorig)
|
67 |
+
imorig = np.expand_dims(imorig, 0)
|
68 |
+
|
69 |
+
# Handle odd sizes
|
70 |
+
expanded_h = False
|
71 |
+
expanded_w = False
|
72 |
+
sh_im = imorig.shape
|
73 |
+
if sh_im[2]%2 == 1:
|
74 |
+
expanded_h = True
|
75 |
+
imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
|
76 |
+
|
77 |
+
if sh_im[3]%2 == 1:
|
78 |
+
expanded_w = True
|
79 |
+
imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
|
80 |
+
|
81 |
+
|
82 |
+
imorig = torch.Tensor(imorig)
|
83 |
+
|
84 |
+
|
85 |
+
# Sets data type according to CPU or GPU modes
|
86 |
+
if self.device == 'cuda':
|
87 |
+
dtype = torch.cuda.FloatTensor
|
88 |
+
else:
|
89 |
+
dtype = torch.FloatTensor
|
90 |
+
|
91 |
+
imnoisy = imorig.clone()
|
92 |
+
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
|
96 |
+
nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
|
97 |
+
|
98 |
+
|
99 |
+
# Estimate noise and subtract it to the input image
|
100 |
+
im_noise_estim = self.model(imnoisy, nsigma)
|
101 |
+
outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
|
102 |
+
|
103 |
+
if expanded_h:
|
104 |
+
imorig = imorig[:, :, :-1, :]
|
105 |
+
outim = outim[:, :, :-1, :]
|
106 |
+
imnoisy = imnoisy[:, :, :-1, :]
|
107 |
+
|
108 |
+
if expanded_w:
|
109 |
+
imorig = imorig[:, :, :, :-1]
|
110 |
+
outim = outim[:, :, :, :-1]
|
111 |
+
imnoisy = imnoisy[:, :, :, :-1]
|
112 |
+
|
113 |
+
return variable_to_cv2_image(outim)
|
denoising/functions.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions implementing custom NN layers
|
3 |
+
|
4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
5 |
+
|
6 |
+
This program is free software: you can use, modify and/or
|
7 |
+
redistribute it under the terms of the GNU General Public
|
8 |
+
License as published by the Free Software Foundation, either
|
9 |
+
version 3 of the License, or (at your option) any later
|
10 |
+
version. You should have received a copy of this license along
|
11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
12 |
+
"""
|
13 |
+
import torch
|
14 |
+
from torch.autograd import Function, Variable
|
15 |
+
|
16 |
+
def concatenate_input_noise_map(input, noise_sigma):
|
17 |
+
r"""Implements the first layer of FFDNet. This function returns a
|
18 |
+
torch.autograd.Variable composed of the concatenation of the downsampled
|
19 |
+
input image and the noise map. Each image of the batch of size CxHxW gets
|
20 |
+
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
|
21 |
+
non-overlapped 2x2 patches of the input image are placed in the new array
|
22 |
+
along the first dimension.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
input: batch containing CxHxW images
|
26 |
+
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
|
27 |
+
"""
|
28 |
+
# noise_sigma is a list of length batch_size
|
29 |
+
N, C, H, W = input.size()
|
30 |
+
dtype = input.type()
|
31 |
+
sca = 2
|
32 |
+
sca2 = sca*sca
|
33 |
+
Cout = sca2*C
|
34 |
+
Hout = H//sca
|
35 |
+
Wout = W//sca
|
36 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
37 |
+
|
38 |
+
# Fill the downsampled image with zeros
|
39 |
+
if 'cuda' in dtype:
|
40 |
+
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
41 |
+
else:
|
42 |
+
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
43 |
+
|
44 |
+
# Build the CxH/2xW/2 noise map
|
45 |
+
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
|
46 |
+
|
47 |
+
# Populate output
|
48 |
+
for idx in range(sca2):
|
49 |
+
downsampledfeatures[:, idx:Cout:sca2, :, :] = \
|
50 |
+
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
51 |
+
|
52 |
+
# concatenate de-interleaved mosaic with noise map
|
53 |
+
return torch.cat((noise_map, downsampledfeatures), 1)
|
54 |
+
|
55 |
+
class UpSampleFeaturesFunction(Function):
|
56 |
+
r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
|
57 |
+
This class implements the forward and backward methods of the last layer
|
58 |
+
of FFDNet. It basically performs the inverse of
|
59 |
+
concatenate_input_noise_map(): it converts each of the images of a
|
60 |
+
batch of size CxH/2xW/2 to images of size C/4xHxW
|
61 |
+
"""
|
62 |
+
@staticmethod
|
63 |
+
def forward(ctx, input):
|
64 |
+
N, Cin, Hin, Win = input.size()
|
65 |
+
dtype = input.type()
|
66 |
+
sca = 2
|
67 |
+
sca2 = sca*sca
|
68 |
+
Cout = Cin//sca2
|
69 |
+
Hout = Hin*sca
|
70 |
+
Wout = Win*sca
|
71 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
72 |
+
|
73 |
+
assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
|
74 |
+
|
75 |
+
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
|
76 |
+
for idx in range(sca2):
|
77 |
+
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
|
78 |
+
|
79 |
+
return result
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def backward(ctx, grad_output):
|
83 |
+
N, Cg_out, Hg_out, Wg_out = grad_output.size()
|
84 |
+
dtype = grad_output.data.type()
|
85 |
+
sca = 2
|
86 |
+
sca2 = sca*sca
|
87 |
+
Cg_in = sca2*Cg_out
|
88 |
+
Hg_in = Hg_out//sca
|
89 |
+
Wg_in = Wg_out//sca
|
90 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
91 |
+
|
92 |
+
# Build output
|
93 |
+
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
|
94 |
+
# Populate output
|
95 |
+
for idx in range(sca2):
|
96 |
+
grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
97 |
+
|
98 |
+
return Variable(grad_input)
|
99 |
+
|
100 |
+
# Alias functions
|
101 |
+
upsamplefeatures = UpSampleFeaturesFunction.apply
|
denoising/models.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as M
|
5 |
+
import math
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import Parameter
|
8 |
+
|
9 |
+
'''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
|
10 |
+
|
11 |
+
def l2normalize(v, eps=1e-12):
|
12 |
+
return v / (v.norm() + eps)
|
13 |
+
|
14 |
+
|
15 |
+
class SpectralNorm(nn.Module):
|
16 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
17 |
+
super(SpectralNorm, self).__init__()
|
18 |
+
self.module = module
|
19 |
+
self.name = name
|
20 |
+
self.power_iterations = power_iterations
|
21 |
+
if not self._made_params():
|
22 |
+
self._make_params()
|
23 |
+
|
24 |
+
def _update_u_v(self):
|
25 |
+
u = getattr(self.module, self.name + "_u")
|
26 |
+
v = getattr(self.module, self.name + "_v")
|
27 |
+
w = getattr(self.module, self.name + "_bar")
|
28 |
+
|
29 |
+
height = w.data.shape[0]
|
30 |
+
for _ in range(self.power_iterations):
|
31 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
32 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
33 |
+
|
34 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
35 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
36 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
37 |
+
|
38 |
+
def _made_params(self):
|
39 |
+
try:
|
40 |
+
u = getattr(self.module, self.name + "_u")
|
41 |
+
v = getattr(self.module, self.name + "_v")
|
42 |
+
w = getattr(self.module, self.name + "_bar")
|
43 |
+
return True
|
44 |
+
except AttributeError:
|
45 |
+
return False
|
46 |
+
|
47 |
+
|
48 |
+
def _make_params(self):
|
49 |
+
w = getattr(self.module, self.name)
|
50 |
+
height = w.data.shape[0]
|
51 |
+
width = w.view(height, -1).data.shape[1]
|
52 |
+
|
53 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
54 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
55 |
+
u.data = l2normalize(u.data)
|
56 |
+
v.data = l2normalize(v.data)
|
57 |
+
w_bar = Parameter(w.data)
|
58 |
+
|
59 |
+
del self.module._parameters[self.name]
|
60 |
+
|
61 |
+
self.module.register_parameter(self.name + "_u", u)
|
62 |
+
self.module.register_parameter(self.name + "_v", v)
|
63 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
64 |
+
|
65 |
+
|
66 |
+
def forward(self, *args):
|
67 |
+
self._update_u_v()
|
68 |
+
return self.module.forward(*args)
|
69 |
+
|
70 |
+
class Selayer(nn.Module):
|
71 |
+
def __init__(self, inplanes):
|
72 |
+
super(Selayer, self).__init__()
|
73 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
74 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
75 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
76 |
+
self.relu = nn.ReLU(inplace=True)
|
77 |
+
self.sigmoid = nn.Sigmoid()
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
out = self.global_avgpool(x)
|
81 |
+
out = self.conv1(out)
|
82 |
+
out = self.relu(out)
|
83 |
+
out = self.conv2(out)
|
84 |
+
out = self.sigmoid(out)
|
85 |
+
|
86 |
+
return x * out
|
87 |
+
|
88 |
+
class SelayerSpectr(nn.Module):
|
89 |
+
def __init__(self, inplanes):
|
90 |
+
super(SelayerSpectr, self).__init__()
|
91 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
92 |
+
self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
|
93 |
+
self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
|
94 |
+
self.relu = nn.ReLU(inplace=True)
|
95 |
+
self.sigmoid = nn.Sigmoid()
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
out = self.global_avgpool(x)
|
99 |
+
out = self.conv1(out)
|
100 |
+
out = self.relu(out)
|
101 |
+
out = self.conv2(out)
|
102 |
+
out = self.sigmoid(out)
|
103 |
+
|
104 |
+
return x * out
|
105 |
+
|
106 |
+
class ResNeXtBottleneck(nn.Module):
|
107 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
108 |
+
super(ResNeXtBottleneck, self).__init__()
|
109 |
+
D = out_channels // 2
|
110 |
+
self.out_channels = out_channels
|
111 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
112 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
113 |
+
groups=cardinality,
|
114 |
+
bias=False)
|
115 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
116 |
+
self.shortcut = nn.Sequential()
|
117 |
+
if stride != 1:
|
118 |
+
self.shortcut.add_module('shortcut',
|
119 |
+
nn.AvgPool2d(2, stride=2))
|
120 |
+
|
121 |
+
self.selayer = Selayer(out_channels)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
bottleneck = self.conv_reduce.forward(x)
|
125 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
126 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
127 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
128 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
129 |
+
bottleneck = self.selayer(bottleneck)
|
130 |
+
|
131 |
+
x = self.shortcut.forward(x)
|
132 |
+
return x + bottleneck
|
133 |
+
|
134 |
+
class SpectrResNeXtBottleneck(nn.Module):
|
135 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
136 |
+
super(SpectrResNeXtBottleneck, self).__init__()
|
137 |
+
D = out_channels // 2
|
138 |
+
self.out_channels = out_channels
|
139 |
+
self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
|
140 |
+
self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
141 |
+
groups=cardinality,
|
142 |
+
bias=False))
|
143 |
+
self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
144 |
+
self.shortcut = nn.Sequential()
|
145 |
+
if stride != 1:
|
146 |
+
self.shortcut.add_module('shortcut',
|
147 |
+
nn.AvgPool2d(2, stride=2))
|
148 |
+
|
149 |
+
self.selayer = SelayerSpectr(out_channels)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
bottleneck = self.conv_reduce.forward(x)
|
153 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
154 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
155 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
156 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
157 |
+
bottleneck = self.selayer(bottleneck)
|
158 |
+
|
159 |
+
x = self.shortcut.forward(x)
|
160 |
+
return x + bottleneck
|
161 |
+
|
162 |
+
class FeatureConv(nn.Module):
|
163 |
+
def __init__(self, input_dim=512, output_dim=512):
|
164 |
+
super(FeatureConv, self).__init__()
|
165 |
+
|
166 |
+
no_bn = True
|
167 |
+
|
168 |
+
seq = []
|
169 |
+
seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
170 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
171 |
+
seq.append(nn.ReLU(inplace=True))
|
172 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
173 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
174 |
+
seq.append(nn.ReLU(inplace=True))
|
175 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
176 |
+
seq.append(nn.ReLU(inplace=True))
|
177 |
+
|
178 |
+
self.network = nn.Sequential(*seq)
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
return self.network(x)
|
182 |
+
|
183 |
+
class Generator(nn.Module):
|
184 |
+
def __init__(self, ngf=64):
|
185 |
+
super(Generator, self).__init__()
|
186 |
+
|
187 |
+
self.feature_conv = FeatureConv()
|
188 |
+
|
189 |
+
self.to0 = self._make_encoder_block_first(6, 32)
|
190 |
+
self.to1 = self._make_encoder_block(32, 64)
|
191 |
+
self.to2 = self._make_encoder_block(64, 128)
|
192 |
+
self.to3 = self._make_encoder_block(128, 256)
|
193 |
+
self.to4 = self._make_encoder_block(256, 512)
|
194 |
+
|
195 |
+
self.deconv_for_decoder = nn.Sequential(
|
196 |
+
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
|
197 |
+
nn.LeakyReLU(0.2),
|
198 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
|
199 |
+
nn.LeakyReLU(0.2),
|
200 |
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
|
201 |
+
nn.LeakyReLU(0.2),
|
202 |
+
nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
203 |
+
nn.Tanh(),
|
204 |
+
)
|
205 |
+
|
206 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
|
207 |
+
|
208 |
+
self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
|
209 |
+
nn.LeakyReLU(0.2, True),
|
210 |
+
tunnel4,
|
211 |
+
nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
|
212 |
+
nn.PixelShuffle(2),
|
213 |
+
nn.LeakyReLU(0.2, True)
|
214 |
+
) # 64
|
215 |
+
|
216 |
+
depth = 2
|
217 |
+
tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
|
218 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
|
219 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
|
220 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
|
221 |
+
ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
|
222 |
+
tunnel3 = nn.Sequential(*tunnel)
|
223 |
+
|
224 |
+
self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
|
225 |
+
nn.LeakyReLU(0.2, True),
|
226 |
+
tunnel3,
|
227 |
+
nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
|
228 |
+
nn.PixelShuffle(2),
|
229 |
+
nn.LeakyReLU(0.2, True)
|
230 |
+
) # 128
|
231 |
+
|
232 |
+
tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
|
233 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
|
234 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
|
235 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
|
236 |
+
ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
|
237 |
+
tunnel2 = nn.Sequential(*tunnel)
|
238 |
+
|
239 |
+
self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
|
240 |
+
nn.LeakyReLU(0.2, True),
|
241 |
+
tunnel2,
|
242 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
|
243 |
+
nn.PixelShuffle(2),
|
244 |
+
nn.LeakyReLU(0.2, True)
|
245 |
+
)
|
246 |
+
|
247 |
+
tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
248 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
|
249 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
|
250 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
|
251 |
+
ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
252 |
+
tunnel1 = nn.Sequential(*tunnel)
|
253 |
+
|
254 |
+
self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
|
255 |
+
nn.LeakyReLU(0.2, True),
|
256 |
+
tunnel1,
|
257 |
+
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
|
258 |
+
nn.PixelShuffle(2),
|
259 |
+
nn.LeakyReLU(0.2, True)
|
260 |
+
)
|
261 |
+
|
262 |
+
self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
|
263 |
+
|
264 |
+
|
265 |
+
def _make_encoder_block(self, inplanes, planes):
|
266 |
+
return nn.Sequential(
|
267 |
+
nn.Conv2d(inplanes, planes, 3, 2, 1),
|
268 |
+
nn.LeakyReLU(0.2),
|
269 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
270 |
+
nn.LeakyReLU(0.2),
|
271 |
+
)
|
272 |
+
|
273 |
+
def _make_encoder_block_first(self, inplanes, planes):
|
274 |
+
return nn.Sequential(
|
275 |
+
nn.Conv2d(inplanes, planes, 3, 1, 1),
|
276 |
+
nn.LeakyReLU(0.2),
|
277 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
278 |
+
nn.LeakyReLU(0.2),
|
279 |
+
)
|
280 |
+
|
281 |
+
def forward(self, sketch, sketch_feat):
|
282 |
+
|
283 |
+
x0 = self.to0(sketch)
|
284 |
+
x1 = self.to1(x0)
|
285 |
+
x2 = self.to2(x1)
|
286 |
+
x3 = self.to3(x2)
|
287 |
+
x4 = self.to4(x3)
|
288 |
+
|
289 |
+
sketch_feat = self.feature_conv(sketch_feat)
|
290 |
+
|
291 |
+
out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
x = self.tunnel3(torch.cat([out, x3], 1))
|
297 |
+
x = self.tunnel2(torch.cat([x, x2], 1))
|
298 |
+
x = self.tunnel1(torch.cat([x, x1], 1))
|
299 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
300 |
+
|
301 |
+
decoder_output = self.deconv_for_decoder(out)
|
302 |
+
|
303 |
+
return x, decoder_output
|
304 |
+
'''
|
305 |
+
class Colorizer(nn.Module):
|
306 |
+
def __init__(self, extractor_path = 'model/model.pth'):
|
307 |
+
super(Colorizer, self).__init__()
|
308 |
+
|
309 |
+
self.generator = Generator()
|
310 |
+
self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
|
311 |
+
|
312 |
+
def extractor_eval(self):
|
313 |
+
for param in self.extractor.parameters():
|
314 |
+
param.requires_grad = False
|
315 |
+
|
316 |
+
def extractor_train(self):
|
317 |
+
for param in extractor.parameters():
|
318 |
+
param.requires_grad = True
|
319 |
+
|
320 |
+
def forward(self, x, extractor_grad = False):
|
321 |
+
|
322 |
+
if extractor_grad:
|
323 |
+
features = self.extractor(x[:, 0:1])
|
324 |
+
else:
|
325 |
+
with torch.no_grad():
|
326 |
+
features = self.extractor(x[:, 0:1]).detach()
|
327 |
+
|
328 |
+
fake, guide = self.generator(x, features)
|
329 |
+
|
330 |
+
return fake, guide
|
331 |
+
'''
|
332 |
+
|
333 |
+
class Colorizer(nn.Module):
|
334 |
+
def __init__(self, generator_model, extractor_model):
|
335 |
+
super(Colorizer, self).__init__()
|
336 |
+
|
337 |
+
self.generator = generator_model
|
338 |
+
self.extractor = extractor_model
|
339 |
+
|
340 |
+
def load_generator_weights(self, gen_weights):
|
341 |
+
self.generator.load_state_dict(gen_weights)
|
342 |
+
|
343 |
+
def load_extractor_weights(self, ext_weights):
|
344 |
+
self.extractor.load_state_dict(ext_weights)
|
345 |
+
|
346 |
+
def extractor_eval(self):
|
347 |
+
for param in self.extractor.parameters():
|
348 |
+
param.requires_grad = False
|
349 |
+
self.extractor.eval()
|
350 |
+
|
351 |
+
def extractor_train(self):
|
352 |
+
for param in extractor.parameters():
|
353 |
+
param.requires_grad = True
|
354 |
+
self.extractor.train()
|
355 |
+
|
356 |
+
def forward(self, x, extractor_grad = False):
|
357 |
+
|
358 |
+
if extractor_grad:
|
359 |
+
features = self.extractor(x[:, 0:1])
|
360 |
+
else:
|
361 |
+
with torch.no_grad():
|
362 |
+
features = self.extractor(x[:, 0:1]).detach()
|
363 |
+
|
364 |
+
fake, guide = self.generator(x, features)
|
365 |
+
|
366 |
+
return fake, guide
|
367 |
+
|
368 |
+
class Discriminator(nn.Module):
|
369 |
+
def __init__(self, ndf=64):
|
370 |
+
super(Discriminator, self).__init__()
|
371 |
+
|
372 |
+
self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
|
373 |
+
nn.LeakyReLU(0.2, True),
|
374 |
+
SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
|
375 |
+
nn.LeakyReLU(0.2, True),
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
|
381 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
|
382 |
+
SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
|
383 |
+
nn.LeakyReLU(0.2, True),
|
384 |
+
|
385 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
|
386 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
|
387 |
+
SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
|
388 |
+
nn.LeakyReLU(0.2, True),
|
389 |
+
|
390 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
|
391 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
|
392 |
+
SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
|
393 |
+
nn.LeakyReLU(0.2, True),
|
394 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
395 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
|
396 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
397 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
398 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
399 |
+
)
|
400 |
+
|
401 |
+
self.out = nn.Linear(512, 1)
|
402 |
+
|
403 |
+
def forward(self, color):
|
404 |
+
x = self.feed(color)
|
405 |
+
|
406 |
+
out = self.out(x.view(color.size(0), -1))
|
407 |
+
return out
|
408 |
+
|
409 |
+
class Content(nn.Module):
|
410 |
+
def __init__(self, path):
|
411 |
+
super(Content, self).__init__()
|
412 |
+
vgg16 = M.vgg16()
|
413 |
+
vgg16.load_state_dict(torch.load(path))
|
414 |
+
vgg16.features = nn.Sequential(
|
415 |
+
*list(vgg16.features.children())[:9]
|
416 |
+
)
|
417 |
+
self.model = vgg16.features
|
418 |
+
self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
|
419 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
420 |
+
|
421 |
+
def forward(self, images):
|
422 |
+
return self.model((images.mul(0.5) - self.mean) / self.std)
|
denoising/models/.gitkeep
ADDED
File without changes
|
denoising/utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Different utilities such as orthogonalization of weights, initialization of
|
3 |
+
loggers, etc
|
4 |
+
|
5 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
6 |
+
|
7 |
+
This program is free software: you can use, modify and/or
|
8 |
+
redistribute it under the terms of the GNU General Public
|
9 |
+
License as published by the Free Software Foundation, either
|
10 |
+
version 3 of the License, or (at your option) any later
|
11 |
+
version. You should have received a copy of this license along
|
12 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
13 |
+
"""
|
14 |
+
import numpy as np
|
15 |
+
import cv2
|
16 |
+
|
17 |
+
|
18 |
+
def variable_to_cv2_image(varim):
|
19 |
+
r"""Converts a torch.autograd.Variable to an OpenCV image
|
20 |
+
|
21 |
+
Args:
|
22 |
+
varim: a torch.autograd.Variable
|
23 |
+
"""
|
24 |
+
nchannels = varim.size()[1]
|
25 |
+
if nchannels == 1:
|
26 |
+
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
|
27 |
+
elif nchannels == 3:
|
28 |
+
res = varim.data.cpu().numpy()[0]
|
29 |
+
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
|
30 |
+
res = (res*255.).clip(0, 255).astype(np.uint8)
|
31 |
+
else:
|
32 |
+
raise Exception('Number of color channels not supported')
|
33 |
+
return res
|
34 |
+
|
35 |
+
|
36 |
+
def normalize(data):
|
37 |
+
return np.float32(data/255.)
|
38 |
+
|
39 |
+
def remove_dataparallel_wrapper(state_dict):
|
40 |
+
r"""Converts a DataParallel model to a normal one by removing the "module."
|
41 |
+
wrapper in the module dictionary
|
42 |
+
|
43 |
+
Args:
|
44 |
+
state_dict: a torch.nn.DataParallel state dictionary
|
45 |
+
"""
|
46 |
+
from collections import OrderedDict
|
47 |
+
|
48 |
+
new_state_dict = OrderedDict()
|
49 |
+
for k, vl in state_dict.items():
|
50 |
+
name = k[7:] # remove 'module.' of DataParallel
|
51 |
+
new_state_dict[name] = vl
|
52 |
+
|
53 |
+
return new_state_dict
|
54 |
+
|
55 |
+
def is_rgb(im_path):
|
56 |
+
r""" Returns True if the image in im_path is an RGB image
|
57 |
+
"""
|
58 |
+
from skimage.io import imread
|
59 |
+
rgb = False
|
60 |
+
im = imread(im_path)
|
61 |
+
if (len(im.shape) == 3):
|
62 |
+
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
|
63 |
+
rgb = True
|
64 |
+
print("rgb: {}".format(rgb))
|
65 |
+
print("im shape: {}".format(im.shape))
|
66 |
+
return rgb
|
drawing.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
import base64
|
4 |
+
import random
|
5 |
+
import string
|
6 |
+
import shutil
|
7 |
+
import torch
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file, Response
|
11 |
+
from flask_wtf import FlaskForm
|
12 |
+
from wtforms import StringField, FileField, BooleanField, DecimalField
|
13 |
+
from wtforms.validators import DataRequired
|
14 |
+
from flask import after_this_request
|
15 |
+
|
16 |
+
from model.models import Colorizer, Generator
|
17 |
+
from model.extractor import get_seresnext_extractor
|
18 |
+
from utils.xdog import XDoGSketcher
|
19 |
+
from utils.utils import open_json
|
20 |
+
from denoising.denoiser import FFDNetDenoiser
|
21 |
+
from inference import process_image_with_hint
|
22 |
+
from utils.utils import resize_pad
|
23 |
+
from utils.dataset_utils import get_sketch
|
24 |
+
|
25 |
+
def generate_id(size=25, chars=string.ascii_letters + string.digits):
|
26 |
+
return ''.join(random.SystemRandom().choice(chars) for _ in range(size))
|
27 |
+
|
28 |
+
def generate_unique_id(current_ids = set()):
|
29 |
+
id_t = generate_id()
|
30 |
+
while id_t in current_ids:
|
31 |
+
id_t = generate_id()
|
32 |
+
|
33 |
+
current_ids.add(id_t)
|
34 |
+
|
35 |
+
return id_t
|
36 |
+
|
37 |
+
app = Flask(__name__)
|
38 |
+
app.config.update(dict(
|
39 |
+
SECRET_KEY="lol kek",
|
40 |
+
WTF_CSRF_SECRET_KEY="cheburek"
|
41 |
+
))
|
42 |
+
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
device = 'cuda'
|
45 |
+
else:
|
46 |
+
device = 'cpu'
|
47 |
+
|
48 |
+
colorizer = torch.jit.load('./model/colorizer.zip', map_location=torch.device(device))
|
49 |
+
|
50 |
+
sketcher = XDoGSketcher()
|
51 |
+
xdog_config = open_json('configs/xdog_config.json')
|
52 |
+
for key in xdog_config.keys():
|
53 |
+
if key in sketcher.params:
|
54 |
+
sketcher.params[key] = xdog_config[key]
|
55 |
+
|
56 |
+
denoiser = FFDNetDenoiser(device)
|
57 |
+
|
58 |
+
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True, 'auto_hint' : False, 'ignore_gray' : False, 'denoiser' : denoiser, 'denoiser_sigma' : 25}
|
59 |
+
|
60 |
+
|
61 |
+
class SubmitForm(FlaskForm):
|
62 |
+
file = FileField(validators=[DataRequired(), ])
|
63 |
+
|
64 |
+
def preprocess_image(file_id, ext):
|
65 |
+
directory_path = os.path.join('static', 'temp_images', file_id)
|
66 |
+
original_path = os.path.join(directory_path, 'original') + ext
|
67 |
+
original_image = plt.imread(original_path)
|
68 |
+
|
69 |
+
resized_image, _ = resize_pad(original_image)
|
70 |
+
resized_image = denoiser.get_denoised_image(resized_image, 25)
|
71 |
+
bw, dfm = get_sketch(resized_image, sketcher, True)
|
72 |
+
|
73 |
+
resized_name = 'resized_' + str(resized_image.shape[0]) + '_' + str(resized_image.shape[1]) + '.png'
|
74 |
+
plt.imsave(os.path.join(directory_path, resized_name), resized_image)
|
75 |
+
plt.imsave(os.path.join(directory_path, 'bw.png'), bw, cmap = 'gray')
|
76 |
+
plt.imsave(os.path.join(directory_path, 'dfm.png'), dfm, cmap = 'gray')
|
77 |
+
os.remove(original_path)
|
78 |
+
|
79 |
+
empty_hint = np.zeros((resized_image.shape[0], resized_image.shape[1], 4), dtype = np.float32)
|
80 |
+
plt.imsave(os.path.join(directory_path, 'hint.png'), empty_hint)
|
81 |
+
|
82 |
+
@app.route('/', methods=['GET', 'POST'])
|
83 |
+
def upload():
|
84 |
+
form = SubmitForm()
|
85 |
+
if form.validate_on_submit():
|
86 |
+
input_data = form.file.data
|
87 |
+
|
88 |
+
_, ext = os.path.splitext(input_data.filename)
|
89 |
+
|
90 |
+
if ext not in ('.jpg', '.png', '.jpeg'):
|
91 |
+
return abort(400)
|
92 |
+
|
93 |
+
file_id = generate_unique_id()
|
94 |
+
directory = os.path.join('static', 'temp_images', file_id)
|
95 |
+
original_filename = os.path.join(directory, 'original') + ext
|
96 |
+
|
97 |
+
try :
|
98 |
+
os.mkdir(directory)
|
99 |
+
input_data.save(original_filename)
|
100 |
+
|
101 |
+
preprocess_image(file_id, ext)
|
102 |
+
|
103 |
+
return redirect(f'/draw/{file_id}')
|
104 |
+
|
105 |
+
except :
|
106 |
+
print('Failed to colorize')
|
107 |
+
if os.path.exists(directory):
|
108 |
+
shutil.rmtree(directory)
|
109 |
+
return abort(400)
|
110 |
+
|
111 |
+
|
112 |
+
return render_template("upload.html", form = form)
|
113 |
+
|
114 |
+
@app.route('/img/<file_id>')
|
115 |
+
def show_image(file_id):
|
116 |
+
if not os.path.exists(os.path.join('static', 'temp_images', str(file_id))):
|
117 |
+
abort(404)
|
118 |
+
return f'<img src="/static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}">'
|
119 |
+
|
120 |
+
def colorize_image(file_id):
|
121 |
+
directory_path = os.path.join('static', 'temp_images', file_id)
|
122 |
+
|
123 |
+
bw = plt.imread(os.path.join(directory_path, 'bw.png'))[..., :1]
|
124 |
+
dfm = plt.imread(os.path.join(directory_path, 'dfm.png'))[..., :1]
|
125 |
+
hint = plt.imread(os.path.join(directory_path, 'hint.png'))
|
126 |
+
|
127 |
+
return process_image_with_hint(bw, dfm, hint, color_args)
|
128 |
+
|
129 |
+
@app.route('/colorize', methods=['POST'])
|
130 |
+
def colorize():
|
131 |
+
|
132 |
+
file_id = request.form['save_file_id']
|
133 |
+
file_id = file_id[file_id.rfind('/') + 1:]
|
134 |
+
|
135 |
+
img_data = request.form['save_image']
|
136 |
+
img_data = img_data[img_data.find(',') + 1:]
|
137 |
+
|
138 |
+
directory_path = os.path.join('static', 'temp_images', file_id)
|
139 |
+
|
140 |
+
with open(os.path.join(directory_path, 'hint.png'), "wb") as im:
|
141 |
+
im.write(base64.decodestring(str.encode(img_data)))
|
142 |
+
|
143 |
+
result = colorize_image(file_id)
|
144 |
+
|
145 |
+
plt.imsave(os.path.join(directory_path, 'colorized.png'), result)
|
146 |
+
|
147 |
+
src_path = f'../static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}'
|
148 |
+
|
149 |
+
return src_path
|
150 |
+
|
151 |
+
@app.route('/draw/<file_id>', methods=['GET', 'POST'])
|
152 |
+
def paintapp(file_id):
|
153 |
+
if request.method == 'GET':
|
154 |
+
|
155 |
+
directory_path = os.path.join('static', 'temp_images', str(file_id))
|
156 |
+
if not os.path.exists(directory_path):
|
157 |
+
abort(404)
|
158 |
+
|
159 |
+
resized_name = [x for x in os.listdir(directory_path) if x.startswith('resized_')][0]
|
160 |
+
|
161 |
+
split = os.path.splitext(resized_name)[0].split('_')
|
162 |
+
width = int(split[2])
|
163 |
+
height = int(split[1])
|
164 |
+
|
165 |
+
return render_template("drawing.html", height = height, width = width, img_path = os.path.join('temp_images', str(file_id), resized_name))
|
inference.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from utils.dataset_utils import get_sketch
|
5 |
+
from utils.utils import resize_pad, generate_mask, extract_cbr, create_cbz, sorted_alphanumeric, subfolder_image_search, remove_folder
|
6 |
+
from torchvision.transforms import ToTensor
|
7 |
+
import os
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import argparse
|
10 |
+
from model.models import Colorizer, Generator
|
11 |
+
from model.extractor import get_seresnext_extractor
|
12 |
+
from utils.xdog import XDoGSketcher
|
13 |
+
from utils.utils import open_json
|
14 |
+
import sys
|
15 |
+
from denoising.denoiser import FFDNetDenoiser
|
16 |
+
|
17 |
+
def colorize_without_hint(inp, color_args):
|
18 |
+
i_hint = torch.zeros(1, 4, inp.shape[2], inp.shape[3]).float().to(color_args['device'])
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
|
22 |
+
|
23 |
+
if color_args['auto_hint']:
|
24 |
+
mask = generate_mask(fake_color.shape[2], fake_color.shape[3], full = False, prob = 1, sigma = color_args['auto_hint_sigma']).unsqueeze(0)
|
25 |
+
mask = mask.to(color_args['device'])
|
26 |
+
|
27 |
+
|
28 |
+
if color_args['ignore_gray']:
|
29 |
+
diff1 = torch.abs(fake_color[:, 0] - fake_color[:, 1])
|
30 |
+
diff2 = torch.abs(fake_color[:, 0] - fake_color[:, 2])
|
31 |
+
diff3 = torch.abs(fake_color[:, 1] - fake_color[:, 2])
|
32 |
+
mask = ((mask + ((diff1 + diff2 + diff3) > 60 / 255).float().unsqueeze(1)) == 2).float()
|
33 |
+
|
34 |
+
|
35 |
+
i_hint = torch.cat([fake_color * mask, mask], 1)
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
|
39 |
+
|
40 |
+
return fake_color
|
41 |
+
|
42 |
+
|
43 |
+
def process_image(image, color_args, to_tensor = ToTensor()):
|
44 |
+
image, pad = resize_pad(image)
|
45 |
+
|
46 |
+
if color_args['denoiser'] is not None:
|
47 |
+
image = color_args['denoiser'].get_denoised_image(image, color_args['denoiser_sigma'])
|
48 |
+
|
49 |
+
bw, dfm = get_sketch(image, color_args['sketcher'], color_args['dfm'])
|
50 |
+
|
51 |
+
bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
|
52 |
+
dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
|
53 |
+
|
54 |
+
output = colorize_without_hint(torch.cat([bw, dfm], 1), color_args)
|
55 |
+
result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
|
56 |
+
|
57 |
+
if pad[0] != 0:
|
58 |
+
result = result[:-pad[0]]
|
59 |
+
if pad[1] != 0:
|
60 |
+
result = result[:, :-pad[1]]
|
61 |
+
|
62 |
+
return result
|
63 |
+
|
64 |
+
def colorize_with_hint(inp, color_args):
|
65 |
+
with torch.no_grad():
|
66 |
+
fake_color, _ = color_args['colorizer'](inp)
|
67 |
+
|
68 |
+
return fake_color
|
69 |
+
|
70 |
+
def process_image_with_hint(bw, dfm, hint, color_args, to_tensor = ToTensor()):
|
71 |
+
bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
|
72 |
+
dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
|
73 |
+
|
74 |
+
i_hint = (torch.FloatTensor(hint[..., :3]).permute(2, 0, 1) - 0.5) / 0.5
|
75 |
+
mask = torch.FloatTensor(hint[..., 3:]).permute(2, 0, 1)
|
76 |
+
i_hint = torch.cat([i_hint * mask, mask], 0).unsqueeze(0).to(color_args['device'])
|
77 |
+
|
78 |
+
output = colorize_with_hint(torch.cat([bw, dfm, i_hint], 1), color_args)
|
79 |
+
result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
|
80 |
+
|
81 |
+
return result
|
82 |
+
|
83 |
+
def colorize_single_image(file_path, save_path, color_args):
|
84 |
+
try:
|
85 |
+
image = plt.imread(file_path)
|
86 |
+
|
87 |
+
colorization = process_image(image, color_args)
|
88 |
+
|
89 |
+
plt.imsave(save_path, colorization)
|
90 |
+
|
91 |
+
return True
|
92 |
+
except KeyboardInterrupt:
|
93 |
+
sys.exit(0)
|
94 |
+
except:
|
95 |
+
print('Failed to colorize {}'.format(file_path))
|
96 |
+
return False
|
97 |
+
|
98 |
+
def colorize_images(source_path, target_path, color_args):
|
99 |
+
images = os.listdir(source_path)
|
100 |
+
|
101 |
+
for image_name in images:
|
102 |
+
file_path = os.path.join(source_path, image_name)
|
103 |
+
|
104 |
+
name, ext = os.path.splitext(image_name)
|
105 |
+
if (ext != '.png'):
|
106 |
+
image_name = name + '.png'
|
107 |
+
|
108 |
+
save_path = os.path.join(target_path, image_name)
|
109 |
+
colorize_single_image(file_path, save_path, color_args)
|
110 |
+
|
111 |
+
def colorize_cbr(file_path, color_args):
|
112 |
+
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
113 |
+
temp_path = 'temp_colorization'
|
114 |
+
|
115 |
+
if not os.path.exists(temp_path):
|
116 |
+
os.makedirs(temp_path)
|
117 |
+
extract_cbr(file_path, temp_path)
|
118 |
+
|
119 |
+
images = subfolder_image_search(temp_path)
|
120 |
+
|
121 |
+
result_images = []
|
122 |
+
for image_path in images:
|
123 |
+
save_path = image_path
|
124 |
+
|
125 |
+
path, ext = os.path.splitext(save_path)
|
126 |
+
if (ext != '.png'):
|
127 |
+
save_path = path + '.png'
|
128 |
+
|
129 |
+
res_flag = colorize_single_image(image_path, save_path, color_args)
|
130 |
+
|
131 |
+
result_images.append(save_path if res_flag else image_path)
|
132 |
+
|
133 |
+
|
134 |
+
result_name = os.path.join(os.path.dirname(file_path), file_name + '_colorized.cbz')
|
135 |
+
|
136 |
+
create_cbz(result_name, result_images)
|
137 |
+
|
138 |
+
remove_folder(temp_path)
|
139 |
+
|
140 |
+
return result_name
|
141 |
+
|
142 |
+
def parse_args():
|
143 |
+
parser = argparse.ArgumentParser()
|
144 |
+
parser.add_argument("-p", "--path", required=True)
|
145 |
+
parser.add_argument("-gen", "--generator", default = 'model/generator.pth')
|
146 |
+
parser.add_argument("-ext", "--extractor", default = 'model/extractor.pth')
|
147 |
+
parser.add_argument("-s", "--sigma", type = float, default = 0.003)
|
148 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
149 |
+
parser.add_argument('-ah', '--auto', dest = 'autohint', action = 'store_true')
|
150 |
+
parser.add_argument('-ig', '--ignore_grey', dest = 'ignore', action = 'store_true')
|
151 |
+
parser.add_argument('-nd', '--no_denoise', dest = 'denoiser', action = 'store_false')
|
152 |
+
parser.add_argument("-ds", "--denoiser_sigma", type = int, default = 25)
|
153 |
+
parser.set_defaults(gpu = False)
|
154 |
+
parser.set_defaults(autohint = False)
|
155 |
+
parser.set_defaults(ignore = False)
|
156 |
+
parser.set_defaults(denoiser = True)
|
157 |
+
args = parser.parse_args()
|
158 |
+
|
159 |
+
return args
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
|
164 |
+
args = parse_args()
|
165 |
+
|
166 |
+
if args.gpu:
|
167 |
+
device = 'cuda'
|
168 |
+
else:
|
169 |
+
device = 'cpu'
|
170 |
+
|
171 |
+
generator = Generator()
|
172 |
+
generator.load_state_dict(torch.load(args.generator))
|
173 |
+
|
174 |
+
extractor = get_seresnext_extractor()
|
175 |
+
extractor.load_state_dict(torch.load(args.extractor))
|
176 |
+
|
177 |
+
colorizer = Colorizer(generator, extractor)
|
178 |
+
colorizer = colorizer.eval().to(device)
|
179 |
+
|
180 |
+
sketcher = XDoGSketcher()
|
181 |
+
xdog_config = open_json('configs/xdog_config.json')
|
182 |
+
for key in xdog_config.keys():
|
183 |
+
if key in sketcher.params:
|
184 |
+
sketcher.params[key] = xdog_config[key]
|
185 |
+
|
186 |
+
denoiser = None
|
187 |
+
if args.denoiser:
|
188 |
+
denoiser = FFDNetDenoiser(device, args.denoiser_sigma)
|
189 |
+
|
190 |
+
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'auto_hint':args.autohint, 'auto_hint_sigma':args.sigma,\
|
191 |
+
'ignore_gray':args.ignore, 'device':device, 'dfm' : True, 'denoiser':denoiser, 'denoiser_sigma' : args.denoiser_sigma}
|
192 |
+
|
193 |
+
|
194 |
+
if os.path.isdir(args.path):
|
195 |
+
colorization_path = os.path.join(args.path, 'colorization')
|
196 |
+
if not os.path.exists(colorization_path):
|
197 |
+
os.makedirs(colorization_path)
|
198 |
+
|
199 |
+
colorize_images(args.path, colorization_path, color_args)
|
200 |
+
|
201 |
+
elif os.path.isfile(args.path):
|
202 |
+
|
203 |
+
split = os.path.splitext(args.path)
|
204 |
+
|
205 |
+
if split[1].lower() in ('.cbr', '.cbz', '.rar', '.zip'):
|
206 |
+
colorize_cbr(args.path, color_args)
|
207 |
+
elif split[1].lower() in ('.jpg', '.png', ',jpeg'):
|
208 |
+
new_image_path = split[0] + '_colorized' + '.png'
|
209 |
+
|
210 |
+
colorize_single_image(args.path, new_image_path, color_args)
|
211 |
+
else:
|
212 |
+
print('Wrong format')
|
213 |
+
else:
|
214 |
+
print('Wrong path')
|
215 |
+
|
model/__pycache__/extractor.cpython-310.pyc
ADDED
Binary file (3.97 kB). View file
|
|
model/__pycache__/extractor.cpython-36.pyc
ADDED
Binary file (3.9 kB). View file
|
|
model/__pycache__/extractor.cpython-39.pyc
ADDED
Binary file (3.95 kB). View file
|
|
model/__pycache__/models.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
model/__pycache__/models.cpython-36.pyc
ADDED
Binary file (14 kB). View file
|
|
model/__pycache__/models.cpython-39.pyc
ADDED
Binary file (13.5 kB). View file
|
|
model/extractor.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee3c59f02ac8c59298fd9b819fa33d2efa168847e15e4be39b35c286f7c18607
|
3 |
+
size 6340842
|
model/extractor.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
'''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
|
6 |
+
|
7 |
+
# Pretrained version
|
8 |
+
class Selayer(nn.Module):
|
9 |
+
def __init__(self, inplanes):
|
10 |
+
super(Selayer, self).__init__()
|
11 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
12 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
13 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
14 |
+
self.relu = nn.ReLU(inplace=True)
|
15 |
+
self.sigmoid = nn.Sigmoid()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
out = self.global_avgpool(x)
|
19 |
+
out = self.conv1(out)
|
20 |
+
out = self.relu(out)
|
21 |
+
out = self.conv2(out)
|
22 |
+
out = self.sigmoid(out)
|
23 |
+
|
24 |
+
return x * out
|
25 |
+
|
26 |
+
|
27 |
+
class BottleneckX_Origin(nn.Module):
|
28 |
+
expansion = 4
|
29 |
+
|
30 |
+
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
|
31 |
+
super(BottleneckX_Origin, self).__init__()
|
32 |
+
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
33 |
+
self.bn1 = nn.BatchNorm2d(planes * 2)
|
34 |
+
|
35 |
+
self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
|
36 |
+
padding=1, groups=cardinality, bias=False)
|
37 |
+
self.bn2 = nn.BatchNorm2d(planes * 2)
|
38 |
+
|
39 |
+
self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
|
40 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
41 |
+
|
42 |
+
self.selayer = Selayer(planes * 4)
|
43 |
+
|
44 |
+
self.relu = nn.ReLU(inplace=True)
|
45 |
+
self.downsample = downsample
|
46 |
+
self.stride = stride
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
residual = x
|
50 |
+
|
51 |
+
out = self.conv1(x)
|
52 |
+
out = self.bn1(out)
|
53 |
+
out = self.relu(out)
|
54 |
+
|
55 |
+
out = self.conv2(out)
|
56 |
+
out = self.bn2(out)
|
57 |
+
out = self.relu(out)
|
58 |
+
|
59 |
+
out = self.conv3(out)
|
60 |
+
out = self.bn3(out)
|
61 |
+
|
62 |
+
out = self.selayer(out)
|
63 |
+
|
64 |
+
if self.downsample is not None:
|
65 |
+
residual = self.downsample(x)
|
66 |
+
|
67 |
+
out += residual
|
68 |
+
out = self.relu(out)
|
69 |
+
|
70 |
+
return out
|
71 |
+
|
72 |
+
class SEResNeXt_extractor(nn.Module):
|
73 |
+
def __init__(self, block, layers, input_channels=3, cardinality=32):
|
74 |
+
super(SEResNeXt_extractor, self).__init__()
|
75 |
+
self.cardinality = cardinality
|
76 |
+
self.inplanes = 64
|
77 |
+
self.input_channels = input_channels
|
78 |
+
|
79 |
+
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
|
80 |
+
bias=False)
|
81 |
+
self.bn1 = nn.BatchNorm2d(64)
|
82 |
+
self.relu = nn.ReLU(inplace=True)
|
83 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
84 |
+
|
85 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
86 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
87 |
+
|
88 |
+
for m in self.modules():
|
89 |
+
if isinstance(m, nn.Conv2d):
|
90 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
91 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
92 |
+
if m.bias is not None:
|
93 |
+
m.bias.data.zero_()
|
94 |
+
elif isinstance(m, nn.BatchNorm2d):
|
95 |
+
m.weight.data.fill_(1)
|
96 |
+
m.bias.data.zero_()
|
97 |
+
|
98 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
99 |
+
downsample = None
|
100 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
101 |
+
downsample = nn.Sequential(
|
102 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
103 |
+
kernel_size=1, stride=stride, bias=False),
|
104 |
+
nn.BatchNorm2d(planes * block.expansion),
|
105 |
+
)
|
106 |
+
|
107 |
+
layers = []
|
108 |
+
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
|
109 |
+
self.inplanes = planes * block.expansion
|
110 |
+
for i in range(1, blocks):
|
111 |
+
layers.append(block(self.inplanes, planes, self.cardinality))
|
112 |
+
|
113 |
+
return nn.Sequential(*layers)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
x = self.conv1(x)
|
117 |
+
x = self.bn1(x)
|
118 |
+
x = self.relu(x)
|
119 |
+
x = self.maxpool(x)
|
120 |
+
|
121 |
+
x = self.layer1(x)
|
122 |
+
x = self.layer2(x)
|
123 |
+
|
124 |
+
return x
|
125 |
+
|
126 |
+
def get_seresnext_extractor():
|
127 |
+
return SEResNeXt_extractor(BottleneckX_Origin, [3, 4, 6, 3], 1)
|
model/models.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as M
|
5 |
+
import math
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import Parameter
|
8 |
+
|
9 |
+
'''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
|
10 |
+
|
11 |
+
def l2normalize(v, eps=1e-12):
|
12 |
+
return v / (v.norm() + eps)
|
13 |
+
|
14 |
+
|
15 |
+
class SpectralNorm(nn.Module):
|
16 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
17 |
+
super(SpectralNorm, self).__init__()
|
18 |
+
self.module = module
|
19 |
+
self.name = name
|
20 |
+
self.power_iterations = power_iterations
|
21 |
+
if not self._made_params():
|
22 |
+
self._make_params()
|
23 |
+
|
24 |
+
def _update_u_v(self):
|
25 |
+
u = getattr(self.module, self.name + "_u")
|
26 |
+
v = getattr(self.module, self.name + "_v")
|
27 |
+
w = getattr(self.module, self.name + "_bar")
|
28 |
+
|
29 |
+
height = w.data.shape[0]
|
30 |
+
for _ in range(self.power_iterations):
|
31 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
32 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
33 |
+
|
34 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
35 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
36 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
37 |
+
|
38 |
+
def _made_params(self):
|
39 |
+
try:
|
40 |
+
u = getattr(self.module, self.name + "_u")
|
41 |
+
v = getattr(self.module, self.name + "_v")
|
42 |
+
w = getattr(self.module, self.name + "_bar")
|
43 |
+
return True
|
44 |
+
except AttributeError:
|
45 |
+
return False
|
46 |
+
|
47 |
+
|
48 |
+
def _make_params(self):
|
49 |
+
w = getattr(self.module, self.name)
|
50 |
+
height = w.data.shape[0]
|
51 |
+
width = w.view(height, -1).data.shape[1]
|
52 |
+
|
53 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
54 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
55 |
+
u.data = l2normalize(u.data)
|
56 |
+
v.data = l2normalize(v.data)
|
57 |
+
w_bar = Parameter(w.data)
|
58 |
+
|
59 |
+
del self.module._parameters[self.name]
|
60 |
+
|
61 |
+
self.module.register_parameter(self.name + "_u", u)
|
62 |
+
self.module.register_parameter(self.name + "_v", v)
|
63 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
64 |
+
|
65 |
+
|
66 |
+
def forward(self, *args):
|
67 |
+
self._update_u_v()
|
68 |
+
return self.module.forward(*args)
|
69 |
+
|
70 |
+
class Selayer(nn.Module):
|
71 |
+
def __init__(self, inplanes):
|
72 |
+
super(Selayer, self).__init__()
|
73 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
74 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
75 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
76 |
+
self.relu = nn.ReLU(inplace=True)
|
77 |
+
self.sigmoid = nn.Sigmoid()
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
out = self.global_avgpool(x)
|
81 |
+
out = self.conv1(out)
|
82 |
+
out = self.relu(out)
|
83 |
+
out = self.conv2(out)
|
84 |
+
out = self.sigmoid(out)
|
85 |
+
|
86 |
+
return x * out
|
87 |
+
|
88 |
+
class SelayerSpectr(nn.Module):
|
89 |
+
def __init__(self, inplanes):
|
90 |
+
super(SelayerSpectr, self).__init__()
|
91 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
92 |
+
self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
|
93 |
+
self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
|
94 |
+
self.relu = nn.ReLU(inplace=True)
|
95 |
+
self.sigmoid = nn.Sigmoid()
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
out = self.global_avgpool(x)
|
99 |
+
out = self.conv1(out)
|
100 |
+
out = self.relu(out)
|
101 |
+
out = self.conv2(out)
|
102 |
+
out = self.sigmoid(out)
|
103 |
+
|
104 |
+
return x * out
|
105 |
+
|
106 |
+
class ResNeXtBottleneck(nn.Module):
|
107 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
108 |
+
super(ResNeXtBottleneck, self).__init__()
|
109 |
+
D = out_channels // 2
|
110 |
+
self.out_channels = out_channels
|
111 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
112 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
113 |
+
groups=cardinality,
|
114 |
+
bias=False)
|
115 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
116 |
+
self.shortcut = nn.Sequential()
|
117 |
+
if stride != 1:
|
118 |
+
self.shortcut.add_module('shortcut',
|
119 |
+
nn.AvgPool2d(2, stride=2))
|
120 |
+
|
121 |
+
self.selayer = Selayer(out_channels)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
bottleneck = self.conv_reduce.forward(x)
|
125 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
126 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
127 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
128 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
129 |
+
bottleneck = self.selayer(bottleneck)
|
130 |
+
|
131 |
+
x = self.shortcut.forward(x)
|
132 |
+
return x + bottleneck
|
133 |
+
|
134 |
+
class SpectrResNeXtBottleneck(nn.Module):
|
135 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
136 |
+
super(SpectrResNeXtBottleneck, self).__init__()
|
137 |
+
D = out_channels // 2
|
138 |
+
self.out_channels = out_channels
|
139 |
+
self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
|
140 |
+
self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
141 |
+
groups=cardinality,
|
142 |
+
bias=False))
|
143 |
+
self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
144 |
+
self.shortcut = nn.Sequential()
|
145 |
+
if stride != 1:
|
146 |
+
self.shortcut.add_module('shortcut',
|
147 |
+
nn.AvgPool2d(2, stride=2))
|
148 |
+
|
149 |
+
self.selayer = SelayerSpectr(out_channels)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
bottleneck = self.conv_reduce.forward(x)
|
153 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
154 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
155 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
156 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
157 |
+
bottleneck = self.selayer(bottleneck)
|
158 |
+
|
159 |
+
x = self.shortcut.forward(x)
|
160 |
+
return x + bottleneck
|
161 |
+
|
162 |
+
class FeatureConv(nn.Module):
|
163 |
+
def __init__(self, input_dim=512, output_dim=512):
|
164 |
+
super(FeatureConv, self).__init__()
|
165 |
+
|
166 |
+
no_bn = True
|
167 |
+
|
168 |
+
seq = []
|
169 |
+
seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
170 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
171 |
+
seq.append(nn.ReLU(inplace=True))
|
172 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
173 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
174 |
+
seq.append(nn.ReLU(inplace=True))
|
175 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
176 |
+
seq.append(nn.ReLU(inplace=True))
|
177 |
+
|
178 |
+
self.network = nn.Sequential(*seq)
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
return self.network(x)
|
182 |
+
|
183 |
+
class Generator(nn.Module):
|
184 |
+
def __init__(self, ngf=64):
|
185 |
+
super(Generator, self).__init__()
|
186 |
+
|
187 |
+
self.feature_conv = FeatureConv()
|
188 |
+
|
189 |
+
self.to0 = self._make_encoder_block_first(6, 32)
|
190 |
+
self.to1 = self._make_encoder_block(32, 64)
|
191 |
+
self.to2 = self._make_encoder_block(64, 128)
|
192 |
+
self.to3 = self._make_encoder_block(128, 256)
|
193 |
+
self.to4 = self._make_encoder_block(256, 512)
|
194 |
+
|
195 |
+
self.deconv_for_decoder = nn.Sequential(
|
196 |
+
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
|
197 |
+
nn.LeakyReLU(0.2),
|
198 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
|
199 |
+
nn.LeakyReLU(0.2),
|
200 |
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
|
201 |
+
nn.LeakyReLU(0.2),
|
202 |
+
nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
203 |
+
nn.Tanh(),
|
204 |
+
)
|
205 |
+
|
206 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
|
207 |
+
|
208 |
+
self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
|
209 |
+
nn.LeakyReLU(0.2, True),
|
210 |
+
tunnel4,
|
211 |
+
nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
|
212 |
+
nn.PixelShuffle(2),
|
213 |
+
nn.LeakyReLU(0.2, True)
|
214 |
+
) # 64
|
215 |
+
|
216 |
+
depth = 2
|
217 |
+
tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
|
218 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
|
219 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
|
220 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
|
221 |
+
ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
|
222 |
+
tunnel3 = nn.Sequential(*tunnel)
|
223 |
+
|
224 |
+
self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
|
225 |
+
nn.LeakyReLU(0.2, True),
|
226 |
+
tunnel3,
|
227 |
+
nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
|
228 |
+
nn.PixelShuffle(2),
|
229 |
+
nn.LeakyReLU(0.2, True)
|
230 |
+
) # 128
|
231 |
+
|
232 |
+
tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
|
233 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
|
234 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
|
235 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
|
236 |
+
ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
|
237 |
+
tunnel2 = nn.Sequential(*tunnel)
|
238 |
+
|
239 |
+
self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
|
240 |
+
nn.LeakyReLU(0.2, True),
|
241 |
+
tunnel2,
|
242 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
|
243 |
+
nn.PixelShuffle(2),
|
244 |
+
nn.LeakyReLU(0.2, True)
|
245 |
+
)
|
246 |
+
|
247 |
+
tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
248 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
|
249 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
|
250 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
|
251 |
+
ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
252 |
+
tunnel1 = nn.Sequential(*tunnel)
|
253 |
+
|
254 |
+
self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
|
255 |
+
nn.LeakyReLU(0.2, True),
|
256 |
+
tunnel1,
|
257 |
+
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
|
258 |
+
nn.PixelShuffle(2),
|
259 |
+
nn.LeakyReLU(0.2, True)
|
260 |
+
)
|
261 |
+
|
262 |
+
self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
|
263 |
+
|
264 |
+
|
265 |
+
def _make_encoder_block(self, inplanes, planes):
|
266 |
+
return nn.Sequential(
|
267 |
+
nn.Conv2d(inplanes, planes, 3, 2, 1),
|
268 |
+
nn.LeakyReLU(0.2),
|
269 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
270 |
+
nn.LeakyReLU(0.2),
|
271 |
+
)
|
272 |
+
|
273 |
+
def _make_encoder_block_first(self, inplanes, planes):
|
274 |
+
return nn.Sequential(
|
275 |
+
nn.Conv2d(inplanes, planes, 3, 1, 1),
|
276 |
+
nn.LeakyReLU(0.2),
|
277 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
278 |
+
nn.LeakyReLU(0.2),
|
279 |
+
)
|
280 |
+
|
281 |
+
def forward(self, sketch, sketch_feat):
|
282 |
+
|
283 |
+
x0 = self.to0(sketch)
|
284 |
+
x1 = self.to1(x0)
|
285 |
+
x2 = self.to2(x1)
|
286 |
+
x3 = self.to3(x2)
|
287 |
+
x4 = self.to4(x3)
|
288 |
+
|
289 |
+
sketch_feat = self.feature_conv(sketch_feat)
|
290 |
+
|
291 |
+
out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
x = self.tunnel3(torch.cat([out, x3], 1))
|
297 |
+
x = self.tunnel2(torch.cat([x, x2], 1))
|
298 |
+
x = self.tunnel1(torch.cat([x, x1], 1))
|
299 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
300 |
+
|
301 |
+
decoder_output = self.deconv_for_decoder(out)
|
302 |
+
|
303 |
+
return x, decoder_output
|
304 |
+
'''
|
305 |
+
class Colorizer(nn.Module):
|
306 |
+
def __init__(self, extractor_path = 'model/model.pth'):
|
307 |
+
super(Colorizer, self).__init__()
|
308 |
+
|
309 |
+
self.generator = Generator()
|
310 |
+
self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
|
311 |
+
|
312 |
+
def extractor_eval(self):
|
313 |
+
for param in self.extractor.parameters():
|
314 |
+
param.requires_grad = False
|
315 |
+
|
316 |
+
def extractor_train(self):
|
317 |
+
for param in extractor.parameters():
|
318 |
+
param.requires_grad = True
|
319 |
+
|
320 |
+
def forward(self, x, extractor_grad = False):
|
321 |
+
|
322 |
+
if extractor_grad:
|
323 |
+
features = self.extractor(x[:, 0:1])
|
324 |
+
else:
|
325 |
+
with torch.no_grad():
|
326 |
+
features = self.extractor(x[:, 0:1]).detach()
|
327 |
+
|
328 |
+
fake, guide = self.generator(x, features)
|
329 |
+
|
330 |
+
return fake, guide
|
331 |
+
'''
|
332 |
+
|
333 |
+
class Colorizer(nn.Module):
|
334 |
+
def __init__(self, generator_model, extractor_model):
|
335 |
+
super(Colorizer, self).__init__()
|
336 |
+
|
337 |
+
self.generator = generator_model
|
338 |
+
self.extractor = extractor_model
|
339 |
+
|
340 |
+
def load_generator_weights(self, gen_weights):
|
341 |
+
self.generator.load_state_dict(gen_weights)
|
342 |
+
|
343 |
+
def load_extractor_weights(self, ext_weights):
|
344 |
+
self.extractor.load_state_dict(ext_weights)
|
345 |
+
|
346 |
+
def extractor_eval(self):
|
347 |
+
for param in self.extractor.parameters():
|
348 |
+
param.requires_grad = False
|
349 |
+
self.extractor.eval()
|
350 |
+
|
351 |
+
def extractor_train(self):
|
352 |
+
for param in extractor.parameters():
|
353 |
+
param.requires_grad = True
|
354 |
+
self.extractor.train()
|
355 |
+
|
356 |
+
def forward(self, x, extractor_grad = False):
|
357 |
+
|
358 |
+
if extractor_grad:
|
359 |
+
features = self.extractor(x[:, 0:1])
|
360 |
+
else:
|
361 |
+
with torch.no_grad():
|
362 |
+
features = self.extractor(x[:, 0:1]).detach()
|
363 |
+
|
364 |
+
fake, guide = self.generator(x, features)
|
365 |
+
|
366 |
+
return fake, guide
|
367 |
+
|
368 |
+
class Discriminator(nn.Module):
|
369 |
+
def __init__(self, ndf=64):
|
370 |
+
super(Discriminator, self).__init__()
|
371 |
+
|
372 |
+
self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
|
373 |
+
nn.LeakyReLU(0.2, True),
|
374 |
+
SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
|
375 |
+
nn.LeakyReLU(0.2, True),
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
|
381 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
|
382 |
+
SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
|
383 |
+
nn.LeakyReLU(0.2, True),
|
384 |
+
|
385 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
|
386 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
|
387 |
+
SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
|
388 |
+
nn.LeakyReLU(0.2, True),
|
389 |
+
|
390 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
|
391 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
|
392 |
+
SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
|
393 |
+
nn.LeakyReLU(0.2, True),
|
394 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
395 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
|
396 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
397 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
398 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
399 |
+
)
|
400 |
+
|
401 |
+
self.out = nn.Linear(512, 1)
|
402 |
+
|
403 |
+
def forward(self, color):
|
404 |
+
x = self.feed(color)
|
405 |
+
|
406 |
+
out = self.out(x.view(color.size(0), -1))
|
407 |
+
return out
|
408 |
+
|
409 |
+
class Content(nn.Module):
|
410 |
+
def __init__(self, path):
|
411 |
+
super(Content, self).__init__()
|
412 |
+
vgg16 = M.vgg16()
|
413 |
+
vgg16.load_state_dict(torch.load(path))
|
414 |
+
vgg16.features = nn.Sequential(
|
415 |
+
*list(vgg16.features.children())[:9]
|
416 |
+
)
|
417 |
+
self.model = vgg16.features
|
418 |
+
self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
|
419 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
420 |
+
|
421 |
+
def forward(self, images):
|
422 |
+
return self.model((images.mul(0.5) - self.mean) / self.std)
|
model/vgg16-397923af.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0
|
3 |
+
size 553433881
|
readme.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
UPD. See the [improved version](https://github.com/qweasdd/manga-colorization-v2).
|
2 |
+
|
3 |
+
# Automatic colorization
|
4 |
+
|
5 |
+
1. Download [generator](https://drive.google.com/file/d/1Oo6ycphJ3sUOpDCDoG29NA5pbhQVCevY/view?usp=sharing), [extractor](https://drive.google.com/file/d/12cbNyJcCa1zI2EBz6nea3BXl21Fm73Bt/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put generator and extractor weights in `model` and denoiser weights in `denoising/models`.
|
6 |
+
2. To colorize image, folder of images, `.cbz` or `.cbr` file, use the following command:
|
7 |
+
```
|
8 |
+
$ python inference.py -p "path to file or folder"
|
9 |
+
```
|
10 |
+
|
11 |
+
# Manual colorization with color hints
|
12 |
+
|
13 |
+
1. Download [colorizer](https://drive.google.com/file/d/1BERrMl9e7cKsk9m2L0q1yO4k7blNhEWC/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put colorizer weights in `model` and denoiser weights in `denoising/models`.
|
14 |
+
2. Run gunicorn server with:
|
15 |
+
```
|
16 |
+
$ ./run_drawing.sh
|
17 |
+
```
|
18 |
+
3. Open `localhost:5000` with a browser.
|
19 |
+
|
20 |
+
# References
|
21 |
+
1. Extractor weights are taken from https://github.com/blandocs/Tag2Pix/releases/download/release/model.pth
|
22 |
+
2. Denoiser weights are taken from http://www.ipol.im/pub/art/2019/231.
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flask==1.1.1
|
2 |
+
gunicorn
|
3 |
+
numpy==1.16.6
|
4 |
+
flask_wtf==0.14.3
|
5 |
+
matplotlib==3.1.1
|
6 |
+
opencv-python==4.1.2.30
|
7 |
+
snowy
|
8 |
+
scipy==1.3.3
|
9 |
+
scikit-image==0.15.0
|
10 |
+
patool==1.12
|
run_drawing.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
gunicorn --worker-class gevent --timeout 150 -w 1 -b 0.0.0.0:5000 drawing:app
|
static/js/draw.js
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
var canvas = document.getElementById('draw_canvas');
|
2 |
+
var ctx = canvas.getContext('2d');
|
3 |
+
var canvasWidth = canvas.width;
|
4 |
+
var canvasHeight = canvas.height;
|
5 |
+
var prevX, prevY;
|
6 |
+
|
7 |
+
var result_canvas = document.getElementById('result');
|
8 |
+
var result_ctx = result_canvas.getContext('2d');
|
9 |
+
result_canvas.width = canvas.width;
|
10 |
+
result_canvas.height = canvas.height;
|
11 |
+
|
12 |
+
var color_indicator = document.getElementById('color');
|
13 |
+
ctx.fillStyle = 'black';
|
14 |
+
color_indicator.value = '#000000';
|
15 |
+
|
16 |
+
var cur_id = window.location.pathname.substring(window.location.pathname.lastIndexOf('/') + 1);
|
17 |
+
|
18 |
+
function getRandomInt(max) {
|
19 |
+
return Math.floor(Math.random() * Math.floor(max));
|
20 |
+
}
|
21 |
+
|
22 |
+
var init_hint = new Image();
|
23 |
+
init_hint.addEventListener('load', function() {
|
24 |
+
ctx.drawImage(init_hint, 0, 0);
|
25 |
+
});
|
26 |
+
init_hint.src = '../static/temp_images/' + cur_id + '/hint.png?' + getRandomInt(100000).toString();
|
27 |
+
|
28 |
+
result_canvas.addEventListener('load', function(e) {
|
29 |
+
var img = new Image();
|
30 |
+
img.addEventListener('load', function() {
|
31 |
+
ctx.drawImage(img, 0, 0);
|
32 |
+
}, false);
|
33 |
+
console.log(window.location.pathname);
|
34 |
+
})
|
35 |
+
|
36 |
+
|
37 |
+
canvas.onload = function (e) {
|
38 |
+
var img = new Image();
|
39 |
+
img.addEventListener('load', function() {
|
40 |
+
ctx.drawImage(img, 0, 0);
|
41 |
+
}, false);
|
42 |
+
console.log(window.location.pathname);
|
43 |
+
//img.src = ;
|
44 |
+
}
|
45 |
+
|
46 |
+
function reset() {
|
47 |
+
ctx.clearRect(0, 0, canvasWidth, canvasHeight);
|
48 |
+
}
|
49 |
+
|
50 |
+
function getMousePos(canvas, evt) {
|
51 |
+
var rect = canvas.getBoundingClientRect();
|
52 |
+
return {
|
53 |
+
x: (evt.clientX - rect.left) / (rect.right - rect.left) * canvas.width,
|
54 |
+
y: (evt.clientY - rect.top) / (rect.bottom - rect.top) * canvas.height
|
55 |
+
};
|
56 |
+
}
|
57 |
+
|
58 |
+
function colorize() {
|
59 |
+
var file_id = document.location.pathname;
|
60 |
+
var image = canvas.toDataURL();
|
61 |
+
|
62 |
+
$.post("/colorize", { save_file_id: file_id, save_image: image}).done(function( data ) {
|
63 |
+
//console.log(document.location.origin + '/img/' + data)
|
64 |
+
//window.open(document.location.origin + '/img/' + data, '_blank');
|
65 |
+
//result.src = data;
|
66 |
+
var img = new Image();
|
67 |
+
img.addEventListener('load', function() {
|
68 |
+
result_ctx.drawImage(img, 0, 0);
|
69 |
+
}, false);
|
70 |
+
img.src = data;
|
71 |
+
});
|
72 |
+
}
|
73 |
+
|
74 |
+
canvas.addEventListener('mousedown', function(e) {
|
75 |
+
var mousePos = getMousePos(canvas, e);
|
76 |
+
if (e.button == 0) {
|
77 |
+
ctx.fillRect(mousePos['x'], mousePos['y'], 1, 1);
|
78 |
+
}
|
79 |
+
|
80 |
+
if (e.button == 2) {
|
81 |
+
prevX = mousePos['x']
|
82 |
+
prevY = mousePos['y']
|
83 |
+
}
|
84 |
+
|
85 |
+
})
|
86 |
+
|
87 |
+
canvas.addEventListener('mouseup', function(e) {
|
88 |
+
if (e.button == 2) {
|
89 |
+
var mousePos = getMousePos(canvas, e);
|
90 |
+
var diff_width = mousePos['x'] - prevX;
|
91 |
+
var diff_height = mousePos['y'] - prevY;
|
92 |
+
|
93 |
+
ctx.clearRect(prevX, prevY, diff_width, diff_height);
|
94 |
+
}
|
95 |
+
})
|
96 |
+
|
97 |
+
|
98 |
+
canvas.addEventListener('contextmenu', function(evt) {
|
99 |
+
evt.preventDefault();
|
100 |
+
})
|
101 |
+
|
102 |
+
function color(color_value){
|
103 |
+
ctx.fillStyle = color_value;
|
104 |
+
color_indicator.value = color_value;
|
105 |
+
}
|
106 |
+
|
107 |
+
color_indicator.oninput = function() {
|
108 |
+
color(this.value);
|
109 |
+
}
|
110 |
+
|
111 |
+
function rgbToHex(rgb){
|
112 |
+
return '#' + ((rgb[0] << 16) | (rgb[1] << 8) | rgb[2]).toString(16);
|
113 |
+
};
|
114 |
+
|
115 |
+
result_canvas.addEventListener('click', function(e) {
|
116 |
+
if (e.button == 0) {
|
117 |
+
var cur_pixel = result_ctx.getImageData(e.offsetX, e.offsetY, 1, 1).data;
|
118 |
+
color(rgbToHex(cur_pixel));
|
119 |
+
}
|
120 |
+
})
|
static/temp_images/.gitkeep
ADDED
File without changes
|
templates/drawing.html
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<title>Colorization app</title>
|
6 |
+
<style>
|
7 |
+
.back{
|
8 |
+
height:100%;
|
9 |
+
width:100%;
|
10 |
+
position: absolute;
|
11 |
+
background-color: yellow;
|
12 |
+
top:0px;
|
13 |
+
padding: 10px;
|
14 |
+
}
|
15 |
+
|
16 |
+
#draw_canvas {
|
17 |
+
height: 200%;
|
18 |
+
border: 3px solid black;
|
19 |
+
background-image: linear-gradient(rgba(60,60,60,.85), rgba(60,60,60,.85)), url(../static/{{img_path}});
|
20 |
+
background-color: #c7b39b;
|
21 |
+
background-size: 100%;
|
22 |
+
}
|
23 |
+
</style>
|
24 |
+
|
25 |
+
</head>
|
26 |
+
<body>
|
27 |
+
<div align="left" style ="margin : 5px; margin-left : 0px">
|
28 |
+
<input type="button" onclick="location.href='/';" value="Home" />
|
29 |
+
</div>
|
30 |
+
<p style="margin: 0; padding: 0">
|
31 |
+
Left click - colorize, right hold - remove with rectangle, left click on result - use corresponding color.
|
32 |
+
|
33 |
+
</p>
|
34 |
+
<hr style ="margin: 0; padding: 0">
|
35 |
+
|
36 |
+
<p><table>
|
37 |
+
<tr>
|
38 |
+
<td>
|
39 |
+
<table>
|
40 |
+
<tr>
|
41 |
+
<td><button style="background-color: #000000; height: 20px; width: 20px;" onclick="color('#000000')"></button>
|
42 |
+
<td><button style="background-color: #B0171F; height: 20px; width: 20px;" onclick="color('#B0171F')"></button>
|
43 |
+
</tr>
|
44 |
+
<tr>
|
45 |
+
<td><button style="background-color: #DA70D6; height: 20px; width: 20px;" onclick="color('#DA70D6')"></button>
|
46 |
+
<td><button style="background-color: #8A2BE2; height: 20px; width: 20px;" onclick="color('#8A2BE2')"></button>
|
47 |
+
</tr>
|
48 |
+
<tr>
|
49 |
+
<td><button style="background-color: #0000FF; height: 20px; width: 20px;" onclick="color('#0000FF')"></button>
|
50 |
+
<td><button style="background-color: #4876FF; height: 20px; width: 20px;" onclick="color('#4876FF')"></button>
|
51 |
+
</tr>
|
52 |
+
<tr>
|
53 |
+
<td><button style="background-color: #CAE1FF; height: 20px; width: 20px;" onclick="color('#CAE1FF')"></button>
|
54 |
+
<td><button style="background-color: #6E7B8B; height: 20px; width: 20px;" onclick="color('#6E7B8B')"></button>
|
55 |
+
</tr>
|
56 |
+
<tr>
|
57 |
+
<td><button style="background-color: #00C78C; height: 20px; width: 20px;" onclick="color('#00C78C')"></button>
|
58 |
+
<td><button style="background-color: #00FA9A; height: 20px; width: 20px;" onclick="color('#00FA9A')"></button>
|
59 |
+
</tr>
|
60 |
+
<tr>
|
61 |
+
<td><button style="background-color: #00FF7F; height: 20px; width: 20px;" onclick="color('#00FF7F')"></button>
|
62 |
+
<td><button style="background-color: #00C957; height: 20px; width: 20px;" onclick="color('#00C957')"></button>
|
63 |
+
</tr>
|
64 |
+
<tr>
|
65 |
+
<td><button style="background-color: #3D9140; height: 20px; width: 20px;" onclick="color('#3D9140')"></button>
|
66 |
+
<td><button style="background-color: #32CD32; height: 20px; width: 20px;" onclick="color('#32CD32')"></button>
|
67 |
+
</tr>
|
68 |
+
<tr>
|
69 |
+
<td><button style="background-color: #00EE00; height: 20px; width: 20px;" onclick="color('#00EE00')"></button>
|
70 |
+
|
71 |
+
<td><button style="background-color: #008B00; height: 20px; width: 20px;" onclick="color('#008B00')"></button>
|
72 |
+
</tr>
|
73 |
+
<tr>
|
74 |
+
<td><button style="background-color: #76EE00; height: 20px; width: 20px;" onclick="color('#76EE00')"></button>
|
75 |
+
|
76 |
+
<td><button style="background-color: #CAFF70; height: 20px; width: 20px;" onclick="color('#CAFF70')"></button>
|
77 |
+
</tr>
|
78 |
+
<tr>
|
79 |
+
<td><button style="background-color: #FFFF00; height: 20px; width: 20px;" onclick="color('#FFFF00')"></button>
|
80 |
+
|
81 |
+
<td><button style="background-color: #CDCD00; height: 20px; width: 20px;" onclick="color('#CDCD00')"></button>
|
82 |
+
</tr>
|
83 |
+
<tr>
|
84 |
+
<td><button style="background-color: #FFF68F; height: 20px; width: 20px;" onclick="color('#FFF68F')"></button>
|
85 |
+
|
86 |
+
<td><button style="background-color: #FFFACD; height: 20px; width: 20px;" onclick="color('#FFFACD')"></button>
|
87 |
+
</tr>
|
88 |
+
<tr>
|
89 |
+
<td><button style="background-color: #FFEC8B; height: 20px; width: 20px;" onclick="color('#FFEC8B')"></button>
|
90 |
+
|
91 |
+
<td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
|
92 |
+
</tr>
|
93 |
+
<tr>
|
94 |
+
<td><button style="background-color: #F5DEB3; height: 20px; width: 20px;" onclick="color('#F5DEB3')"></button>
|
95 |
+
|
96 |
+
<td><button style="background-color: #FFE4B5; height: 20px; width: 20px;" onclick="color('#FFE4B5')"></button>
|
97 |
+
</tr>
|
98 |
+
<tr>
|
99 |
+
<td><button style="background-color: #EECFA1; height: 20px; width: 20px;" onclick="color('#EECFA1')"></button>
|
100 |
+
|
101 |
+
<td><button style="background-color: #FF9912; height: 20px; width: 20px;" onclick="color('#FF9912')"></button>
|
102 |
+
</tr>
|
103 |
+
<tr>
|
104 |
+
<td><button style="background-color: #8E388E; height: 20px; width: 20px;" onclick="color('#8E388E')"></button>
|
105 |
+
|
106 |
+
<td><button style="background-color: #7171C6; height: 20px; width: 20px;" onclick="color('#7171C6')"></button>
|
107 |
+
</tr>
|
108 |
+
|
109 |
+
<tr>
|
110 |
+
<td><button style="background-color: #7D9EC0; height: 20px; width: 20px;" onclick="color('#7D9EC0')"></button>
|
111 |
+
|
112 |
+
<td><button style="background-color: #388E8E; height: 20px; width: 20px;" onclick="color('#388E8E')"></button>
|
113 |
+
|
114 |
+
</tr>
|
115 |
+
|
116 |
+
<tr>
|
117 |
+
<td><button style="background-color: #71C671; height: 20px; width: 20px;" onclick="color('#71C671')"></button>
|
118 |
+
|
119 |
+
<td><button style="background-color: #8E8E38; height: 20px; width: 20px;" onclick="color('#8E8E38')"></button>
|
120 |
+
</tr>
|
121 |
+
<tr>
|
122 |
+
<td><button style="background-color: #C5C1AA; height: 20px; width: 20px;" onclick="color('#C5C1AA')"></button>
|
123 |
+
|
124 |
+
<td><button style="background-color: #C67171; height: 20px; width: 20px;" onclick="color('#C67171')"></button>
|
125 |
+
</tr>
|
126 |
+
<tr>
|
127 |
+
<td><button style="background-color: #555555; height: 20px; width: 20px;" onclick="color('#555555')"></button>
|
128 |
+
<td><button style="background-color: #848484; height: 20px; width: 20px;" onclick="color('#848484')"></button>
|
129 |
+
</tr>
|
130 |
+
<tr>
|
131 |
+
<td><button style="background-color: #FFFFFF; height: 20px; width: 20px;" onclick="color('#FFFFFF')"></button>
|
132 |
+
<td><button style="background-color: #EE0000; height: 20px; width: 20px;" onclick="color('#EE0000')"></button>
|
133 |
+
</tr>
|
134 |
+
<tr>
|
135 |
+
<td><button style="background-color: #FF4040; height: 20px; width: 20px;" onclick="color('#FF4040')"></button>
|
136 |
+
<td><button style="background-color: #EE6363; height: 20px; width: 20px;" onclick="color('#EE6363')"></button>
|
137 |
+
</tr>
|
138 |
+
<tr>
|
139 |
+
<td><button style="background-color: #FFC1C1; height: 20px; width: 20px;" onclick="color('#FFC1C1')"></button>
|
140 |
+
<td><button style="background-color: #FF7256; height: 20px; width: 20px;" onclick="color('#FF7256')"></button>
|
141 |
+
</tr>
|
142 |
+
<tr>
|
143 |
+
<td><button style="background-color: #FF4500; height: 20px; width: 20px;" onclick="color('#FF4500')"></button>
|
144 |
+
<td><button style="background-color: #F4A460; height: 20px; width: 20px;" onclick="color('#F4A460')"></button>
|
145 |
+
</tr>
|
146 |
+
<tr>
|
147 |
+
<td><button style="background-color: #FF8000; height: 20px; width: 20px;" onclick="color('FF8000')"></button>
|
148 |
+
<td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
|
149 |
+
</tr>
|
150 |
+
<tr>
|
151 |
+
<td><button style="background-color: #8B864E; height: 20px; width: 20px;" onclick="color('#8B864E')"></button>
|
152 |
+
<td><button style="background-color: #9ACD32; height: 20px; width: 20px;" onclick="color('#9ACD32')"></button>
|
153 |
+
</tr>
|
154 |
+
<tr>
|
155 |
+
<td><button style="background-color: #66CD00; height: 20px; width: 20px;" onclick="color('#66CD00')"></button>
|
156 |
+
<td><button style="background-color: #BDFCC9; height: 20px; width: 20px;" onclick="color('#BDFCC9')"></button>
|
157 |
+
</tr>
|
158 |
+
<tr>
|
159 |
+
<td><button style="background-color: #76EEC6; height: 20px; width: 20px;" onclick="color('#76EEC6')"></button>
|
160 |
+
<td><button style="background-color: #40E0D0; height: 20px; width: 20px;" onclick="color('#40E0D0')"></button>
|
161 |
+
</tr>
|
162 |
+
<tr>
|
163 |
+
<td><button style="background-color: #E0EEEE; height: 20px; width: 20px;" onclick="color('#E0EEEE')"></button>
|
164 |
+
<td><button style="background-color: #98F5FF; height: 20px; width: 20px;" onclick="color('#98F5FF')"></button>
|
165 |
+
</tr>
|
166 |
+
<tr>
|
167 |
+
<td><button style="background-color: #33A1C9; height: 20px; width: 20px;" onclick="color('#33A1C9')"></button>
|
168 |
+
<td><button style="background-color: #F0F8FF; height: 20px; width: 20px;" onclick="color('#F0F8FF')"></button>
|
169 |
+
</tr>
|
170 |
+
<tr>
|
171 |
+
<td><button style="background-color: #4682B4; height: 20px; width: 20px;" onclick="color('#4682B4')"></button>
|
172 |
+
<td><button style="background-color: #C6E2FF; height: 20px; width: 20px;" onclick="color('#C6E2FF')"></button>
|
173 |
+
</tr>
|
174 |
+
<tr>
|
175 |
+
<td><button style="background-color: #9B30FF; height: 20px; width: 20px;" onclick="color('#9B30FF')"></button>
|
176 |
+
<td><button style="background-color: #EE82EE; height: 20px; width: 20px;" onclick="color('#EE82EE')"></button>
|
177 |
+
</tr>
|
178 |
+
<tr>
|
179 |
+
<td><button style="background-color: #FFC0CB; height: 20px; width: 20px;" onclick="color('#FFC0CB')"></button>
|
180 |
+
<td><button style="background-color: #7CFC00; height: 20px; width: 20px;" onclick="color('#7CFC00')"></button>
|
181 |
+
</tr>
|
182 |
+
<tr>
|
183 |
+
<input type="color" id="color">
|
184 |
+
</tr>
|
185 |
+
</table>
|
186 |
+
</td>
|
187 |
+
<td>
|
188 |
+
<div style="width: 1150px; height: 800px; overflow: auto"><canvas align = "center" id="draw_canvas" width="{{width}}" height="{{height}}"></canvas></div>
|
189 |
+
</td>
|
190 |
+
<td>
|
191 |
+
<canvas id='result'></canvas>
|
192 |
+
</td>
|
193 |
+
</tr>
|
194 |
+
</table></p>
|
195 |
+
|
196 |
+
<button style="height: 20px; width: 80px" onclick="reset()">Clear</button>
|
197 |
+
|
198 |
+
<button style="height: 20px; width: 80px" onclick="colorize()" >Colorize</button>
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
<script src="http://code.jquery.com/jquery-1.8.3.js"></script>
|
203 |
+
<script src="/static/js/draw.js">
|
204 |
+
</script>
|
205 |
+
</body>
|
206 |
+
</html>
|
templates/submit.html
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<form method="POST" action="/" enctype="multipart/form-data" target="_blank">
|
2 |
+
{{ form.hidden_tag() }}
|
3 |
+
{{ form.file.label }} {{ form.file(size=20) }}
|
4 |
+
{{ form.denoise.label }} {{ form.denoise(size=5) }}
|
5 |
+
{{ form.denoise_sigma.label }} {{ form.denoise_sigma(size=5) }}
|
6 |
+
{{ form.autohint.label }} {{ form.autohint(size=5) }}
|
7 |
+
{{ form.autohint_sigma.label }} {{ form.autohint_sigma(size=5) }}
|
8 |
+
{{ form.ignore_gray.label }} {{ form.ignore_gray(size=5) }}
|
9 |
+
<input type="submit" value="Colorize">
|
10 |
+
|
11 |
+
</form>
|
templates/upload.html
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<style>
|
2 |
+
form{
|
3 |
+
position:fixed;
|
4 |
+
top:50%;
|
5 |
+
left:45%;
|
6 |
+
width:1250px;
|
7 |
+
}
|
8 |
+
</style>
|
9 |
+
|
10 |
+
|
11 |
+
<form id="form" method="POST" action="/" enctype="multipart/form-data">
|
12 |
+
{{ form.hidden_tag() }}
|
13 |
+
{{ form.file(size=20) }}
|
14 |
+
</form>
|
15 |
+
|
16 |
+
<script>
|
17 |
+
document.getElementById("file").onchange = function() {
|
18 |
+
document.getElementById("form").submit();
|
19 |
+
};
|
20 |
+
</script>
|
train.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
import numpy as np
|
5 |
+
import albumentations as albu
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
|
9 |
+
from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
|
10 |
+
from model.models import Colorizer, Generator, Content, Discriminator
|
11 |
+
from model.extractor import get_seresnext_extractor
|
12 |
+
from dataset.datasets import TrainDataset, FineTuningDataset
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("-p", "--path", required=True, help = "dataset path")
|
18 |
+
parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true')
|
19 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
20 |
+
parser.set_defaults(fine_tuning = False)
|
21 |
+
parser.set_defaults(gpu = False)
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
return args
|
25 |
+
|
26 |
+
def get_transforms():
|
27 |
+
return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
|
28 |
+
|
29 |
+
def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
|
30 |
+
train_dataset = TrainDataset(data_path, transforms)
|
31 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
|
32 |
+
|
33 |
+
if fine_tuning:
|
34 |
+
finetuning_dataset = FineTuningDataset(data_path, transforms)
|
35 |
+
finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
|
36 |
+
|
37 |
+
return train_dataloader, finetuning_dataloader
|
38 |
+
|
39 |
+
def get_models(device):
|
40 |
+
generator = Generator()
|
41 |
+
extractor = get_seresnext_extractor()
|
42 |
+
colorizer = Colorizer(generator, extractor)
|
43 |
+
|
44 |
+
colorizer.extractor_eval()
|
45 |
+
colorizer = colorizer.to(device)
|
46 |
+
|
47 |
+
discriminator = Discriminator().to(device)
|
48 |
+
|
49 |
+
content = Content('model/vgg16-397923af.pth').eval().to(device)
|
50 |
+
for param in content.parameters():
|
51 |
+
param.requires_grad = False
|
52 |
+
|
53 |
+
return colorizer, discriminator, content
|
54 |
+
|
55 |
+
def set_weights(colorizer, discriminator):
|
56 |
+
colorizer.generator.apply(weights_init)
|
57 |
+
colorizer.load_extractor_weights(torch.load('model/extractor.pth'))
|
58 |
+
|
59 |
+
discriminator.apply(weights_init_spectr)
|
60 |
+
|
61 |
+
def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
|
62 |
+
sim_loss_full = dist_loss(main_output, real_image)
|
63 |
+
sim_loss_guide = dist_loss(guide_output, real_image)
|
64 |
+
|
65 |
+
adv_loss = class_loss(disc_output, true_labels)
|
66 |
+
|
67 |
+
content_loss = content_dist_loss(content_gen, content_true)
|
68 |
+
|
69 |
+
sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
|
70 |
+
|
71 |
+
return sum_loss
|
72 |
+
|
73 |
+
def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
|
74 |
+
optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
|
75 |
+
optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9))
|
76 |
+
|
77 |
+
return optimizerG, optimizerD
|
78 |
+
|
79 |
+
def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
|
80 |
+
for p in discriminator.parameters():
|
81 |
+
p.requires_grad = False
|
82 |
+
for p in colorizer.generator.parameters():
|
83 |
+
p.requires_grad = True
|
84 |
+
|
85 |
+
colorizer.generator.zero_grad()
|
86 |
+
|
87 |
+
bw, color, hint, dfm = inputs
|
88 |
+
bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
|
89 |
+
|
90 |
+
fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
|
91 |
+
|
92 |
+
logits_fake = discriminator(fake)
|
93 |
+
y_real = torch.ones((bw.size(0), 1), device = device)
|
94 |
+
|
95 |
+
content_fake = content(fake)
|
96 |
+
with torch.no_grad():
|
97 |
+
content_true = content(color)
|
98 |
+
|
99 |
+
generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
|
100 |
+
|
101 |
+
if white_penalty:
|
102 |
+
mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
|
103 |
+
white_zones = mask * (fake + 1) / 2
|
104 |
+
white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
|
105 |
+
|
106 |
+
generator_loss += white_penalty
|
107 |
+
|
108 |
+
generator_loss.backward()
|
109 |
+
|
110 |
+
optimizer.step()
|
111 |
+
|
112 |
+
return generator_loss.item()
|
113 |
+
|
114 |
+
def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
|
115 |
+
|
116 |
+
for p in discriminator.parameters():
|
117 |
+
p.requires_grad = True
|
118 |
+
for p in colorizer.generator.parameters():
|
119 |
+
p.requires_grad = False
|
120 |
+
|
121 |
+
discriminator.zero_grad()
|
122 |
+
|
123 |
+
bw, color, hint, dfm = inputs
|
124 |
+
bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
|
125 |
+
|
126 |
+
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
127 |
+
|
128 |
+
y_fake = torch.zeros((bw.size(0), 1), device = device)
|
129 |
+
|
130 |
+
with torch.no_grad():
|
131 |
+
fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
|
132 |
+
fake_color.detach()
|
133 |
+
|
134 |
+
logits_fake = discriminator(fake_color)
|
135 |
+
logits_real = discriminator(color)
|
136 |
+
|
137 |
+
fake_loss = loss_function(logits_fake, y_fake)
|
138 |
+
real_loss = loss_function(logits_real, y_real)
|
139 |
+
|
140 |
+
discriminator_loss = real_loss + fake_loss
|
141 |
+
|
142 |
+
discriminator_loss.backward()
|
143 |
+
optimizer.step()
|
144 |
+
|
145 |
+
return discriminator_loss.item()
|
146 |
+
|
147 |
+
def decrease_lr(optimizer, rate):
|
148 |
+
for group in optimizer.param_groups:
|
149 |
+
group['lr'] /= rate
|
150 |
+
|
151 |
+
def set_lr(optimizer, value):
|
152 |
+
for group in optimizer.param_groups:
|
153 |
+
group['lr'] = value
|
154 |
+
|
155 |
+
def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
|
156 |
+
colorizer.generator.train()
|
157 |
+
discriminator.train()
|
158 |
+
|
159 |
+
disc_step = True
|
160 |
+
|
161 |
+
for epoch in range(epochs):
|
162 |
+
if (epoch == lr_decay_epoch):
|
163 |
+
decrease_lr(colorizer_optimizer, 10)
|
164 |
+
decrease_lr(discriminator_optimizer, 10)
|
165 |
+
|
166 |
+
sum_disc_loss = 0
|
167 |
+
sum_gen_loss = 0
|
168 |
+
|
169 |
+
for n, inputs in enumerate(dataloader):
|
170 |
+
if n % 5 == 0:
|
171 |
+
print(datetime.datetime.now().time())
|
172 |
+
print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
173 |
+
|
174 |
+
|
175 |
+
if disc_step:
|
176 |
+
step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
|
177 |
+
sum_disc_loss += step_loss
|
178 |
+
else:
|
179 |
+
step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
|
180 |
+
sum_gen_loss += step_loss
|
181 |
+
|
182 |
+
disc_step = disc_step ^ True
|
183 |
+
|
184 |
+
|
185 |
+
print(datetime.datetime.now().time())
|
186 |
+
print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
187 |
+
|
188 |
+
|
189 |
+
def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
|
190 |
+
|
191 |
+
for p in discriminator.parameters():
|
192 |
+
p.requires_grad = True
|
193 |
+
for p in colorizer.generator.parameters():
|
194 |
+
p.requires_grad = False
|
195 |
+
|
196 |
+
for cur_disc_step in range(5):
|
197 |
+
discriminator.zero_grad()
|
198 |
+
|
199 |
+
bw, dfm, color_for_real = data_iter.next()
|
200 |
+
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
201 |
+
|
202 |
+
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
203 |
+
y_fake = torch.zeros((bw.size(0), 1), device = device)
|
204 |
+
|
205 |
+
empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
|
209 |
+
fake_color_manga.detach()
|
210 |
+
|
211 |
+
logits_fake = discriminator(fake_color_manga)
|
212 |
+
logits_real = discriminator(color_for_real)
|
213 |
+
|
214 |
+
fake_loss = loss_function(logits_fake, y_fake)
|
215 |
+
real_loss = loss_function(logits_real, y_real)
|
216 |
+
discriminator_loss = real_loss + fake_loss
|
217 |
+
|
218 |
+
discriminator_loss.backward()
|
219 |
+
disc_optimizer.step()
|
220 |
+
|
221 |
+
|
222 |
+
for p in discriminator.parameters():
|
223 |
+
p.requires_grad = False
|
224 |
+
for p in colorizer.generator.parameters():
|
225 |
+
p.requires_grad = True
|
226 |
+
|
227 |
+
colorizer.generator.zero_grad()
|
228 |
+
|
229 |
+
bw, dfm, _ = data_iter.next()
|
230 |
+
bw, dfm = bw.to(device), dfm.to(device)
|
231 |
+
|
232 |
+
y_real = torch.ones((bw.size(0), 1), device = device)
|
233 |
+
|
234 |
+
empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
|
235 |
+
|
236 |
+
fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
|
237 |
+
|
238 |
+
logits_fake = discriminator(fake_manga)
|
239 |
+
adv_loss = loss_function(logits_fake, y_real)
|
240 |
+
|
241 |
+
generator_loss = adv_loss
|
242 |
+
|
243 |
+
generator_loss.backward()
|
244 |
+
gen_optimizer.step()
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
|
249 |
+
colorizer.generator.train()
|
250 |
+
discriminator.train()
|
251 |
+
|
252 |
+
disc_step = True
|
253 |
+
|
254 |
+
for n, inputs in enumerate(dataloader):
|
255 |
+
|
256 |
+
if n == iterations:
|
257 |
+
return
|
258 |
+
|
259 |
+
if disc_step:
|
260 |
+
discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
|
261 |
+
else:
|
262 |
+
generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
|
263 |
+
|
264 |
+
disc_step = disc_step ^ True
|
265 |
+
|
266 |
+
if n % 10 == 5:
|
267 |
+
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
268 |
+
|
269 |
+
if __name__ == '__main__':
|
270 |
+
args = parse_args()
|
271 |
+
config = open_json('configs/train_config.json')
|
272 |
+
|
273 |
+
if args.gpu:
|
274 |
+
device = 'cuda'
|
275 |
+
else:
|
276 |
+
device = 'cpu'
|
277 |
+
|
278 |
+
augmentations = get_transforms()
|
279 |
+
|
280 |
+
train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
|
281 |
+
|
282 |
+
colorizer, discriminator, content = get_models(device)
|
283 |
+
set_weights(colorizer, discriminator)
|
284 |
+
|
285 |
+
gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
|
286 |
+
|
287 |
+
train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
|
288 |
+
|
289 |
+
if args.fine_tuning:
|
290 |
+
set_lr(gen_optimizer, config["finetuning_generator_lr"])
|
291 |
+
fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
|
292 |
+
|
293 |
+
torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
|
train/bw/blackclover_cl268.png
ADDED
train/bw/dfm_blackclover_cl268.png
ADDED
train/color/blackclover_cl268.png
ADDED
train/real_manga/blackclover_cl268.png
ADDED
train/real_manga/dfm_blackclover_cl268.png
ADDED
utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.78 kB). View file
|
|
utils/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (3.8 kB). View file
|
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.81 kB). View file
|
|
utils/dataset_utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import cv2
|
4 |
+
import snowy
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
def get_resized_image(img, size):
|
9 |
+
if len(img.shape) == 2:
|
10 |
+
img = np.repeat(np.expand_dims(img, 2), 3, 2)
|
11 |
+
|
12 |
+
if (img.shape[0] < img.shape[1]):
|
13 |
+
height = img.shape[0]
|
14 |
+
ratio = height / size
|
15 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
16 |
+
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
|
17 |
+
else:
|
18 |
+
width = img.shape[1]
|
19 |
+
ratio = width / size
|
20 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
21 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
22 |
+
|
23 |
+
if (img.dtype == 'float32'):
|
24 |
+
np.clip(img, 0, 1, out = img)
|
25 |
+
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def get_sketch_image(img, sketcher, mult_val):
|
30 |
+
|
31 |
+
if mult_val:
|
32 |
+
sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
|
33 |
+
else:
|
34 |
+
sketch_image = sketcher.get_sketch_with_resize(img)
|
35 |
+
|
36 |
+
return sketch_image
|
37 |
+
|
38 |
+
|
39 |
+
def get_dfm_image(sketch):
|
40 |
+
dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
|
41 |
+
return dfm_image
|
42 |
+
|
43 |
+
def get_sketch(image, sketcher, dfm, mult = None):
|
44 |
+
sketch_image = get_sketch_image(image, sketcher, mult)
|
45 |
+
|
46 |
+
dfm_image = None
|
47 |
+
|
48 |
+
if dfm:
|
49 |
+
dfm_image = get_dfm_image(sketch_image)
|
50 |
+
|
51 |
+
sketch_image = (sketch_image * 255).astype('uint8')
|
52 |
+
|
53 |
+
if dfm:
|
54 |
+
dfm_image = (dfm_image * 255).astype('uint8')
|
55 |
+
|
56 |
+
return sketch_image, dfm_image
|
57 |
+
|
58 |
+
def get_sketches(image, sketcher, mult_list, dfm):
|
59 |
+
for mult in mult_list:
|
60 |
+
yield get_sketch(image, sketcher, dfm, mult)
|
61 |
+
|
62 |
+
|
63 |
+
def create_resized_dataset(source_path, target_path, side_size):
|
64 |
+
images = os.listdir(source_path)
|
65 |
+
|
66 |
+
for image_name in images:
|
67 |
+
|
68 |
+
new_image_name = image_name[:image_name.rfind('.')] + '.png'
|
69 |
+
new_path = os.path.join(target_path, new_image_name)
|
70 |
+
|
71 |
+
if not os.path.exists(new_path):
|
72 |
+
try:
|
73 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
74 |
+
|
75 |
+
if image is None:
|
76 |
+
raise Exception()
|
77 |
+
|
78 |
+
image = get_resized_image(image, side_size)
|
79 |
+
|
80 |
+
cv2.imwrite(new_path, image)
|
81 |
+
except:
|
82 |
+
print('Failed to process {}'.format(image_name))
|
83 |
+
|
84 |
+
|
85 |
+
def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
|
86 |
+
|
87 |
+
images = os.listdir(source_path)
|
88 |
+
for image_name in images:
|
89 |
+
try:
|
90 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
91 |
+
|
92 |
+
if image is None:
|
93 |
+
raise Exception()
|
94 |
+
|
95 |
+
for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
|
96 |
+
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
|
97 |
+
cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
|
98 |
+
|
99 |
+
if dfm:
|
100 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
|
101 |
+
cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
|
102 |
+
|
103 |
+
except:
|
104 |
+
print('Failed to process {}'.format(image_name))
|
105 |
+
|
106 |
+
|
107 |
+
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
|
108 |
+
images = os.listdir(source_path)
|
109 |
+
|
110 |
+
color_path = os.path.join(target_path, 'color')
|
111 |
+
sketch_path = os.path.join(target_path, 'bw')
|
112 |
+
|
113 |
+
if not os.path.exists(color_path):
|
114 |
+
os.makedirs(color_path)
|
115 |
+
|
116 |
+
if not os.path.exists(sketch_path):
|
117 |
+
os.makedirs(sketch_path)
|
118 |
+
|
119 |
+
for image_name in images:
|
120 |
+
new_image_name = image_name[:image_name.rfind('.')] + '.png'
|
121 |
+
|
122 |
+
try:
|
123 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
124 |
+
|
125 |
+
if image is None:
|
126 |
+
raise Exception()
|
127 |
+
|
128 |
+
resized_image = get_resized_image(image, side_size)
|
129 |
+
cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
|
130 |
+
|
131 |
+
for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
|
132 |
+
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
|
133 |
+
cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
|
134 |
+
|
135 |
+
if dfm:
|
136 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
|
137 |
+
cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
|
138 |
+
|
139 |
+
except:
|
140 |
+
print('Failed to process {}'.format(image_name))
|
141 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import scipy.stats as stats
|
5 |
+
import cv2
|
6 |
+
import json
|
7 |
+
import patoolib
|
8 |
+
import re
|
9 |
+
from pathlib import Path
|
10 |
+
from shutil import rmtree
|
11 |
+
|
12 |
+
def weights_init(m):
|
13 |
+
classname = m.__class__.__name__
|
14 |
+
if classname.find('Conv2d') != -1:
|
15 |
+
nn.init.xavier_uniform_(m.weight.data)
|
16 |
+
|
17 |
+
def weights_init_spectr(m):
|
18 |
+
classname = m.__class__.__name__
|
19 |
+
if classname.find('Conv2d') != -1:
|
20 |
+
nn.init.xavier_uniform_(m.weight_bar.data)
|
21 |
+
|
22 |
+
def generate_mask(height, width, mu = 1, sigma = 0.0005, prob = 0.5, full = True, full_prob = 0.01):
|
23 |
+
X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
|
24 |
+
|
25 |
+
if full:
|
26 |
+
if (np.random.binomial(1, p = full_prob) == 1):
|
27 |
+
return torch.ones(1, height, width).float()
|
28 |
+
|
29 |
+
if np.random.binomial(1, p = prob) == 1:
|
30 |
+
mask = torch.rand(1, height, width).ge(X.rvs(1)[0]).float()
|
31 |
+
else:
|
32 |
+
mask = torch.zeros(1, height, width).float()
|
33 |
+
|
34 |
+
return mask
|
35 |
+
|
36 |
+
def resize_pad(img, size = 512):
|
37 |
+
|
38 |
+
if len(img.shape) == 2:
|
39 |
+
img = np.expand_dims(img, 2)
|
40 |
+
|
41 |
+
if img.shape[2] == 1:
|
42 |
+
img = np.repeat(img, 3, 2)
|
43 |
+
|
44 |
+
if img.shape[2] == 4:
|
45 |
+
img = img[:, :, :3]
|
46 |
+
|
47 |
+
pad = None
|
48 |
+
|
49 |
+
if (img.shape[0] < img.shape[1]):
|
50 |
+
height = img.shape[0]
|
51 |
+
ratio = height / size
|
52 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
53 |
+
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
|
54 |
+
|
55 |
+
new_width = width
|
56 |
+
while (new_width % 32 != 0):
|
57 |
+
new_width += 1
|
58 |
+
|
59 |
+
pad = (0, new_width - width)
|
60 |
+
|
61 |
+
img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
|
62 |
+
else:
|
63 |
+
width = img.shape[1]
|
64 |
+
ratio = width / size
|
65 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
66 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
67 |
+
|
68 |
+
new_height = height
|
69 |
+
while (new_height % 32 != 0):
|
70 |
+
new_height += 1
|
71 |
+
|
72 |
+
pad = (new_height - height, 0)
|
73 |
+
|
74 |
+
img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
|
75 |
+
|
76 |
+
if (img.dtype == 'float32'):
|
77 |
+
np.clip(img, 0, 1, out = img)
|
78 |
+
|
79 |
+
return img, pad
|
80 |
+
|
81 |
+
def open_json(file):
|
82 |
+
with open(file) as json_file:
|
83 |
+
data = json.load(json_file)
|
84 |
+
|
85 |
+
return data
|
86 |
+
|
87 |
+
def extract_cbr(file, out_dir):
|
88 |
+
patoolib.extract_archive(file, outdir = out_dir, verbosity = 1, interactive = False)
|
89 |
+
|
90 |
+
def create_cbz(file_path, files):
|
91 |
+
patoolib.create_archive(file_path, files, verbosity = 1, interactive = False)
|
92 |
+
|
93 |
+
def subfolder_image_search(start_folder):
|
94 |
+
return [x.as_posix() for x in Path(start_folder).rglob("*.[pPjJ][nNpP][gG]")]
|
95 |
+
|
96 |
+
def remove_folder(folder_path):
|
97 |
+
rmtree(folder_path)
|
98 |
+
|
99 |
+
def sorted_alphanumeric(data):
|
100 |
+
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
101 |
+
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
|
102 |
+
return sorted(data, key=alphanum_key)
|
utils/xdog.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cv2 import resize, INTER_LANCZOS4, INTER_AREA
|
2 |
+
from skimage.color import rgb2gray
|
3 |
+
import numpy as np
|
4 |
+
from scipy.ndimage.filters import gaussian_filter
|
5 |
+
from skimage.filters import threshold_otsu
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
class XDoGSketcher:
|
9 |
+
|
10 |
+
def __init__(self, gamma = 0.95, phi = 89.25, eps = -0.1, k = 8, sigma = 0.5, mult = 1):
|
11 |
+
self.params = {}
|
12 |
+
self.params['gamma'] = gamma
|
13 |
+
self.params['phi'] = phi
|
14 |
+
self.params['eps'] = eps
|
15 |
+
self.params['k'] = k
|
16 |
+
self.params['sigma'] = sigma
|
17 |
+
|
18 |
+
self.params['mult'] = mult
|
19 |
+
|
20 |
+
def _xdog(self, im, **transform_params):
|
21 |
+
# Source : https://github.com/CemalUnal/XDoG-Filter
|
22 |
+
# Reference : XDoG: An eXtended difference-of-Gaussians compendium including advanced image stylization
|
23 |
+
# Link : http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.365.151&rep=rep1&type=pdf
|
24 |
+
|
25 |
+
if im.shape[2] == 3:
|
26 |
+
im = rgb2gray(im)
|
27 |
+
|
28 |
+
imf1 = gaussian_filter(im, transform_params['sigma'])
|
29 |
+
imf2 = gaussian_filter(im, transform_params['sigma'] * transform_params['k'])
|
30 |
+
imdiff = imf1 - transform_params['gamma'] * imf2
|
31 |
+
imdiff = (imdiff < transform_params['eps']) * 1.0 \
|
32 |
+
+ (imdiff >= transform_params['eps']) * (1.0 + np.tanh(transform_params['phi'] * imdiff))
|
33 |
+
imdiff -= imdiff.min()
|
34 |
+
imdiff /= imdiff.max()
|
35 |
+
|
36 |
+
|
37 |
+
th = threshold_otsu(imdiff)
|
38 |
+
imdiff = imdiff >= th
|
39 |
+
|
40 |
+
imdiff = imdiff.astype('float32')
|
41 |
+
|
42 |
+
return imdiff
|
43 |
+
|
44 |
+
|
45 |
+
def get_sketch(self, image, **kwargs):
|
46 |
+
current_params = self.params.copy()
|
47 |
+
|
48 |
+
for key in kwargs.keys():
|
49 |
+
if key in current_params.keys():
|
50 |
+
current_params[key] = kwargs[key]
|
51 |
+
|
52 |
+
result_image = self._xdog(image, **current_params)
|
53 |
+
|
54 |
+
return result_image
|
55 |
+
|
56 |
+
def get_sketch_with_resize(self, image, **kwargs):
|
57 |
+
if 'mult' in kwargs.keys():
|
58 |
+
mult = kwargs['mult']
|
59 |
+
else:
|
60 |
+
mult = self.params['mult']
|
61 |
+
|
62 |
+
temp_image = resize(image, (image.shape[1] * mult, image.shape[0] * mult), interpolation = INTER_LANCZOS4)
|
63 |
+
temp_image = self.get_sketch(temp_image, **kwargs)
|
64 |
+
image = resize(temp_image, (image.shape[1], image.shape[0]), interpolation = INTER_AREA)
|
65 |
+
|
66 |
+
return image
|
67 |
+
|
68 |
+
|
web.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file
|
2 |
+
from flask_wtf import FlaskForm
|
3 |
+
from wtforms import StringField, FileField, BooleanField, DecimalField
|
4 |
+
from wtforms.validators import DataRequired
|
5 |
+
from flask import after_this_request
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import os
|
10 |
+
from model.models import Colorizer, Generator
|
11 |
+
from model.extractor import get_seresnext_extractor
|
12 |
+
from utils.xdog import XDoGSketcher
|
13 |
+
from utils.utils import open_json
|
14 |
+
from denoising.denoiser import FFDNetDenoiser
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
from inference import colorize_single_image, colorize_images, colorize_cbr
|
18 |
+
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
device = 'cuda'
|
21 |
+
else:
|
22 |
+
device = 'cpu'
|
23 |
+
|
24 |
+
generator = Generator()
|
25 |
+
generator.load_state_dict(torch.load('model/generator.pth'))
|
26 |
+
|
27 |
+
extractor = get_seresnext_extractor()
|
28 |
+
extractor.load_state_dict(torch.load('model/extractor.pth'))
|
29 |
+
|
30 |
+
colorizer = Colorizer(generator, extractor)
|
31 |
+
colorizer = colorizer.eval().to(device)
|
32 |
+
|
33 |
+
sketcher = XDoGSketcher()
|
34 |
+
xdog_config = open_json('configs/xdog_config.json')
|
35 |
+
for key in xdog_config.keys():
|
36 |
+
if key in sketcher.params:
|
37 |
+
sketcher.params[key] = xdog_config[key]
|
38 |
+
|
39 |
+
denoiser = FFDNetDenoiser(device)
|
40 |
+
|
41 |
+
|
42 |
+
app = Flask(__name__)
|
43 |
+
app.config.update(dict(
|
44 |
+
SECRET_KEY="lol kek",
|
45 |
+
WTF_CSRF_SECRET_KEY="cheburek"
|
46 |
+
))
|
47 |
+
|
48 |
+
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True}
|
49 |
+
|
50 |
+
class SubmitForm(FlaskForm):
|
51 |
+
file = FileField(validators=[DataRequired()])
|
52 |
+
denoise = BooleanField(default = 'checked')
|
53 |
+
denoise_sigma = DecimalField(label = 'Denoise sigma', validators=[DataRequired()], default = 25, places = None)
|
54 |
+
autohint = BooleanField(default = None)
|
55 |
+
autohint_sigma = DecimalField(label = 'Autohint sigma', validators=[DataRequired()], default= 0.0003, places = None)
|
56 |
+
ignore_gray = BooleanField(label = 'Ignore gray autohint', default = None)
|
57 |
+
|
58 |
+
@app.route('/img/<path>')
|
59 |
+
def show_image(path):
|
60 |
+
return f'<img src="/static/{path}">'
|
61 |
+
|
62 |
+
@app.route('/', methods=('GET', 'POST'))
|
63 |
+
def submit_data():
|
64 |
+
form = SubmitForm()
|
65 |
+
if form.validate_on_submit():
|
66 |
+
|
67 |
+
input_data = form.file.data
|
68 |
+
|
69 |
+
_, ext = os.path.splitext(input_data.filename)
|
70 |
+
filename = str(datetime.now()) + ext
|
71 |
+
|
72 |
+
input_data.save(filename)
|
73 |
+
|
74 |
+
color_args['auto_hint'] = form.autohint.data
|
75 |
+
color_args['auto_hint_sigma'] = float(form.autohint_sigma.data)
|
76 |
+
color_args['ignore_gray'] = form.ignore_gray.data
|
77 |
+
color_args['denoiser'] = None
|
78 |
+
|
79 |
+
if form.denoise.data:
|
80 |
+
color_args['denoiser'] = denoiser
|
81 |
+
color_args['denoiser_sigma'] = float(form.denoise_sigma.data)
|
82 |
+
|
83 |
+
if ext.lower() in ('.cbr', '.cbz', '.rar', '.zip'):
|
84 |
+
result_name = colorize_cbr(filename, color_args)
|
85 |
+
os.remove(filename)
|
86 |
+
|
87 |
+
@after_this_request
|
88 |
+
def remove_file(response):
|
89 |
+
try:
|
90 |
+
os.remove(result_name)
|
91 |
+
except Exception as error:
|
92 |
+
app.logger.error("Error removing or closing downloaded file handle", error)
|
93 |
+
return response
|
94 |
+
|
95 |
+
return send_file(result_name, mimetype='application/vnd.comicbook-rar', attachment_filename=result_name, as_attachment=True)
|
96 |
+
|
97 |
+
elif ext.lower() in ('.jpg', '.png', ',jpeg'):
|
98 |
+
random_name = str(datetime.now()) + '.png'
|
99 |
+
new_image_path = os.path.join('static', random_name)
|
100 |
+
|
101 |
+
colorize_single_image(filename, new_image_path, color_args)
|
102 |
+
os.remove(filename)
|
103 |
+
|
104 |
+
return redirect(f'/img/{random_name}')
|
105 |
+
else:
|
106 |
+
return 'Wrong format'
|
107 |
+
|
108 |
+
return render_template('submit.html', form=form)
|