Keiser41 commited on
Commit
212d7be
1 Parent(s): d3745c8

Upload 246 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.ipynb
2
+
3
+ model/*.pth
4
+
5
+ temp_colorization/
6
+ __pycache__/
.gitattributes CHANGED
@@ -33,3 +33,103 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ train/11_dfm.png filter=lfs diff=lfs merge=lfs -text
37
+ train/110_dfm.png filter=lfs diff=lfs merge=lfs -text
38
+ train/111_dfm.png filter=lfs diff=lfs merge=lfs -text
39
+ train/113_dfm.png filter=lfs diff=lfs merge=lfs -text
40
+ train/12_dfm.png filter=lfs diff=lfs merge=lfs -text
41
+ train/122_dfm.png filter=lfs diff=lfs merge=lfs -text
42
+ train/127_dfm.png filter=lfs diff=lfs merge=lfs -text
43
+ train/13_dfm.png filter=lfs diff=lfs merge=lfs -text
44
+ train/131_dfm.png filter=lfs diff=lfs merge=lfs -text
45
+ train/137_dfm.png filter=lfs diff=lfs merge=lfs -text
46
+ train/138_dfm.png filter=lfs diff=lfs merge=lfs -text
47
+ train/139_dfm.png filter=lfs diff=lfs merge=lfs -text
48
+ train/144_dfm.png filter=lfs diff=lfs merge=lfs -text
49
+ train/145_dfm.png filter=lfs diff=lfs merge=lfs -text
50
+ train/146_dfm.png filter=lfs diff=lfs merge=lfs -text
51
+ train/15_dfm.png filter=lfs diff=lfs merge=lfs -text
52
+ train/bw/11.png filter=lfs diff=lfs merge=lfs -text
53
+ train/bw/110.png filter=lfs diff=lfs merge=lfs -text
54
+ train/bw/1100_dfm.png filter=lfs diff=lfs merge=lfs -text
55
+ train/bw/111.png filter=lfs diff=lfs merge=lfs -text
56
+ train/bw/1110_dfm.png filter=lfs diff=lfs merge=lfs -text
57
+ train/bw/113.png filter=lfs diff=lfs merge=lfs -text
58
+ train/bw/1130_dfm.png filter=lfs diff=lfs merge=lfs -text
59
+ train/bw/12.png filter=lfs diff=lfs merge=lfs -text
60
+ train/bw/122.png filter=lfs diff=lfs merge=lfs -text
61
+ train/bw/1220_dfm.png filter=lfs diff=lfs merge=lfs -text
62
+ train/bw/127.png filter=lfs diff=lfs merge=lfs -text
63
+ train/bw/1270_dfm.png filter=lfs diff=lfs merge=lfs -text
64
+ train/bw/13.png filter=lfs diff=lfs merge=lfs -text
65
+ train/bw/131.png filter=lfs diff=lfs merge=lfs -text
66
+ train/bw/1310_dfm.png filter=lfs diff=lfs merge=lfs -text
67
+ train/bw/137.png filter=lfs diff=lfs merge=lfs -text
68
+ train/bw/1370_dfm.png filter=lfs diff=lfs merge=lfs -text
69
+ train/bw/138.png filter=lfs diff=lfs merge=lfs -text
70
+ train/bw/1380_dfm.png filter=lfs diff=lfs merge=lfs -text
71
+ train/bw/139.png filter=lfs diff=lfs merge=lfs -text
72
+ train/bw/1390_dfm.png filter=lfs diff=lfs merge=lfs -text
73
+ train/bw/144.png filter=lfs diff=lfs merge=lfs -text
74
+ train/bw/1440_dfm.png filter=lfs diff=lfs merge=lfs -text
75
+ train/bw/145.png filter=lfs diff=lfs merge=lfs -text
76
+ train/bw/1450_dfm.png filter=lfs diff=lfs merge=lfs -text
77
+ train/bw/146.png filter=lfs diff=lfs merge=lfs -text
78
+ train/bw/1460_dfm.png filter=lfs diff=lfs merge=lfs -text
79
+ train/bw/15.png filter=lfs diff=lfs merge=lfs -text
80
+ train/color/11.png filter=lfs diff=lfs merge=lfs -text
81
+ train/color/110.png filter=lfs diff=lfs merge=lfs -text
82
+ train/color/111.png filter=lfs diff=lfs merge=lfs -text
83
+ train/color/112.png filter=lfs diff=lfs merge=lfs -text
84
+ train/color/113.png filter=lfs diff=lfs merge=lfs -text
85
+ train/color/114.png filter=lfs diff=lfs merge=lfs -text
86
+ train/color/115.png filter=lfs diff=lfs merge=lfs -text
87
+ train/color/116.png filter=lfs diff=lfs merge=lfs -text
88
+ train/color/117.png filter=lfs diff=lfs merge=lfs -text
89
+ train/color/119.png filter=lfs diff=lfs merge=lfs -text
90
+ train/color/12.png filter=lfs diff=lfs merge=lfs -text
91
+ train/color/120.png filter=lfs diff=lfs merge=lfs -text
92
+ train/color/121.png filter=lfs diff=lfs merge=lfs -text
93
+ train/color/122.png filter=lfs diff=lfs merge=lfs -text
94
+ train/color/124.png filter=lfs diff=lfs merge=lfs -text
95
+ train/color/125.png filter=lfs diff=lfs merge=lfs -text
96
+ train/color/126.png filter=lfs diff=lfs merge=lfs -text
97
+ train/color/127.png filter=lfs diff=lfs merge=lfs -text
98
+ train/color/128.png filter=lfs diff=lfs merge=lfs -text
99
+ train/color/129.png filter=lfs diff=lfs merge=lfs -text
100
+ train/color/13.png filter=lfs diff=lfs merge=lfs -text
101
+ train/color/130.png filter=lfs diff=lfs merge=lfs -text
102
+ train/color/131.png filter=lfs diff=lfs merge=lfs -text
103
+ train/color/132.png filter=lfs diff=lfs merge=lfs -text
104
+ train/color/133.png filter=lfs diff=lfs merge=lfs -text
105
+ train/color/136.png filter=lfs diff=lfs merge=lfs -text
106
+ train/color/137.png filter=lfs diff=lfs merge=lfs -text
107
+ train/color/138.png filter=lfs diff=lfs merge=lfs -text
108
+ train/color/139.png filter=lfs diff=lfs merge=lfs -text
109
+ train/color/14.png filter=lfs diff=lfs merge=lfs -text
110
+ train/color/140.png filter=lfs diff=lfs merge=lfs -text
111
+ train/color/141.png filter=lfs diff=lfs merge=lfs -text
112
+ train/color/142.png filter=lfs diff=lfs merge=lfs -text
113
+ train/color/143.png filter=lfs diff=lfs merge=lfs -text
114
+ train/color/144.png filter=lfs diff=lfs merge=lfs -text
115
+ train/color/145.png filter=lfs diff=lfs merge=lfs -text
116
+ train/color/146.png filter=lfs diff=lfs merge=lfs -text
117
+ train/color/147.png filter=lfs diff=lfs merge=lfs -text
118
+ train/color/148.png filter=lfs diff=lfs merge=lfs -text
119
+ train/color/149.png filter=lfs diff=lfs merge=lfs -text
120
+ train/color/15.png filter=lfs diff=lfs merge=lfs -text
121
+ train/color/150.png filter=lfs diff=lfs merge=lfs -text
122
+ train/color/16.png filter=lfs diff=lfs merge=lfs -text
123
+ train/color/17.png filter=lfs diff=lfs merge=lfs -text
124
+ train/color/18.png filter=lfs diff=lfs merge=lfs -text
125
+ train/color/19.png filter=lfs diff=lfs merge=lfs -text
126
+ train/real_manga/28e234c2.png filter=lfs diff=lfs merge=lfs -text
127
+ train/real_manga/3bc5fa00.png filter=lfs diff=lfs merge=lfs -text
128
+ train/real_manga/43d013cb.png filter=lfs diff=lfs merge=lfs -text
129
+ train/real_manga/4e5f7c7c.png filter=lfs diff=lfs merge=lfs -text
130
+ train/real_manga/6fc6ccf3.png filter=lfs diff=lfs merge=lfs -text
131
+ train/real_manga/9dec4d8f.png filter=lfs diff=lfs merge=lfs -text
132
+ train/real_manga/aba89a1c.png filter=lfs diff=lfs merge=lfs -text
133
+ train/real_manga/b4b63009.png filter=lfs diff=lfs merge=lfs -text
134
+ train/real_manga/cff0a3bf.png filter=lfs diff=lfs merge=lfs -text
135
+ train/real_manga/d65293ce.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.ipynb
2
+ *.pth
3
+ *.zip
4
+
5
+ __pycache__/
6
+ temp_colorization/
7
+
8
+ static/temp_images/*
9
+
10
+ !.gitkeep
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
2
+
3
+ RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
4
+
5
+ COPY . .
6
+
7
+ RUN pip install --no-cache-dir -r ./requirements.txt
8
+
9
+ EXPOSE 5000
10
+
11
+ CMD gunicorn --timeout 200 -w 3 -b 0.0.0.0:5000 drawing:app
configs/train_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generator_lr" : 1e-4,
3
+ "discriminator_lr" : 4e-4,
4
+ "epochs" : 15,
5
+ "lr_decrease_epoch" : 10,
6
+ "finetuning_generator_lr" : 1e-6,
7
+ "finetuning_iterations" : 3500,
8
+ "batch_size" : 4,
9
+ "number_of_mults" : 1
10
+ }
configs/xdog_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sigma" : 0.5,
3
+ "k" : 8,
4
+ "phi" : 89.25,
5
+ "gamma" : 0.95,
6
+ "eps" : -0.1,
7
+ "mult" : 7
8
+ }
dataset/__pycache__/datasets.cpython-39.pyc ADDED
Binary file (3.56 kB). View file
 
dataset/datasets.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torchvision.transforms as transforms
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+ from utils.utils import generate_mask
8
+
9
+
10
+ class TrainDataset(torch.utils.data.Dataset):
11
+ def __init__(self, data_path, transform = None, mults_amount = 1):
12
+ self.data = os.listdir(os.path.join(data_path, 'color'))
13
+ self.data_path = data_path
14
+ self.transform = transform
15
+ self.mults_amount = mults_amount
16
+
17
+ self.ToTensor = transforms.ToTensor()
18
+ def __len__(self):
19
+ return len(self.data)
20
+
21
+ def __getitem__(self, idx):
22
+ image_name = self.data[idx]
23
+
24
+ color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
25
+
26
+
27
+ if self.mults_amount > 1:
28
+ mult_number = np.random.choice(range(self.mults_amount))
29
+
30
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
+ else:
33
+ bw_name = self.data[idx]
34
+ dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
+
36
+
37
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
38
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
39
+
40
+ bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
41
+
42
+ if self.transform:
43
+ result = self.transform(image = color_img, mask = bw_img)
44
+ color_img = result['image']
45
+ bw_img = result['mask']
46
+
47
+ dfm_img = bw_img[:, :, 1]
48
+ bw_img = bw_img[:, :, 0]
49
+
50
+ color_img = self.ToTensor(color_img)
51
+ bw_img = self.ToTensor(bw_img)
52
+
53
+ dfm_img = self.ToTensor(dfm_img)
54
+
55
+ color_img = (color_img - 0.5) / 0.5
56
+
57
+ mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
58
+ hint = torch.cat((color_img * mask, mask), 0)
59
+
60
+ return bw_img, color_img, hint, dfm_img
61
+
62
+ class FineTuningDataset(torch.utils.data.Dataset):
63
+ def __init__(self, data_path, transform = None, mult_amount = 1):
64
+ self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
65
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
66
+ self.data_path = data_path
67
+ self.transform = transform
68
+ self.mults_amount = mult_amount
69
+
70
+ np.random.shuffle(self.color_data)
71
+
72
+ self.ToTensor = transforms.ToTensor()
73
+ def __len__(self):
74
+ return len(self.data)
75
+
76
+ def __getitem__(self, idx):
77
+ color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
78
+
79
+ image_name = self.data[idx]
80
+ if self.mults_amount > 1:
81
+ mult_number = np.random.choice(range(self.mults_amount))
82
+
83
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
+ else:
86
+ bw_name = self.data[idx]
87
+ dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
+
89
+
90
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
91
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
92
+
93
+ if self.transform:
94
+ result = self.transform(image = color_img)
95
+ color_img = result['image']
96
+
97
+ result = self.transform(image = bw_img, mask = dfm_img)
98
+ bw_img = result['image']
99
+ dfm_img = result['mask']
100
+
101
+ color_img = self.ToTensor(color_img)
102
+ bw_img = self.ToTensor(bw_img)
103
+ dfm_img = self.ToTensor(dfm_img)
104
+
105
+ color_img = (color_img - 0.5) / 0.5
106
+
107
+ return bw_img, dfm_img, color_img
denoising/denoiser.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Denoise an image with the FFDNet denoising method
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import os
14
+ import argparse
15
+ import time
16
+ import numpy as np
17
+ import cv2
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.autograd import Variable
21
+ from denoising.models import FFDNet
22
+ from denoising.utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
23
+
24
+ class FFDNetDenoiser:
25
+ def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
26
+ self.sigma = _sigma / 255
27
+ self.weights_dir = _weights_dir
28
+ self.channels = _in_ch
29
+ self.device = _device
30
+
31
+ self.model = FFDNet(num_input_channels = _in_ch)
32
+ self.load_weights()
33
+ self.model.eval()
34
+
35
+
36
+ def load_weights(self):
37
+ weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
38
+ weights_path = os.path.join(self.weights_dir, weights_name)
39
+ if self.device == 'cuda':
40
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
41
+ device_ids = [0]
42
+ self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
43
+ else:
44
+ state_dict = torch.load(weights_path, map_location='cpu')
45
+ # CPU mode: remove the DataParallel wrapper
46
+ state_dict = remove_dataparallel_wrapper(state_dict)
47
+ self.model.load_state_dict(state_dict)
48
+
49
+ def get_denoised_image(self, imorig, sigma = None):
50
+
51
+ if sigma is not None:
52
+ cur_sigma = sigma / 255
53
+ else:
54
+ cur_sigma = self.sigma
55
+
56
+ if len(imorig.shape) < 3 or imorig.shape[2] == 1:
57
+ imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
58
+
59
+ if (max(imorig.shape[0], imorig.shape[1]) > 1200):
60
+ ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
61
+ imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
62
+
63
+ imorig = imorig.transpose(2, 0, 1)
64
+
65
+ if (imorig.max() > 1.2):
66
+ imorig = normalize(imorig)
67
+ imorig = np.expand_dims(imorig, 0)
68
+
69
+ # Handle odd sizes
70
+ expanded_h = False
71
+ expanded_w = False
72
+ sh_im = imorig.shape
73
+ if sh_im[2]%2 == 1:
74
+ expanded_h = True
75
+ imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
76
+
77
+ if sh_im[3]%2 == 1:
78
+ expanded_w = True
79
+ imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
80
+
81
+
82
+ imorig = torch.Tensor(imorig)
83
+
84
+
85
+ # Sets data type according to CPU or GPU modes
86
+ if self.device == 'cuda':
87
+ dtype = torch.cuda.FloatTensor
88
+ else:
89
+ dtype = torch.FloatTensor
90
+
91
+ imnoisy = imorig.clone()
92
+
93
+
94
+ with torch.no_grad():
95
+ imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
96
+ nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
97
+
98
+
99
+ # Estimate noise and subtract it to the input image
100
+ im_noise_estim = self.model(imnoisy, nsigma)
101
+ outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
102
+
103
+ if expanded_h:
104
+ imorig = imorig[:, :, :-1, :]
105
+ outim = outim[:, :, :-1, :]
106
+ imnoisy = imnoisy[:, :, :-1, :]
107
+
108
+ if expanded_w:
109
+ imorig = imorig[:, :, :, :-1]
110
+ outim = outim[:, :, :, :-1]
111
+ imnoisy = imnoisy[:, :, :, :-1]
112
+
113
+ return variable_to_cv2_image(outim)
denoising/functions.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions implementing custom NN layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch
14
+ from torch.autograd import Function, Variable
15
+
16
+ def concatenate_input_noise_map(input, noise_sigma):
17
+ r"""Implements the first layer of FFDNet. This function returns a
18
+ torch.autograd.Variable composed of the concatenation of the downsampled
19
+ input image and the noise map. Each image of the batch of size CxHxW gets
20
+ converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
21
+ non-overlapped 2x2 patches of the input image are placed in the new array
22
+ along the first dimension.
23
+
24
+ Args:
25
+ input: batch containing CxHxW images
26
+ noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
27
+ """
28
+ # noise_sigma is a list of length batch_size
29
+ N, C, H, W = input.size()
30
+ dtype = input.type()
31
+ sca = 2
32
+ sca2 = sca*sca
33
+ Cout = sca2*C
34
+ Hout = H//sca
35
+ Wout = W//sca
36
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
37
+
38
+ # Fill the downsampled image with zeros
39
+ if 'cuda' in dtype:
40
+ downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
41
+ else:
42
+ downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
43
+
44
+ # Build the CxH/2xW/2 noise map
45
+ noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
46
+
47
+ # Populate output
48
+ for idx in range(sca2):
49
+ downsampledfeatures[:, idx:Cout:sca2, :, :] = \
50
+ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
51
+
52
+ # concatenate de-interleaved mosaic with noise map
53
+ return torch.cat((noise_map, downsampledfeatures), 1)
54
+
55
+ class UpSampleFeaturesFunction(Function):
56
+ r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
57
+ This class implements the forward and backward methods of the last layer
58
+ of FFDNet. It basically performs the inverse of
59
+ concatenate_input_noise_map(): it converts each of the images of a
60
+ batch of size CxH/2xW/2 to images of size C/4xHxW
61
+ """
62
+ @staticmethod
63
+ def forward(ctx, input):
64
+ N, Cin, Hin, Win = input.size()
65
+ dtype = input.type()
66
+ sca = 2
67
+ sca2 = sca*sca
68
+ Cout = Cin//sca2
69
+ Hout = Hin*sca
70
+ Wout = Win*sca
71
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
72
+
73
+ assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
74
+
75
+ result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
76
+ for idx in range(sca2):
77
+ result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
78
+
79
+ return result
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output):
83
+ N, Cg_out, Hg_out, Wg_out = grad_output.size()
84
+ dtype = grad_output.data.type()
85
+ sca = 2
86
+ sca2 = sca*sca
87
+ Cg_in = sca2*Cg_out
88
+ Hg_in = Hg_out//sca
89
+ Wg_in = Wg_out//sca
90
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
91
+
92
+ # Build output
93
+ grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
94
+ # Populate output
95
+ for idx in range(sca2):
96
+ grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
97
+
98
+ return Variable(grad_input)
99
+
100
+ # Alias functions
101
+ upsamplefeatures = UpSampleFeaturesFunction.apply
denoising/models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of the FFDNet model and its custom layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch.nn as nn
14
+ from torch.autograd import Variable
15
+ import denoising.functions as functions
16
+
17
+ class UpSampleFeatures(nn.Module):
18
+ r"""Implements the last layer of FFDNet
19
+ """
20
+ def __init__(self):
21
+ super(UpSampleFeatures, self).__init__()
22
+ def forward(self, x):
23
+ return functions.upsamplefeatures(x)
24
+
25
+ class IntermediateDnCNN(nn.Module):
26
+ r"""Implements the middel part of the FFDNet architecture, which
27
+ is basically a DnCNN net
28
+ """
29
+ def __init__(self, input_features, middle_features, num_conv_layers):
30
+ super(IntermediateDnCNN, self).__init__()
31
+ self.kernel_size = 3
32
+ self.padding = 1
33
+ self.input_features = input_features
34
+ self.num_conv_layers = num_conv_layers
35
+ self.middle_features = middle_features
36
+ if self.input_features == 5:
37
+ self.output_features = 4 #Grayscale image
38
+ elif self.input_features == 15:
39
+ self.output_features = 12 #RGB image
40
+ else:
41
+ raise Exception('Invalid number of input features')
42
+
43
+ layers = []
44
+ layers.append(nn.Conv2d(in_channels=self.input_features,\
45
+ out_channels=self.middle_features,\
46
+ kernel_size=self.kernel_size,\
47
+ padding=self.padding,\
48
+ bias=False))
49
+ layers.append(nn.ReLU(inplace=True))
50
+ for _ in range(self.num_conv_layers-2):
51
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
52
+ out_channels=self.middle_features,\
53
+ kernel_size=self.kernel_size,\
54
+ padding=self.padding,\
55
+ bias=False))
56
+ layers.append(nn.BatchNorm2d(self.middle_features))
57
+ layers.append(nn.ReLU(inplace=True))
58
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
59
+ out_channels=self.output_features,\
60
+ kernel_size=self.kernel_size,\
61
+ padding=self.padding,\
62
+ bias=False))
63
+ self.itermediate_dncnn = nn.Sequential(*layers)
64
+ def forward(self, x):
65
+ out = self.itermediate_dncnn(x)
66
+ return out
67
+
68
+ class FFDNet(nn.Module):
69
+ r"""Implements the FFDNet architecture
70
+ """
71
+ def __init__(self, num_input_channels):
72
+ super(FFDNet, self).__init__()
73
+ self.num_input_channels = num_input_channels
74
+ if self.num_input_channels == 1:
75
+ # Grayscale image
76
+ self.num_feature_maps = 64
77
+ self.num_conv_layers = 15
78
+ self.downsampled_channels = 5
79
+ self.output_features = 4
80
+ elif self.num_input_channels == 3:
81
+ # RGB image
82
+ self.num_feature_maps = 96
83
+ self.num_conv_layers = 12
84
+ self.downsampled_channels = 15
85
+ self.output_features = 12
86
+ else:
87
+ raise Exception('Invalid number of input features')
88
+
89
+ self.intermediate_dncnn = IntermediateDnCNN(\
90
+ input_features=self.downsampled_channels,\
91
+ middle_features=self.num_feature_maps,\
92
+ num_conv_layers=self.num_conv_layers)
93
+ self.upsamplefeatures = UpSampleFeatures()
94
+
95
+ def forward(self, x, noise_sigma):
96
+ concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
97
+ concat_noise_x = Variable(concat_noise_x)
98
+ h_dncnn = self.intermediate_dncnn(concat_noise_x)
99
+ pred_noise = self.upsamplefeatures(h_dncnn)
100
+ return pred_noise
denoising/models/.gitkeep ADDED
File without changes
denoising/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Different utilities such as orthogonalization of weights, initialization of
3
+ loggers, etc
4
+
5
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
6
+
7
+ This program is free software: you can use, modify and/or
8
+ redistribute it under the terms of the GNU General Public
9
+ License as published by the Free Software Foundation, either
10
+ version 3 of the License, or (at your option) any later
11
+ version. You should have received a copy of this license along
12
+ this program. If not, see <http://www.gnu.org/licenses/>.
13
+ """
14
+ import numpy as np
15
+ import cv2
16
+
17
+
18
+ def variable_to_cv2_image(varim):
19
+ r"""Converts a torch.autograd.Variable to an OpenCV image
20
+
21
+ Args:
22
+ varim: a torch.autograd.Variable
23
+ """
24
+ nchannels = varim.size()[1]
25
+ if nchannels == 1:
26
+ res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
27
+ elif nchannels == 3:
28
+ res = varim.data.cpu().numpy()[0]
29
+ res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
30
+ res = (res*255.).clip(0, 255).astype(np.uint8)
31
+ else:
32
+ raise Exception('Number of color channels not supported')
33
+ return res
34
+
35
+
36
+ def normalize(data):
37
+ return np.float32(data/255.)
38
+
39
+ def remove_dataparallel_wrapper(state_dict):
40
+ r"""Converts a DataParallel model to a normal one by removing the "module."
41
+ wrapper in the module dictionary
42
+
43
+ Args:
44
+ state_dict: a torch.nn.DataParallel state dictionary
45
+ """
46
+ from collections import OrderedDict
47
+
48
+ new_state_dict = OrderedDict()
49
+ for k, vl in state_dict.items():
50
+ name = k[7:] # remove 'module.' of DataParallel
51
+ new_state_dict[name] = vl
52
+
53
+ return new_state_dict
54
+
55
+ def is_rgb(im_path):
56
+ r""" Returns True if the image in im_path is an RGB image
57
+ """
58
+ from skimage.io import imread
59
+ rgb = False
60
+ im = imread(im_path)
61
+ if (len(im.shape) == 3):
62
+ if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
63
+ rgb = True
64
+ print("rgb: {}".format(rgb))
65
+ print("im shape: {}".format(im.shape))
66
+ return rgb
drawing.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ import base64
4
+ import random
5
+ import string
6
+ import shutil
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file, Response
11
+ from flask_wtf import FlaskForm
12
+ from wtforms import StringField, FileField, BooleanField, DecimalField
13
+ from wtforms.validators import DataRequired
14
+ from flask import after_this_request
15
+
16
+ from model.models import Colorizer, Generator
17
+ from model.extractor import get_seresnext_extractor
18
+ from utils.xdog import XDoGSketcher
19
+ from utils.utils import open_json
20
+ from denoising.denoiser import FFDNetDenoiser
21
+ from inference import process_image_with_hint
22
+ from utils.utils import resize_pad
23
+ from utils.dataset_utils import get_sketch
24
+
25
+ def generate_id(size=25, chars=string.ascii_letters + string.digits):
26
+ return ''.join(random.SystemRandom().choice(chars) for _ in range(size))
27
+
28
+ def generate_unique_id(current_ids = set()):
29
+ id_t = generate_id()
30
+ while id_t in current_ids:
31
+ id_t = generate_id()
32
+
33
+ current_ids.add(id_t)
34
+
35
+ return id_t
36
+
37
+ app = Flask(__name__)
38
+ app.config.update(dict(
39
+ SECRET_KEY="lol kek",
40
+ WTF_CSRF_SECRET_KEY="cheburek"
41
+ ))
42
+
43
+ if torch.cuda.is_available():
44
+ device = 'cuda'
45
+ else:
46
+ device = 'cpu'
47
+
48
+ colorizer = torch.jit.load('./model/colorizer.zip', map_location=torch.device(device))
49
+
50
+ sketcher = XDoGSketcher()
51
+ xdog_config = open_json('configs/xdog_config.json')
52
+ for key in xdog_config.keys():
53
+ if key in sketcher.params:
54
+ sketcher.params[key] = xdog_config[key]
55
+
56
+ denoiser = FFDNetDenoiser(device)
57
+
58
+ color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True, 'auto_hint' : False, 'ignore_gray' : False, 'denoiser' : denoiser, 'denoiser_sigma' : 25}
59
+
60
+
61
+ class SubmitForm(FlaskForm):
62
+ file = FileField(validators=[DataRequired(), ])
63
+
64
+ def preprocess_image(file_id, ext):
65
+ directory_path = os.path.join('static', 'temp_images', file_id)
66
+ original_path = os.path.join(directory_path, 'original') + ext
67
+ original_image = plt.imread(original_path)
68
+
69
+ resized_image, _ = resize_pad(original_image)
70
+ resized_image = denoiser.get_denoised_image(resized_image, 25)
71
+ bw, dfm = get_sketch(resized_image, sketcher, True)
72
+
73
+ resized_name = 'resized_' + str(resized_image.shape[0]) + '_' + str(resized_image.shape[1]) + '.png'
74
+ plt.imsave(os.path.join(directory_path, resized_name), resized_image)
75
+ plt.imsave(os.path.join(directory_path, 'bw.png'), bw, cmap = 'gray')
76
+ plt.imsave(os.path.join(directory_path, 'dfm.png'), dfm, cmap = 'gray')
77
+ os.remove(original_path)
78
+
79
+ empty_hint = np.zeros((resized_image.shape[0], resized_image.shape[1], 4), dtype = np.float32)
80
+ plt.imsave(os.path.join(directory_path, 'hint.png'), empty_hint)
81
+
82
+ @app.route('/', methods=['GET', 'POST'])
83
+ def upload():
84
+ form = SubmitForm()
85
+ if form.validate_on_submit():
86
+ input_data = form.file.data
87
+
88
+ _, ext = os.path.splitext(input_data.filename)
89
+
90
+ if ext not in ('.jpg', '.png', '.jpeg'):
91
+ return abort(400)
92
+
93
+ file_id = generate_unique_id()
94
+ directory = os.path.join('static', 'temp_images', file_id)
95
+ original_filename = os.path.join(directory, 'original') + ext
96
+
97
+ try :
98
+ os.mkdir(directory)
99
+ input_data.save(original_filename)
100
+
101
+ preprocess_image(file_id, ext)
102
+
103
+ return redirect(f'/draw/{file_id}')
104
+
105
+ except :
106
+ print('Failed to colorize')
107
+ if os.path.exists(directory):
108
+ shutil.rmtree(directory)
109
+ return abort(400)
110
+
111
+
112
+ return render_template("upload.html", form = form)
113
+
114
+ @app.route('/img/<file_id>')
115
+ def show_image(file_id):
116
+ if not os.path.exists(os.path.join('static', 'temp_images', str(file_id))):
117
+ abort(404)
118
+ return f'<img src="/static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}">'
119
+
120
+ def colorize_image(file_id):
121
+ directory_path = os.path.join('static', 'temp_images', file_id)
122
+
123
+ bw = plt.imread(os.path.join(directory_path, 'bw.png'))[..., :1]
124
+ dfm = plt.imread(os.path.join(directory_path, 'dfm.png'))[..., :1]
125
+ hint = plt.imread(os.path.join(directory_path, 'hint.png'))
126
+
127
+ return process_image_with_hint(bw, dfm, hint, color_args)
128
+
129
+ @app.route('/colorize', methods=['POST'])
130
+ def colorize():
131
+
132
+ file_id = request.form['save_file_id']
133
+ file_id = file_id[file_id.rfind('/') + 1:]
134
+
135
+ img_data = request.form['save_image']
136
+ img_data = img_data[img_data.find(',') + 1:]
137
+
138
+ directory_path = os.path.join('static', 'temp_images', file_id)
139
+
140
+ with open(os.path.join(directory_path, 'hint.png'), "wb") as im:
141
+ im.write(base64.decodestring(str.encode(img_data)))
142
+
143
+ result = colorize_image(file_id)
144
+
145
+ plt.imsave(os.path.join(directory_path, 'colorized.png'), result)
146
+
147
+ src_path = f'../static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}'
148
+
149
+ return src_path
150
+
151
+ @app.route('/draw/<file_id>', methods=['GET', 'POST'])
152
+ def paintapp(file_id):
153
+ if request.method == 'GET':
154
+
155
+ directory_path = os.path.join('static', 'temp_images', str(file_id))
156
+ if not os.path.exists(directory_path):
157
+ abort(404)
158
+
159
+ resized_name = [x for x in os.listdir(directory_path) if x.startswith('resized_')][0]
160
+
161
+ split = os.path.splitext(resized_name)[0].split('_')
162
+ width = int(split[2])
163
+ height = int(split[1])
164
+
165
+ return render_template("drawing.html", height = height, width = width, img_path = os.path.join('temp_images', str(file_id), resized_name))
inference.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from utils.dataset_utils import get_sketch
5
+ from utils.utils import resize_pad, generate_mask, extract_cbr, create_cbz, sorted_alphanumeric, subfolder_image_search, remove_folder
6
+ from torchvision.transforms import ToTensor
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import argparse
10
+ from model.models import Colorizer, Generator
11
+ from model.extractor import get_seresnext_extractor
12
+ from utils.xdog import XDoGSketcher
13
+ from utils.utils import open_json
14
+ import sys
15
+ from denoising.denoiser import FFDNetDenoiser
16
+
17
+ def colorize_without_hint(inp, color_args):
18
+ i_hint = torch.zeros(1, 4, inp.shape[2], inp.shape[3]).float().to(color_args['device'])
19
+
20
+ with torch.no_grad():
21
+ fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
22
+
23
+ if color_args['auto_hint']:
24
+ mask = generate_mask(fake_color.shape[2], fake_color.shape[3], full = False, prob = 1, sigma = color_args['auto_hint_sigma']).unsqueeze(0)
25
+ mask = mask.to(color_args['device'])
26
+
27
+
28
+ if color_args['ignore_gray']:
29
+ diff1 = torch.abs(fake_color[:, 0] - fake_color[:, 1])
30
+ diff2 = torch.abs(fake_color[:, 0] - fake_color[:, 2])
31
+ diff3 = torch.abs(fake_color[:, 1] - fake_color[:, 2])
32
+ mask = ((mask + ((diff1 + diff2 + diff3) > 60 / 255).float().unsqueeze(1)) == 2).float()
33
+
34
+
35
+ i_hint = torch.cat([fake_color * mask, mask], 1)
36
+
37
+ with torch.no_grad():
38
+ fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
39
+
40
+ return fake_color
41
+
42
+
43
+ def process_image(image, color_args, to_tensor = ToTensor()):
44
+ image, pad = resize_pad(image)
45
+
46
+ if color_args['denoiser'] is not None:
47
+ image = color_args['denoiser'].get_denoised_image(image, color_args['denoiser_sigma'])
48
+
49
+ bw, dfm = get_sketch(image, color_args['sketcher'], color_args['dfm'])
50
+
51
+ bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
52
+ dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
53
+
54
+ output = colorize_without_hint(torch.cat([bw, dfm], 1), color_args)
55
+ result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
56
+
57
+ if pad[0] != 0:
58
+ result = result[:-pad[0]]
59
+ if pad[1] != 0:
60
+ result = result[:, :-pad[1]]
61
+
62
+ return result
63
+
64
+ def colorize_with_hint(inp, color_args):
65
+ with torch.no_grad():
66
+ fake_color, _ = color_args['colorizer'](inp)
67
+
68
+ return fake_color
69
+
70
+ def process_image_with_hint(bw, dfm, hint, color_args, to_tensor = ToTensor()):
71
+ bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
72
+ dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
73
+
74
+ i_hint = (torch.FloatTensor(hint[..., :3]).permute(2, 0, 1) - 0.5) / 0.5
75
+ mask = torch.FloatTensor(hint[..., 3:]).permute(2, 0, 1)
76
+ i_hint = torch.cat([i_hint * mask, mask], 0).unsqueeze(0).to(color_args['device'])
77
+
78
+ output = colorize_with_hint(torch.cat([bw, dfm, i_hint], 1), color_args)
79
+ result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
80
+
81
+ return result
82
+
83
+ def colorize_single_image(file_path, save_path, color_args):
84
+ try:
85
+ image = plt.imread(file_path)
86
+
87
+ colorization = process_image(image, color_args)
88
+
89
+ plt.imsave(save_path, colorization)
90
+
91
+ return True
92
+ except KeyboardInterrupt:
93
+ sys.exit(0)
94
+ except:
95
+ print('Failed to colorize {}'.format(file_path))
96
+ return False
97
+
98
+ def colorize_images(source_path, target_path, color_args):
99
+ images = os.listdir(source_path)
100
+
101
+ for image_name in images:
102
+ file_path = os.path.join(source_path, image_name)
103
+
104
+ name, ext = os.path.splitext(image_name)
105
+ if (ext != '.png'):
106
+ image_name = name + '.png'
107
+
108
+ save_path = os.path.join(target_path, image_name)
109
+ colorize_single_image(file_path, save_path, color_args)
110
+
111
+ def colorize_cbr(file_path, color_args):
112
+ file_name = os.path.splitext(os.path.basename(file_path))[0]
113
+ temp_path = 'temp_colorization'
114
+
115
+ if not os.path.exists(temp_path):
116
+ os.makedirs(temp_path)
117
+ extract_cbr(file_path, temp_path)
118
+
119
+ images = subfolder_image_search(temp_path)
120
+
121
+ result_images = []
122
+ for image_path in images:
123
+ save_path = image_path
124
+
125
+ path, ext = os.path.splitext(save_path)
126
+ if (ext != '.png'):
127
+ save_path = path + '.png'
128
+
129
+ res_flag = colorize_single_image(image_path, save_path, color_args)
130
+
131
+ result_images.append(save_path if res_flag else image_path)
132
+
133
+
134
+ result_name = os.path.join(os.path.dirname(file_path), file_name + '_colorized.cbz')
135
+
136
+ create_cbz(result_name, result_images)
137
+
138
+ remove_folder(temp_path)
139
+
140
+ return result_name
141
+
142
+ def parse_args():
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument("-p", "--path", required=True)
145
+ parser.add_argument("-gen", "--generator", default = 'model/generator.pth')
146
+ parser.add_argument("-ext", "--extractor", default = 'model/extractor.pth')
147
+ parser.add_argument("-s", "--sigma", type = float, default = 0.003)
148
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
149
+ parser.add_argument('-ah', '--auto', dest = 'autohint', action = 'store_true')
150
+ parser.add_argument('-ig', '--ignore_grey', dest = 'ignore', action = 'store_true')
151
+ parser.add_argument('-nd', '--no_denoise', dest = 'denoiser', action = 'store_false')
152
+ parser.add_argument("-ds", "--denoiser_sigma", type = int, default = 25)
153
+ parser.set_defaults(gpu = False)
154
+ parser.set_defaults(autohint = False)
155
+ parser.set_defaults(ignore = False)
156
+ parser.set_defaults(denoiser = True)
157
+ args = parser.parse_args()
158
+
159
+ return args
160
+
161
+
162
+ if __name__ == "__main__":
163
+
164
+ args = parse_args()
165
+
166
+ if args.gpu:
167
+ device = 'cuda'
168
+ else:
169
+ device = 'cpu'
170
+
171
+ generator = Generator()
172
+ generator.load_state_dict(torch.load(args.generator))
173
+
174
+ extractor = get_seresnext_extractor()
175
+ extractor.load_state_dict(torch.load(args.extractor))
176
+
177
+ colorizer = Colorizer(generator, extractor)
178
+ colorizer = colorizer.eval().to(device)
179
+
180
+ sketcher = XDoGSketcher()
181
+ xdog_config = open_json('configs/xdog_config.json')
182
+ for key in xdog_config.keys():
183
+ if key in sketcher.params:
184
+ sketcher.params[key] = xdog_config[key]
185
+
186
+ denoiser = None
187
+ if args.denoiser:
188
+ denoiser = FFDNetDenoiser(device, args.denoiser_sigma)
189
+
190
+ color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'auto_hint':args.autohint, 'auto_hint_sigma':args.sigma,\
191
+ 'ignore_gray':args.ignore, 'device':device, 'dfm' : True, 'denoiser':denoiser, 'denoiser_sigma' : args.denoiser_sigma}
192
+
193
+
194
+ if os.path.isdir(args.path):
195
+ colorization_path = os.path.join(args.path, 'colorization')
196
+ if not os.path.exists(colorization_path):
197
+ os.makedirs(colorization_path)
198
+
199
+ colorize_images(args.path, colorization_path, color_args)
200
+
201
+ elif os.path.isfile(args.path):
202
+
203
+ split = os.path.splitext(args.path)
204
+
205
+ if split[1].lower() in ('.cbr', '.cbz', '.rar', '.zip'):
206
+ colorize_cbr(args.path, color_args)
207
+ elif split[1].lower() in ('.jpg', '.png', ',jpeg'):
208
+ new_image_path = split[0] + '_colorized' + '.png'
209
+
210
+ colorize_single_image(args.path, new_image_path, color_args)
211
+ else:
212
+ print('Wrong format')
213
+ else:
214
+ print('Wrong path')
215
+
model/__pycache__/extractor.cpython-39.pyc ADDED
Binary file (3.95 kB). View file
 
model/__pycache__/models.cpython-39.pyc ADDED
Binary file (13.5 kB). View file
 
model/extractor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee3c59f02ac8c59298fd9b819fa33d2efa168847e15e4be39b35c286f7c18607
3
+ size 6340842
model/extractor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ '''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
6
+
7
+ # Pretrained version
8
+ class Selayer(nn.Module):
9
+ def __init__(self, inplanes):
10
+ super(Selayer, self).__init__()
11
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
13
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x):
18
+ out = self.global_avgpool(x)
19
+ out = self.conv1(out)
20
+ out = self.relu(out)
21
+ out = self.conv2(out)
22
+ out = self.sigmoid(out)
23
+
24
+ return x * out
25
+
26
+
27
+ class BottleneckX_Origin(nn.Module):
28
+ expansion = 4
29
+
30
+ def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
31
+ super(BottleneckX_Origin, self).__init__()
32
+ self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(planes * 2)
34
+
35
+ self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
36
+ padding=1, groups=cardinality, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(planes * 2)
38
+
39
+ self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes * 4)
41
+
42
+ self.selayer = Selayer(planes * 4)
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.downsample = downsample
46
+ self.stride = stride
47
+
48
+ def forward(self, x):
49
+ residual = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv3(out)
60
+ out = self.bn3(out)
61
+
62
+ out = self.selayer(out)
63
+
64
+ if self.downsample is not None:
65
+ residual = self.downsample(x)
66
+
67
+ out += residual
68
+ out = self.relu(out)
69
+
70
+ return out
71
+
72
+ class SEResNeXt_extractor(nn.Module):
73
+ def __init__(self, block, layers, input_channels=3, cardinality=32):
74
+ super(SEResNeXt_extractor, self).__init__()
75
+ self.cardinality = cardinality
76
+ self.inplanes = 64
77
+ self.input_channels = input_channels
78
+
79
+ self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
80
+ bias=False)
81
+ self.bn1 = nn.BatchNorm2d(64)
82
+ self.relu = nn.ReLU(inplace=True)
83
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
84
+
85
+ self.layer1 = self._make_layer(block, 64, layers[0])
86
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91
+ m.weight.data.normal_(0, math.sqrt(2. / n))
92
+ if m.bias is not None:
93
+ m.bias.data.zero_()
94
+ elif isinstance(m, nn.BatchNorm2d):
95
+ m.weight.data.fill_(1)
96
+ m.bias.data.zero_()
97
+
98
+ def _make_layer(self, block, planes, blocks, stride=1):
99
+ downsample = None
100
+ if stride != 1 or self.inplanes != planes * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.inplanes, planes * block.expansion,
103
+ kernel_size=1, stride=stride, bias=False),
104
+ nn.BatchNorm2d(planes * block.expansion),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
109
+ self.inplanes = planes * block.expansion
110
+ for i in range(1, blocks):
111
+ layers.append(block(self.inplanes, planes, self.cardinality))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+ x = self.conv1(x)
117
+ x = self.bn1(x)
118
+ x = self.relu(x)
119
+ x = self.maxpool(x)
120
+
121
+ x = self.layer1(x)
122
+ x = self.layer2(x)
123
+
124
+ return x
125
+
126
+ def get_seresnext_extractor():
127
+ return SEResNeXt_extractor(BottleneckX_Origin, [3, 4, 6, 3], 1)
model/models.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as M
5
+ import math
6
+ from torch import Tensor
7
+ from torch.nn import Parameter
8
+
9
+ '''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
10
+
11
+ def l2normalize(v, eps=1e-12):
12
+ return v / (v.norm() + eps)
13
+
14
+
15
+ class SpectralNorm(nn.Module):
16
+ def __init__(self, module, name='weight', power_iterations=1):
17
+ super(SpectralNorm, self).__init__()
18
+ self.module = module
19
+ self.name = name
20
+ self.power_iterations = power_iterations
21
+ if not self._made_params():
22
+ self._make_params()
23
+
24
+ def _update_u_v(self):
25
+ u = getattr(self.module, self.name + "_u")
26
+ v = getattr(self.module, self.name + "_v")
27
+ w = getattr(self.module, self.name + "_bar")
28
+
29
+ height = w.data.shape[0]
30
+ for _ in range(self.power_iterations):
31
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
32
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
33
+
34
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
35
+ sigma = u.dot(w.view(height, -1).mv(v))
36
+ setattr(self.module, self.name, w / sigma.expand_as(w))
37
+
38
+ def _made_params(self):
39
+ try:
40
+ u = getattr(self.module, self.name + "_u")
41
+ v = getattr(self.module, self.name + "_v")
42
+ w = getattr(self.module, self.name + "_bar")
43
+ return True
44
+ except AttributeError:
45
+ return False
46
+
47
+
48
+ def _make_params(self):
49
+ w = getattr(self.module, self.name)
50
+ height = w.data.shape[0]
51
+ width = w.view(height, -1).data.shape[1]
52
+
53
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
54
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
55
+ u.data = l2normalize(u.data)
56
+ v.data = l2normalize(v.data)
57
+ w_bar = Parameter(w.data)
58
+
59
+ del self.module._parameters[self.name]
60
+
61
+ self.module.register_parameter(self.name + "_u", u)
62
+ self.module.register_parameter(self.name + "_v", v)
63
+ self.module.register_parameter(self.name + "_bar", w_bar)
64
+
65
+
66
+ def forward(self, *args):
67
+ self._update_u_v()
68
+ return self.module.forward(*args)
69
+
70
+ class Selayer(nn.Module):
71
+ def __init__(self, inplanes):
72
+ super(Selayer, self).__init__()
73
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
74
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
75
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.sigmoid = nn.Sigmoid()
78
+
79
+ def forward(self, x):
80
+ out = self.global_avgpool(x)
81
+ out = self.conv1(out)
82
+ out = self.relu(out)
83
+ out = self.conv2(out)
84
+ out = self.sigmoid(out)
85
+
86
+ return x * out
87
+
88
+ class SelayerSpectr(nn.Module):
89
+ def __init__(self, inplanes):
90
+ super(SelayerSpectr, self).__init__()
91
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
92
+ self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
93
+ self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
94
+ self.relu = nn.ReLU(inplace=True)
95
+ self.sigmoid = nn.Sigmoid()
96
+
97
+ def forward(self, x):
98
+ out = self.global_avgpool(x)
99
+ out = self.conv1(out)
100
+ out = self.relu(out)
101
+ out = self.conv2(out)
102
+ out = self.sigmoid(out)
103
+
104
+ return x * out
105
+
106
+ class ResNeXtBottleneck(nn.Module):
107
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
108
+ super(ResNeXtBottleneck, self).__init__()
109
+ D = out_channels // 2
110
+ self.out_channels = out_channels
111
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
112
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
113
+ groups=cardinality,
114
+ bias=False)
115
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
116
+ self.shortcut = nn.Sequential()
117
+ if stride != 1:
118
+ self.shortcut.add_module('shortcut',
119
+ nn.AvgPool2d(2, stride=2))
120
+
121
+ self.selayer = Selayer(out_channels)
122
+
123
+ def forward(self, x):
124
+ bottleneck = self.conv_reduce.forward(x)
125
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
126
+ bottleneck = self.conv_conv.forward(bottleneck)
127
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
128
+ bottleneck = self.conv_expand.forward(bottleneck)
129
+ bottleneck = self.selayer(bottleneck)
130
+
131
+ x = self.shortcut.forward(x)
132
+ return x + bottleneck
133
+
134
+ class SpectrResNeXtBottleneck(nn.Module):
135
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
136
+ super(SpectrResNeXtBottleneck, self).__init__()
137
+ D = out_channels // 2
138
+ self.out_channels = out_channels
139
+ self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
140
+ self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
141
+ groups=cardinality,
142
+ bias=False))
143
+ self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
144
+ self.shortcut = nn.Sequential()
145
+ if stride != 1:
146
+ self.shortcut.add_module('shortcut',
147
+ nn.AvgPool2d(2, stride=2))
148
+
149
+ self.selayer = SelayerSpectr(out_channels)
150
+
151
+ def forward(self, x):
152
+ bottleneck = self.conv_reduce.forward(x)
153
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
154
+ bottleneck = self.conv_conv.forward(bottleneck)
155
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
156
+ bottleneck = self.conv_expand.forward(bottleneck)
157
+ bottleneck = self.selayer(bottleneck)
158
+
159
+ x = self.shortcut.forward(x)
160
+ return x + bottleneck
161
+
162
+ class FeatureConv(nn.Module):
163
+ def __init__(self, input_dim=512, output_dim=512):
164
+ super(FeatureConv, self).__init__()
165
+
166
+ no_bn = True
167
+
168
+ seq = []
169
+ seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
170
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
171
+ seq.append(nn.ReLU(inplace=True))
172
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
173
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
174
+ seq.append(nn.ReLU(inplace=True))
175
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
176
+ seq.append(nn.ReLU(inplace=True))
177
+
178
+ self.network = nn.Sequential(*seq)
179
+
180
+ def forward(self, x):
181
+ return self.network(x)
182
+
183
+ class Generator(nn.Module):
184
+ def __init__(self, ngf=64):
185
+ super(Generator, self).__init__()
186
+
187
+ self.feature_conv = FeatureConv()
188
+
189
+ self.to0 = self._make_encoder_block_first(6, 32)
190
+ self.to1 = self._make_encoder_block(32, 64)
191
+ self.to2 = self._make_encoder_block(64, 128)
192
+ self.to3 = self._make_encoder_block(128, 256)
193
+ self.to4 = self._make_encoder_block(256, 512)
194
+
195
+ self.deconv_for_decoder = nn.Sequential(
196
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
197
+ nn.LeakyReLU(0.2),
198
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
199
+ nn.LeakyReLU(0.2),
200
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
201
+ nn.LeakyReLU(0.2),
202
+ nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
203
+ nn.Tanh(),
204
+ )
205
+
206
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
207
+
208
+ self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
209
+ nn.LeakyReLU(0.2, True),
210
+ tunnel4,
211
+ nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
212
+ nn.PixelShuffle(2),
213
+ nn.LeakyReLU(0.2, True)
214
+ ) # 64
215
+
216
+ depth = 2
217
+ tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
218
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
219
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
220
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
221
+ ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
222
+ tunnel3 = nn.Sequential(*tunnel)
223
+
224
+ self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
225
+ nn.LeakyReLU(0.2, True),
226
+ tunnel3,
227
+ nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
228
+ nn.PixelShuffle(2),
229
+ nn.LeakyReLU(0.2, True)
230
+ ) # 128
231
+
232
+ tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
233
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
234
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
235
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
236
+ ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
237
+ tunnel2 = nn.Sequential(*tunnel)
238
+
239
+ self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
240
+ nn.LeakyReLU(0.2, True),
241
+ tunnel2,
242
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
243
+ nn.PixelShuffle(2),
244
+ nn.LeakyReLU(0.2, True)
245
+ )
246
+
247
+ tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
248
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
249
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
250
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
251
+ ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
252
+ tunnel1 = nn.Sequential(*tunnel)
253
+
254
+ self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
255
+ nn.LeakyReLU(0.2, True),
256
+ tunnel1,
257
+ nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
258
+ nn.PixelShuffle(2),
259
+ nn.LeakyReLU(0.2, True)
260
+ )
261
+
262
+ self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
263
+
264
+
265
+ def _make_encoder_block(self, inplanes, planes):
266
+ return nn.Sequential(
267
+ nn.Conv2d(inplanes, planes, 3, 2, 1),
268
+ nn.LeakyReLU(0.2),
269
+ nn.Conv2d(planes, planes, 3, 1, 1),
270
+ nn.LeakyReLU(0.2),
271
+ )
272
+
273
+ def _make_encoder_block_first(self, inplanes, planes):
274
+ return nn.Sequential(
275
+ nn.Conv2d(inplanes, planes, 3, 1, 1),
276
+ nn.LeakyReLU(0.2),
277
+ nn.Conv2d(planes, planes, 3, 1, 1),
278
+ nn.LeakyReLU(0.2),
279
+ )
280
+
281
+ def forward(self, sketch, sketch_feat):
282
+
283
+ x0 = self.to0(sketch)
284
+ x1 = self.to1(x0)
285
+ x2 = self.to2(x1)
286
+ x3 = self.to3(x2)
287
+ x4 = self.to4(x3)
288
+
289
+ sketch_feat = self.feature_conv(sketch_feat)
290
+
291
+ out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
292
+
293
+
294
+
295
+
296
+ x = self.tunnel3(torch.cat([out, x3], 1))
297
+ x = self.tunnel2(torch.cat([x, x2], 1))
298
+ x = self.tunnel1(torch.cat([x, x1], 1))
299
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
300
+
301
+ decoder_output = self.deconv_for_decoder(out)
302
+
303
+ return x, decoder_output
304
+ '''
305
+ class Colorizer(nn.Module):
306
+ def __init__(self, extractor_path = 'model/model.pth'):
307
+ super(Colorizer, self).__init__()
308
+
309
+ self.generator = Generator()
310
+ self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
311
+
312
+ def extractor_eval(self):
313
+ for param in self.extractor.parameters():
314
+ param.requires_grad = False
315
+
316
+ def extractor_train(self):
317
+ for param in extractor.parameters():
318
+ param.requires_grad = True
319
+
320
+ def forward(self, x, extractor_grad = False):
321
+
322
+ if extractor_grad:
323
+ features = self.extractor(x[:, 0:1])
324
+ else:
325
+ with torch.no_grad():
326
+ features = self.extractor(x[:, 0:1]).detach()
327
+
328
+ fake, guide = self.generator(x, features)
329
+
330
+ return fake, guide
331
+ '''
332
+
333
+ class Colorizer(nn.Module):
334
+ def __init__(self, generator_model, extractor_model):
335
+ super(Colorizer, self).__init__()
336
+
337
+ self.generator = generator_model
338
+ self.extractor = extractor_model
339
+
340
+ def load_generator_weights(self, gen_weights):
341
+ self.generator.load_state_dict(gen_weights)
342
+
343
+ def load_extractor_weights(self, ext_weights):
344
+ self.extractor.load_state_dict(ext_weights)
345
+
346
+ def extractor_eval(self):
347
+ for param in self.extractor.parameters():
348
+ param.requires_grad = False
349
+ self.extractor.eval()
350
+
351
+ def extractor_train(self):
352
+ for param in extractor.parameters():
353
+ param.requires_grad = True
354
+ self.extractor.train()
355
+
356
+ def forward(self, x, extractor_grad = False):
357
+
358
+ if extractor_grad:
359
+ features = self.extractor(x[:, 0:1])
360
+ else:
361
+ with torch.no_grad():
362
+ features = self.extractor(x[:, 0:1]).detach()
363
+
364
+ fake, guide = self.generator(x, features)
365
+
366
+ return fake, guide
367
+
368
+ class Discriminator(nn.Module):
369
+ def __init__(self, ndf=64):
370
+ super(Discriminator, self).__init__()
371
+
372
+ self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
373
+ nn.LeakyReLU(0.2, True),
374
+ SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
375
+ nn.LeakyReLU(0.2, True),
376
+
377
+
378
+
379
+
380
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
381
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
382
+ SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
383
+ nn.LeakyReLU(0.2, True),
384
+
385
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
386
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
387
+ SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
388
+ nn.LeakyReLU(0.2, True),
389
+
390
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
391
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
392
+ SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
393
+ nn.LeakyReLU(0.2, True),
394
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
395
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
396
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
397
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
398
+ nn.AdaptiveAvgPool2d((1, 1))
399
+ )
400
+
401
+ self.out = nn.Linear(512, 1)
402
+
403
+ def forward(self, color):
404
+ x = self.feed(color)
405
+
406
+ out = self.out(x.view(color.size(0), -1))
407
+ return out
408
+
409
+ class Content(nn.Module):
410
+ def __init__(self, path):
411
+ super(Content, self).__init__()
412
+ vgg16 = M.vgg16()
413
+ vgg16.load_state_dict(torch.load(path))
414
+ vgg16.features = nn.Sequential(
415
+ *list(vgg16.features.children())[:9]
416
+ )
417
+ self.model = vgg16.features
418
+ self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
419
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
420
+
421
+ def forward(self, images):
422
+ return self.model((images.mul(0.5) - self.mean) / self.std)
model/vgg16-397923af.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0
3
+ size 553433881
readme.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ UPD. See the [improved version](https://github.com/qweasdd/manga-colorization-v2).
2
+
3
+ # Automatic colorization
4
+
5
+ 1. Download [generator](https://drive.google.com/file/d/1Oo6ycphJ3sUOpDCDoG29NA5pbhQVCevY/view?usp=sharing), [extractor](https://drive.google.com/file/d/12cbNyJcCa1zI2EBz6nea3BXl21Fm73Bt/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put generator and extractor weights in `model` and denoiser weights in `denoising/models`.
6
+ 2. To colorize image, folder of images, `.cbz` or `.cbr` file, use the following command:
7
+ ```
8
+ $ python inference.py -p "path to file or folder"
9
+ ```
10
+
11
+ # Manual colorization with color hints
12
+
13
+ 1. Download [colorizer](https://drive.google.com/file/d/1BERrMl9e7cKsk9m2L0q1yO4k7blNhEWC/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put colorizer weights in `model` and denoiser weights in `denoising/models`.
14
+ 2. Run gunicorn server with:
15
+ ```
16
+ $ ./run_drawing.sh
17
+ ```
18
+ 3. Open `localhost:5000` with a browser.
19
+
20
+ # References
21
+ 1. Extractor weights are taken from https://github.com/blandocs/Tag2Pix/releases/download/release/model.pth
22
+ 2. Denoiser weights are taken from http://www.ipol.im/pub/art/2019/231.
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ flask==1.1.1
2
+ gunicorn
3
+ numpy==1.16.6
4
+ flask_wtf==0.14.3
5
+ matplotlib==3.1.1
6
+ opencv-python==4.1.2.30
7
+ snowy
8
+ scipy==1.3.3
9
+ scikit-image==0.15.0
10
+ patool==1.12
run_drawing.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ gunicorn --worker-class gevent --timeout 150 -w 1 -b 0.0.0.0:5000 drawing:app
static/js/draw.js ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var canvas = document.getElementById('draw_canvas');
2
+ var ctx = canvas.getContext('2d');
3
+ var canvasWidth = canvas.width;
4
+ var canvasHeight = canvas.height;
5
+ var prevX, prevY;
6
+
7
+ var result_canvas = document.getElementById('result');
8
+ var result_ctx = result_canvas.getContext('2d');
9
+ result_canvas.width = canvas.width;
10
+ result_canvas.height = canvas.height;
11
+
12
+ var color_indicator = document.getElementById('color');
13
+ ctx.fillStyle = 'black';
14
+ color_indicator.value = '#000000';
15
+
16
+ var cur_id = window.location.pathname.substring(window.location.pathname.lastIndexOf('/') + 1);
17
+
18
+ function getRandomInt(max) {
19
+ return Math.floor(Math.random() * Math.floor(max));
20
+ }
21
+
22
+ var init_hint = new Image();
23
+ init_hint.addEventListener('load', function() {
24
+ ctx.drawImage(init_hint, 0, 0);
25
+ });
26
+ init_hint.src = '../static/temp_images/' + cur_id + '/hint.png?' + getRandomInt(100000).toString();
27
+
28
+ result_canvas.addEventListener('load', function(e) {
29
+ var img = new Image();
30
+ img.addEventListener('load', function() {
31
+ ctx.drawImage(img, 0, 0);
32
+ }, false);
33
+ console.log(window.location.pathname);
34
+ })
35
+
36
+
37
+ canvas.onload = function (e) {
38
+ var img = new Image();
39
+ img.addEventListener('load', function() {
40
+ ctx.drawImage(img, 0, 0);
41
+ }, false);
42
+ console.log(window.location.pathname);
43
+ //img.src = ;
44
+ }
45
+
46
+ function reset() {
47
+ ctx.clearRect(0, 0, canvasWidth, canvasHeight);
48
+ }
49
+
50
+ function getMousePos(canvas, evt) {
51
+ var rect = canvas.getBoundingClientRect();
52
+ return {
53
+ x: (evt.clientX - rect.left) / (rect.right - rect.left) * canvas.width,
54
+ y: (evt.clientY - rect.top) / (rect.bottom - rect.top) * canvas.height
55
+ };
56
+ }
57
+
58
+ function colorize() {
59
+ var file_id = document.location.pathname;
60
+ var image = canvas.toDataURL();
61
+
62
+ $.post("/colorize", { save_file_id: file_id, save_image: image}).done(function( data ) {
63
+ //console.log(document.location.origin + '/img/' + data)
64
+ //window.open(document.location.origin + '/img/' + data, '_blank');
65
+ //result.src = data;
66
+ var img = new Image();
67
+ img.addEventListener('load', function() {
68
+ result_ctx.drawImage(img, 0, 0);
69
+ }, false);
70
+ img.src = data;
71
+ });
72
+ }
73
+
74
+ canvas.addEventListener('mousedown', function(e) {
75
+ var mousePos = getMousePos(canvas, e);
76
+ if (e.button == 0) {
77
+ ctx.fillRect(mousePos['x'], mousePos['y'], 1, 1);
78
+ }
79
+
80
+ if (e.button == 2) {
81
+ prevX = mousePos['x']
82
+ prevY = mousePos['y']
83
+ }
84
+
85
+ })
86
+
87
+ canvas.addEventListener('mouseup', function(e) {
88
+ if (e.button == 2) {
89
+ var mousePos = getMousePos(canvas, e);
90
+ var diff_width = mousePos['x'] - prevX;
91
+ var diff_height = mousePos['y'] - prevY;
92
+
93
+ ctx.clearRect(prevX, prevY, diff_width, diff_height);
94
+ }
95
+ })
96
+
97
+
98
+ canvas.addEventListener('contextmenu', function(evt) {
99
+ evt.preventDefault();
100
+ })
101
+
102
+ function color(color_value){
103
+ ctx.fillStyle = color_value;
104
+ color_indicator.value = color_value;
105
+ }
106
+
107
+ color_indicator.oninput = function() {
108
+ color(this.value);
109
+ }
110
+
111
+ function rgbToHex(rgb){
112
+ return '#' + ((rgb[0] << 16) | (rgb[1] << 8) | rgb[2]).toString(16);
113
+ };
114
+
115
+ result_canvas.addEventListener('click', function(e) {
116
+ if (e.button == 0) {
117
+ var cur_pixel = result_ctx.getImageData(e.offsetX, e.offsetY, 1, 1).data;
118
+ color(rgbToHex(cur_pixel));
119
+ }
120
+ })
static/temp_images/.gitkeep ADDED
File without changes
templates/drawing.html ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Colorization app</title>
6
+ <style>
7
+ .back{
8
+ height:100%;
9
+ width:100%;
10
+ position: absolute;
11
+ background-color: yellow;
12
+ top:0px;
13
+ padding: 10px;
14
+ }
15
+
16
+ #draw_canvas {
17
+ height: 200%;
18
+ border: 3px solid black;
19
+ background-image: linear-gradient(rgba(60,60,60,.85), rgba(60,60,60,.85)), url(../static/{{img_path}});
20
+ background-color: #c7b39b;
21
+ background-size: 100%;
22
+ }
23
+ </style>
24
+
25
+ </head>
26
+ <body>
27
+ <div align="left" style ="margin : 5px; margin-left : 0px">
28
+ <input type="button" onclick="location.href='/';" value="Home" />
29
+ </div>
30
+ <p style="margin: 0; padding: 0">
31
+ Left click - colorize, right hold - remove with rectangle, left click on result - use corresponding color.
32
+
33
+ </p>
34
+ <hr style ="margin: 0; padding: 0">
35
+
36
+ <p><table>
37
+ <tr>
38
+ <td>
39
+ <table>
40
+ <tr>
41
+ <td><button style="background-color: #000000; height: 20px; width: 20px;" onclick="color('#000000')"></button>
42
+ <td><button style="background-color: #B0171F; height: 20px; width: 20px;" onclick="color('#B0171F')"></button>
43
+ </tr>
44
+ <tr>
45
+ <td><button style="background-color: #DA70D6; height: 20px; width: 20px;" onclick="color('#DA70D6')"></button>
46
+ <td><button style="background-color: #8A2BE2; height: 20px; width: 20px;" onclick="color('#8A2BE2')"></button>
47
+ </tr>
48
+ <tr>
49
+ <td><button style="background-color: #0000FF; height: 20px; width: 20px;" onclick="color('#0000FF')"></button>
50
+ <td><button style="background-color: #4876FF; height: 20px; width: 20px;" onclick="color('#4876FF')"></button>
51
+ </tr>
52
+ <tr>
53
+ <td><button style="background-color: #CAE1FF; height: 20px; width: 20px;" onclick="color('#CAE1FF')"></button>
54
+ <td><button style="background-color: #6E7B8B; height: 20px; width: 20px;" onclick="color('#6E7B8B')"></button>
55
+ </tr>
56
+ <tr>
57
+ <td><button style="background-color: #00C78C; height: 20px; width: 20px;" onclick="color('#00C78C')"></button>
58
+ <td><button style="background-color: #00FA9A; height: 20px; width: 20px;" onclick="color('#00FA9A')"></button>
59
+ </tr>
60
+ <tr>
61
+ <td><button style="background-color: #00FF7F; height: 20px; width: 20px;" onclick="color('#00FF7F')"></button>
62
+ <td><button style="background-color: #00C957; height: 20px; width: 20px;" onclick="color('#00C957')"></button>
63
+ </tr>
64
+ <tr>
65
+ <td><button style="background-color: #3D9140; height: 20px; width: 20px;" onclick="color('#3D9140')"></button>
66
+ <td><button style="background-color: #32CD32; height: 20px; width: 20px;" onclick="color('#32CD32')"></button>
67
+ </tr>
68
+ <tr>
69
+ <td><button style="background-color: #00EE00; height: 20px; width: 20px;" onclick="color('#00EE00')"></button>
70
+
71
+ <td><button style="background-color: #008B00; height: 20px; width: 20px;" onclick="color('#008B00')"></button>
72
+ </tr>
73
+ <tr>
74
+ <td><button style="background-color: #76EE00; height: 20px; width: 20px;" onclick="color('#76EE00')"></button>
75
+
76
+ <td><button style="background-color: #CAFF70; height: 20px; width: 20px;" onclick="color('#CAFF70')"></button>
77
+ </tr>
78
+ <tr>
79
+ <td><button style="background-color: #FFFF00; height: 20px; width: 20px;" onclick="color('#FFFF00')"></button>
80
+
81
+ <td><button style="background-color: #CDCD00; height: 20px; width: 20px;" onclick="color('#CDCD00')"></button>
82
+ </tr>
83
+ <tr>
84
+ <td><button style="background-color: #FFF68F; height: 20px; width: 20px;" onclick="color('#FFF68F')"></button>
85
+
86
+ <td><button style="background-color: #FFFACD; height: 20px; width: 20px;" onclick="color('#FFFACD')"></button>
87
+ </tr>
88
+ <tr>
89
+ <td><button style="background-color: #FFEC8B; height: 20px; width: 20px;" onclick="color('#FFEC8B')"></button>
90
+
91
+ <td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
92
+ </tr>
93
+ <tr>
94
+ <td><button style="background-color: #F5DEB3; height: 20px; width: 20px;" onclick="color('#F5DEB3')"></button>
95
+
96
+ <td><button style="background-color: #FFE4B5; height: 20px; width: 20px;" onclick="color('#FFE4B5')"></button>
97
+ </tr>
98
+ <tr>
99
+ <td><button style="background-color: #EECFA1; height: 20px; width: 20px;" onclick="color('#EECFA1')"></button>
100
+
101
+ <td><button style="background-color: #FF9912; height: 20px; width: 20px;" onclick="color('#FF9912')"></button>
102
+ </tr>
103
+ <tr>
104
+ <td><button style="background-color: #8E388E; height: 20px; width: 20px;" onclick="color('#8E388E')"></button>
105
+
106
+ <td><button style="background-color: #7171C6; height: 20px; width: 20px;" onclick="color('#7171C6')"></button>
107
+ </tr>
108
+
109
+ <tr>
110
+ <td><button style="background-color: #7D9EC0; height: 20px; width: 20px;" onclick="color('#7D9EC0')"></button>
111
+
112
+ <td><button style="background-color: #388E8E; height: 20px; width: 20px;" onclick="color('#388E8E')"></button>
113
+
114
+ </tr>
115
+
116
+ <tr>
117
+ <td><button style="background-color: #71C671; height: 20px; width: 20px;" onclick="color('#71C671')"></button>
118
+
119
+ <td><button style="background-color: #8E8E38; height: 20px; width: 20px;" onclick="color('#8E8E38')"></button>
120
+ </tr>
121
+ <tr>
122
+ <td><button style="background-color: #C5C1AA; height: 20px; width: 20px;" onclick="color('#C5C1AA')"></button>
123
+
124
+ <td><button style="background-color: #C67171; height: 20px; width: 20px;" onclick="color('#C67171')"></button>
125
+ </tr>
126
+ <tr>
127
+ <td><button style="background-color: #555555; height: 20px; width: 20px;" onclick="color('#555555')"></button>
128
+ <td><button style="background-color: #848484; height: 20px; width: 20px;" onclick="color('#848484')"></button>
129
+ </tr>
130
+ <tr>
131
+ <td><button style="background-color: #FFFFFF; height: 20px; width: 20px;" onclick="color('#FFFFFF')"></button>
132
+ <td><button style="background-color: #EE0000; height: 20px; width: 20px;" onclick="color('#EE0000')"></button>
133
+ </tr>
134
+ <tr>
135
+ <td><button style="background-color: #FF4040; height: 20px; width: 20px;" onclick="color('#FF4040')"></button>
136
+ <td><button style="background-color: #EE6363; height: 20px; width: 20px;" onclick="color('#EE6363')"></button>
137
+ </tr>
138
+ <tr>
139
+ <td><button style="background-color: #FFC1C1; height: 20px; width: 20px;" onclick="color('#FFC1C1')"></button>
140
+ <td><button style="background-color: #FF7256; height: 20px; width: 20px;" onclick="color('#FF7256')"></button>
141
+ </tr>
142
+ <tr>
143
+ <td><button style="background-color: #FF4500; height: 20px; width: 20px;" onclick="color('#FF4500')"></button>
144
+ <td><button style="background-color: #F4A460; height: 20px; width: 20px;" onclick="color('#F4A460')"></button>
145
+ </tr>
146
+ <tr>
147
+ <td><button style="background-color: #FF8000; height: 20px; width: 20px;" onclick="color('FF8000')"></button>
148
+ <td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
149
+ </tr>
150
+ <tr>
151
+ <td><button style="background-color: #8B864E; height: 20px; width: 20px;" onclick="color('#8B864E')"></button>
152
+ <td><button style="background-color: #9ACD32; height: 20px; width: 20px;" onclick="color('#9ACD32')"></button>
153
+ </tr>
154
+ <tr>
155
+ <td><button style="background-color: #66CD00; height: 20px; width: 20px;" onclick="color('#66CD00')"></button>
156
+ <td><button style="background-color: #BDFCC9; height: 20px; width: 20px;" onclick="color('#BDFCC9')"></button>
157
+ </tr>
158
+ <tr>
159
+ <td><button style="background-color: #76EEC6; height: 20px; width: 20px;" onclick="color('#76EEC6')"></button>
160
+ <td><button style="background-color: #40E0D0; height: 20px; width: 20px;" onclick="color('#40E0D0')"></button>
161
+ </tr>
162
+ <tr>
163
+ <td><button style="background-color: #E0EEEE; height: 20px; width: 20px;" onclick="color('#E0EEEE')"></button>
164
+ <td><button style="background-color: #98F5FF; height: 20px; width: 20px;" onclick="color('#98F5FF')"></button>
165
+ </tr>
166
+ <tr>
167
+ <td><button style="background-color: #33A1C9; height: 20px; width: 20px;" onclick="color('#33A1C9')"></button>
168
+ <td><button style="background-color: #F0F8FF; height: 20px; width: 20px;" onclick="color('#F0F8FF')"></button>
169
+ </tr>
170
+ <tr>
171
+ <td><button style="background-color: #4682B4; height: 20px; width: 20px;" onclick="color('#4682B4')"></button>
172
+ <td><button style="background-color: #C6E2FF; height: 20px; width: 20px;" onclick="color('#C6E2FF')"></button>
173
+ </tr>
174
+ <tr>
175
+ <td><button style="background-color: #9B30FF; height: 20px; width: 20px;" onclick="color('#9B30FF')"></button>
176
+ <td><button style="background-color: #EE82EE; height: 20px; width: 20px;" onclick="color('#EE82EE')"></button>
177
+ </tr>
178
+ <tr>
179
+ <td><button style="background-color: #FFC0CB; height: 20px; width: 20px;" onclick="color('#FFC0CB')"></button>
180
+ <td><button style="background-color: #7CFC00; height: 20px; width: 20px;" onclick="color('#7CFC00')"></button>
181
+ </tr>
182
+ <tr>
183
+ <input type="color" id="color">
184
+ </tr>
185
+ </table>
186
+ </td>
187
+ <td>
188
+ <div style="width: 1150px; height: 800px; overflow: auto"><canvas align = "center" id="draw_canvas" width="{{width}}" height="{{height}}"></canvas></div>
189
+ </td>
190
+ <td>
191
+ <canvas id='result'></canvas>
192
+ </td>
193
+ </tr>
194
+ </table></p>
195
+
196
+ <button style="height: 20px; width: 80px" onclick="reset()">Clear</button>
197
+
198
+ <button style="height: 20px; width: 80px" onclick="colorize()" >Colorize</button>
199
+
200
+
201
+
202
+ <script src="http://code.jquery.com/jquery-1.8.3.js"></script>
203
+ <script src="/static/js/draw.js">
204
+ </script>
205
+ </body>
206
+ </html>
templates/submit.html ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <form method="POST" action="/" enctype="multipart/form-data" target="_blank">
2
+ {{ form.hidden_tag() }}
3
+ {{ form.file.label }} {{ form.file(size=20) }}
4
+ {{ form.denoise.label }} {{ form.denoise(size=5) }}
5
+ {{ form.denoise_sigma.label }} {{ form.denoise_sigma(size=5) }}
6
+ {{ form.autohint.label }} {{ form.autohint(size=5) }}
7
+ {{ form.autohint_sigma.label }} {{ form.autohint_sigma(size=5) }}
8
+ {{ form.ignore_gray.label }} {{ form.ignore_gray(size=5) }}
9
+ <input type="submit" value="Colorize">
10
+
11
+ </form>
templates/upload.html ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <style>
2
+ form{
3
+ position:fixed;
4
+ top:50%;
5
+ left:45%;
6
+ width:1250px;
7
+ }
8
+ </style>
9
+
10
+
11
+ <form id="form" method="POST" action="/" enctype="multipart/form-data">
12
+ {{ form.hidden_tag() }}
13
+ {{ form.file(size=20) }}
14
+ </form>
15
+
16
+ <script>
17
+ document.getElementById("file").onchange = function() {
18
+ document.getElementById("form").submit();
19
+ };
20
+ </script>
train.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import albumentations as albu
6
+ import argparse
7
+ import datetime
8
+
9
+ from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
10
+ from model.models import Colorizer, Generator, Content, Discriminator
11
+ from model.extractor import get_seresnext_extractor
12
+ from dataset.datasets import TrainDataset, FineTuningDataset
13
+
14
+
15
+
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("-p", "--path", required=True, help = "dataset path")
19
+ parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true')
20
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
21
+ parser.set_defaults(fine_tuning = False)
22
+ parser.set_defaults(gpu = False)
23
+ args = parser.parse_args()
24
+
25
+ return args
26
+
27
+ def get_transforms():
28
+ return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
29
+
30
+ def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
+ train_dataset = TrainDataset(data_path, transforms, mult_number)
32
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
33
+
34
+ if fine_tuning:
35
+ finetuning_dataset = FineTuningDataset(data_path, transforms)
36
+ finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
37
+
38
+ return train_dataloader, finetuning_dataloader
39
+
40
+ def get_models(device):
41
+ generator = Generator()
42
+ extractor = get_seresnext_extractor()
43
+ colorizer = Colorizer(generator, extractor)
44
+
45
+ colorizer.extractor_eval()
46
+ colorizer = colorizer.to(device)
47
+
48
+ discriminator = Discriminator().to(device)
49
+
50
+ content = Content('model/vgg16-397923af.pth').eval().to(device)
51
+ for param in content.parameters():
52
+ param.requires_grad = False
53
+
54
+ return colorizer, discriminator, content
55
+
56
+ def set_weights(colorizer, discriminator):
57
+ colorizer.generator.apply(weights_init)
58
+ colorizer.load_extractor_weights(torch.load('model/extractor.pth'))
59
+
60
+ discriminator.apply(weights_init_spectr)
61
+
62
+ def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
63
+ sim_loss_full = dist_loss(main_output, real_image)
64
+ sim_loss_guide = dist_loss(guide_output, real_image)
65
+
66
+ adv_loss = class_loss(disc_output, true_labels)
67
+
68
+ content_loss = content_dist_loss(content_gen, content_true)
69
+
70
+ sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
71
+
72
+ return sum_loss
73
+
74
+ def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
75
+ optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
76
+ optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9))
77
+
78
+ return optimizerG, optimizerD
79
+
80
+ def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
81
+ for p in discriminator.parameters():
82
+ p.requires_grad = False
83
+ for p in colorizer.generator.parameters():
84
+ p.requires_grad = True
85
+
86
+ colorizer.generator.zero_grad()
87
+
88
+ bw, color, hint, dfm = inputs
89
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
90
+
91
+ fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
92
+
93
+ logits_fake = discriminator(fake)
94
+ y_real = torch.ones((bw.size(0), 1), device = device)
95
+
96
+ content_fake = content(fake)
97
+ with torch.no_grad():
98
+ content_true = content(color)
99
+
100
+ generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
101
+
102
+ if white_penalty:
103
+ mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
104
+ white_zones = mask * (fake + 1) / 2
105
+ white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
106
+
107
+ generator_loss += white_penalty
108
+
109
+ generator_loss.backward()
110
+
111
+ optimizer.step()
112
+
113
+ return generator_loss.item()
114
+
115
+ def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
116
+
117
+ for p in discriminator.parameters():
118
+ p.requires_grad = True
119
+ for p in colorizer.generator.parameters():
120
+ p.requires_grad = False
121
+
122
+ discriminator.zero_grad()
123
+
124
+ bw, color, hint, dfm = inputs
125
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
126
+
127
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
128
+
129
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
130
+
131
+ with torch.no_grad():
132
+ fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
133
+ fake_color.detach()
134
+
135
+ logits_fake = discriminator(fake_color)
136
+ logits_real = discriminator(color)
137
+
138
+ fake_loss = loss_function(logits_fake, y_fake)
139
+ real_loss = loss_function(logits_real, y_real)
140
+
141
+ discriminator_loss = real_loss + fake_loss
142
+
143
+ discriminator_loss.backward()
144
+ optimizer.step()
145
+
146
+ return discriminator_loss.item()
147
+
148
+ def decrease_lr(optimizer, rate):
149
+ for group in optimizer.param_groups:
150
+ group['lr'] /= rate
151
+
152
+ def set_lr(optimizer, value):
153
+ for group in optimizer.param_groups:
154
+ group['lr'] = value
155
+
156
+ def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
157
+ colorizer.generator.train()
158
+ discriminator.train()
159
+
160
+ disc_step = True
161
+
162
+ for epoch in range(epochs):
163
+ if (epoch == lr_decay_epoch):
164
+ decrease_lr(colorizer_optimizer, 10)
165
+ decrease_lr(discriminator_optimizer, 10)
166
+
167
+ sum_disc_loss = 0
168
+ sum_gen_loss = 0
169
+
170
+ for n, inputs in enumerate(dataloader):
171
+ if n % 5 == 0:
172
+ print(datetime.datetime.now().time())
173
+ print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
174
+
175
+
176
+ if disc_step:
177
+ step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
178
+ sum_disc_loss += step_loss
179
+ else:
180
+ step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
181
+ sum_gen_loss += step_loss
182
+
183
+ disc_step = disc_step ^ True
184
+
185
+
186
+ print(datetime.datetime.now().time())
187
+ print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
188
+
189
+
190
+ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
191
+
192
+ for p in discriminator.parameters():
193
+ p.requires_grad = True
194
+ for p in colorizer.generator.parameters():
195
+ p.requires_grad = False
196
+
197
+ for cur_disc_step in range(5):
198
+ discriminator.zero_grad()
199
+
200
+ bw, dfm, color_for_real = data_iter.next()
201
+ bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
202
+
203
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
204
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
205
+
206
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
207
+
208
+ with torch.no_grad():
209
+ fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
210
+ fake_color_manga.detach()
211
+
212
+ logits_fake = discriminator(fake_color_manga)
213
+ logits_real = discriminator(color_for_real)
214
+
215
+ fake_loss = loss_function(logits_fake, y_fake)
216
+ real_loss = loss_function(logits_real, y_real)
217
+ discriminator_loss = real_loss + fake_loss
218
+
219
+ discriminator_loss.backward()
220
+ disc_optimizer.step()
221
+
222
+
223
+ for p in discriminator.parameters():
224
+ p.requires_grad = False
225
+ for p in colorizer.generator.parameters():
226
+ p.requires_grad = True
227
+
228
+ colorizer.generator.zero_grad()
229
+
230
+ bw, dfm, _ = data_iter.next()
231
+ bw, dfm = bw.to(device), dfm.to(device)
232
+
233
+ y_real = torch.ones((bw.size(0), 1), device = device)
234
+
235
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
236
+
237
+ fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
238
+
239
+ logits_fake = discriminator(fake_manga)
240
+ adv_loss = loss_function(logits_fake, y_real)
241
+
242
+ generator_loss = adv_loss
243
+
244
+ generator_loss.backward()
245
+ gen_optimizer.step()
246
+
247
+
248
+
249
+ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
250
+ colorizer.generator.train()
251
+ discriminator.train()
252
+
253
+ disc_step = True
254
+
255
+ for n, inputs in enumerate(dataloader):
256
+
257
+ if n == iterations:
258
+ return
259
+
260
+ if disc_step:
261
+ discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
262
+ else:
263
+ generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
264
+
265
+ disc_step = disc_step ^ True
266
+
267
+ if n % 10 == 5:
268
+ fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
269
+
270
+ if __name__ == '__main__':
271
+ args = parse_args()
272
+ config = open_json('configs/train_config.json')
273
+
274
+ if args.gpu:
275
+ device = 'cuda'
276
+ else:
277
+ device = 'cpu'
278
+
279
+ augmentations = get_transforms()
280
+
281
+ train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
282
+
283
+ colorizer, discriminator, content = get_models(device)
284
+ set_weights(colorizer, discriminator)
285
+
286
+ gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
287
+
288
+ train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
289
+
290
+ if args.fine_tuning:
291
+ set_lr(gen_optimizer, config["finetuning_generator_lr"])
292
+ fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
293
+
294
+ torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
train/110_dfm.png ADDED

Git LFS Details

  • SHA256: 27e3161ab5d726acf99904a845cde3c4dc74534d1ce5608a7e37ae0818c8aa0c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
train/111_dfm.png ADDED

Git LFS Details

  • SHA256: 217966c360f19eb533060cc48b30aee7ea574dcb5e324044158be7e93e67aa1f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
train/112_dfm.png ADDED
train/113_dfm.png ADDED

Git LFS Details

  • SHA256: c8a8e98c1a38150c55742ca5c1e3fc515170a545ebd1b79bafba76e10cd5f738
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
train/114_dfm.png ADDED
train/115_dfm.png ADDED
train/116_dfm.png ADDED
train/117_dfm.png ADDED
train/118_dfm.png ADDED
train/119_dfm.png ADDED
train/11_dfm.png ADDED

Git LFS Details

  • SHA256: dac26bfeecf844f8c562c72aa587a7a15e791950e0711df2562c0629a2f842ed
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
train/120_dfm.png ADDED
train/121_dfm.png ADDED
train/122_dfm.png ADDED

Git LFS Details

  • SHA256: 92317a01e04588b1fb713ee677b39e9c0e542d63068aa761505c1f99f1cf1a77
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
train/123_dfm.png ADDED
train/124_dfm.png ADDED
train/125_dfm.png ADDED
train/126_dfm.png ADDED
train/127_dfm.png ADDED

Git LFS Details

  • SHA256: a2b9c44e5ede86bf6f789bf132158e2b5b253fca28cdbfd14c3aa6b3e24ccf37
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
train/128_dfm.png ADDED