Keiser41 commited on
Commit
62456b0
1 Parent(s): 6bdd788

Upload 47 files

Browse files
Files changed (47) hide show
  1. .dockerignore +6 -0
  2. .gitignore +10 -0
  3. Dockerfile +11 -0
  4. configs/train_config.json +10 -0
  5. configs/xdog_config.json +8 -0
  6. dataset/__pycache__/datasets.cpython-310.pyc +0 -0
  7. dataset/__pycache__/datasets.cpython-36.pyc +0 -0
  8. dataset/__pycache__/datasets.cpython-39.pyc +0 -0
  9. dataset/datasets.py +91 -0
  10. denoising/denoiser.py +113 -0
  11. denoising/functions.py +101 -0
  12. denoising/models.py +422 -0
  13. denoising/models/.gitkeep +0 -0
  14. denoising/utils.py +66 -0
  15. drawing.py +165 -0
  16. inference.py +215 -0
  17. model/__pycache__/extractor.cpython-310.pyc +0 -0
  18. model/__pycache__/extractor.cpython-36.pyc +0 -0
  19. model/__pycache__/extractor.cpython-39.pyc +0 -0
  20. model/__pycache__/models.cpython-310.pyc +0 -0
  21. model/__pycache__/models.cpython-36.pyc +0 -0
  22. model/__pycache__/models.cpython-39.pyc +0 -0
  23. model/extractor.pth +3 -0
  24. model/extractor.py +127 -0
  25. model/models.py +422 -0
  26. model/vgg16-397923af.pth +3 -0
  27. readme.md +22 -0
  28. requirements.txt +10 -0
  29. run_drawing.sh +1 -0
  30. static/js/draw.js +120 -0
  31. static/temp_images/.gitkeep +0 -0
  32. templates/drawing.html +206 -0
  33. templates/submit.html +11 -0
  34. templates/upload.html +20 -0
  35. train.py +293 -0
  36. train/bw/blackclover_cl268.png +0 -0
  37. train/bw/dfm_blackclover_cl268.png +0 -0
  38. train/color/blackclover_cl268.png +0 -0
  39. train/real_manga/blackclover_cl268.png +0 -0
  40. train/real_manga/dfm_blackclover_cl268.png +0 -0
  41. utils/__pycache__/utils.cpython-310.pyc +0 -0
  42. utils/__pycache__/utils.cpython-36.pyc +0 -0
  43. utils/__pycache__/utils.cpython-39.pyc +0 -0
  44. utils/dataset_utils.py +141 -0
  45. utils/utils.py +102 -0
  46. utils/xdog.py +68 -0
  47. web.py +108 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.ipynb
2
+
3
+ model/*.pth
4
+
5
+ temp_colorization/
6
+ __pycache__/
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.ipynb
2
+ *.pth
3
+ *.zip
4
+
5
+ __pycache__/
6
+ temp_colorization/
7
+
8
+ static/temp_images/*
9
+
10
+ !.gitkeep
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
2
+
3
+ RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
4
+
5
+ COPY . .
6
+
7
+ RUN pip install --no-cache-dir -r ./requirements.txt
8
+
9
+ EXPOSE 5000
10
+
11
+ CMD gunicorn --timeout 200 -w 3 -b 0.0.0.0:5000 drawing:app
configs/train_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generator_lr" : 1e-4,
3
+ "discriminator_lr" : 4e-4,
4
+ "epochs" : 15,
5
+ "lr_decrease_epoch" : 10,
6
+ "finetuning_generator_lr" : 1e-6,
7
+ "finetuning_iterations" : 3500,
8
+ "batch_size" : 1,
9
+ "number_of_mults" : 1
10
+ }
configs/xdog_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sigma" : 0.5,
3
+ "k" : 8,
4
+ "phi" : 89.25,
5
+ "gamma" : 0.95,
6
+ "eps" : -0.1,
7
+ "mult" : 7
8
+ }
dataset/__pycache__/datasets.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
dataset/__pycache__/datasets.cpython-36.pyc ADDED
Binary file (3.23 kB). View file
 
dataset/__pycache__/datasets.cpython-39.pyc ADDED
Binary file (3.56 kB). View file
 
dataset/datasets.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import os
4
+ import numpy as np
5
+ import torchvision.transforms as transforms
6
+ from utils.utils import generate_mask
7
+
8
+ class TrainDataset(torch.utils.data.Dataset):
9
+ def __init__(self, data_path, transform=None):
10
+ self.data = os.listdir(os.path.join(data_path, 'color'))
11
+ self.data_path = data_path
12
+ self.transform = transform
13
+ self.ToTensor = transforms.ToTensor()
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ image_name = self.data[idx]
20
+
21
+ color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
22
+ bw_name = self.data[idx]
23
+ dfm_name = 'dfm_' + self.data[idx]
24
+
25
+ bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
26
+ dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
27
+
28
+ color_img = np.array(color_img)
29
+ bw_img = np.array(bw_img)
30
+ dfm_img = np.array(dfm_img)
31
+
32
+ bw_img = np.expand_dims(bw_img, 2)
33
+ dfm_img = np.expand_dims(dfm_img, 2)
34
+ bw_img = np.concatenate([bw_img, dfm_img], axis=2)
35
+
36
+ if self.transform:
37
+ result = self.transform(image=color_img, mask=bw_img)
38
+ color_img = result['image']
39
+ bw_img = result['mask']
40
+
41
+ color_img = self.ToTensor(color_img)
42
+ bw_img = self.ToTensor(bw_img)
43
+ color_img = (color_img - 0.5) / 0.5 # Normalización de color_img
44
+
45
+ mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
46
+ hint = torch.cat((color_img * mask, mask), 0)
47
+
48
+ return bw_img, bw_img, color_img, hint
49
+
50
+ class FineTuningDataset(torch.utils.data.Dataset):
51
+ def __init__(self, data_path, transform=None, mult_amount=1):
52
+ self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
53
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
54
+ self.data_path = data_path
55
+ self.transform = transform
56
+ self.mults_amount = mult_amount
57
+
58
+ np.random.shuffle(self.color_data)
59
+ self.ToTensor = transforms.ToTensor()
60
+
61
+ def __len__(self):
62
+ return len(self.data)
63
+
64
+ def __getitem__(self, idx):
65
+ image_name = self.data[idx]
66
+
67
+ color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
68
+ bw_name = self.data[idx]
69
+ dfm_name = 'dfm_' + self.data[idx]
70
+
71
+ bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
72
+ dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
73
+
74
+ color_img = np.array(color_img)
75
+ bw_img = np.array(bw_img)
76
+ dfm_img = np.array(dfm_img)
77
+
78
+ bw_img = np.expand_dims(bw_img, 2)
79
+ dfm_img = np.expand_dims(dfm_img, 2)
80
+ bw_img = np.concatenate([bw_img, dfm_img], axis=2)
81
+
82
+ if self.transform:
83
+ result = self.transform(image=color_img, mask=bw_img)
84
+ color_img = result['image']
85
+ bw_img = result['mask']
86
+
87
+ color_img = self.ToTensor(color_img)
88
+ bw_img = self.ToTensor(bw_img)
89
+ color_img = (color_img - 0.5) / 0.5 # Normalización de color_img
90
+
91
+ return bw_img, color_img # Devuelve bw_img una vez y color_img
denoising/denoiser.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Denoise an image with the FFDNet denoising method
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import os
14
+ import argparse
15
+ import time
16
+ import numpy as np
17
+ import cv2
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.autograd import Variable
21
+ from denoising.models import FFDNet
22
+ from denoising.utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
23
+
24
+ class FFDNetDenoiser:
25
+ def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
26
+ self.sigma = _sigma / 255
27
+ self.weights_dir = _weights_dir
28
+ self.channels = _in_ch
29
+ self.device = _device
30
+
31
+ self.model = FFDNet(num_input_channels = _in_ch)
32
+ self.load_weights()
33
+ self.model.eval()
34
+
35
+
36
+ def load_weights(self):
37
+ weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
38
+ weights_path = os.path.join(self.weights_dir, weights_name)
39
+ if self.device == 'cuda':
40
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
41
+ device_ids = [0]
42
+ self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
43
+ else:
44
+ state_dict = torch.load(weights_path, map_location='cpu')
45
+ # CPU mode: remove the DataParallel wrapper
46
+ state_dict = remove_dataparallel_wrapper(state_dict)
47
+ self.model.load_state_dict(state_dict)
48
+
49
+ def get_denoised_image(self, imorig, sigma = None):
50
+
51
+ if sigma is not None:
52
+ cur_sigma = sigma / 255
53
+ else:
54
+ cur_sigma = self.sigma
55
+
56
+ if len(imorig.shape) < 3 or imorig.shape[2] == 1:
57
+ imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
58
+
59
+ if (max(imorig.shape[0], imorig.shape[1]) > 1200):
60
+ ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
61
+ imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
62
+
63
+ imorig = imorig.transpose(2, 0, 1)
64
+
65
+ if (imorig.max() > 1.2):
66
+ imorig = normalize(imorig)
67
+ imorig = np.expand_dims(imorig, 0)
68
+
69
+ # Handle odd sizes
70
+ expanded_h = False
71
+ expanded_w = False
72
+ sh_im = imorig.shape
73
+ if sh_im[2]%2 == 1:
74
+ expanded_h = True
75
+ imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
76
+
77
+ if sh_im[3]%2 == 1:
78
+ expanded_w = True
79
+ imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
80
+
81
+
82
+ imorig = torch.Tensor(imorig)
83
+
84
+
85
+ # Sets data type according to CPU or GPU modes
86
+ if self.device == 'cuda':
87
+ dtype = torch.cuda.FloatTensor
88
+ else:
89
+ dtype = torch.FloatTensor
90
+
91
+ imnoisy = imorig.clone()
92
+
93
+
94
+ with torch.no_grad():
95
+ imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
96
+ nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
97
+
98
+
99
+ # Estimate noise and subtract it to the input image
100
+ im_noise_estim = self.model(imnoisy, nsigma)
101
+ outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
102
+
103
+ if expanded_h:
104
+ imorig = imorig[:, :, :-1, :]
105
+ outim = outim[:, :, :-1, :]
106
+ imnoisy = imnoisy[:, :, :-1, :]
107
+
108
+ if expanded_w:
109
+ imorig = imorig[:, :, :, :-1]
110
+ outim = outim[:, :, :, :-1]
111
+ imnoisy = imnoisy[:, :, :, :-1]
112
+
113
+ return variable_to_cv2_image(outim)
denoising/functions.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions implementing custom NN layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch
14
+ from torch.autograd import Function, Variable
15
+
16
+ def concatenate_input_noise_map(input, noise_sigma):
17
+ r"""Implements the first layer of FFDNet. This function returns a
18
+ torch.autograd.Variable composed of the concatenation of the downsampled
19
+ input image and the noise map. Each image of the batch of size CxHxW gets
20
+ converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
21
+ non-overlapped 2x2 patches of the input image are placed in the new array
22
+ along the first dimension.
23
+
24
+ Args:
25
+ input: batch containing CxHxW images
26
+ noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
27
+ """
28
+ # noise_sigma is a list of length batch_size
29
+ N, C, H, W = input.size()
30
+ dtype = input.type()
31
+ sca = 2
32
+ sca2 = sca*sca
33
+ Cout = sca2*C
34
+ Hout = H//sca
35
+ Wout = W//sca
36
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
37
+
38
+ # Fill the downsampled image with zeros
39
+ if 'cuda' in dtype:
40
+ downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
41
+ else:
42
+ downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
43
+
44
+ # Build the CxH/2xW/2 noise map
45
+ noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
46
+
47
+ # Populate output
48
+ for idx in range(sca2):
49
+ downsampledfeatures[:, idx:Cout:sca2, :, :] = \
50
+ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
51
+
52
+ # concatenate de-interleaved mosaic with noise map
53
+ return torch.cat((noise_map, downsampledfeatures), 1)
54
+
55
+ class UpSampleFeaturesFunction(Function):
56
+ r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
57
+ This class implements the forward and backward methods of the last layer
58
+ of FFDNet. It basically performs the inverse of
59
+ concatenate_input_noise_map(): it converts each of the images of a
60
+ batch of size CxH/2xW/2 to images of size C/4xHxW
61
+ """
62
+ @staticmethod
63
+ def forward(ctx, input):
64
+ N, Cin, Hin, Win = input.size()
65
+ dtype = input.type()
66
+ sca = 2
67
+ sca2 = sca*sca
68
+ Cout = Cin//sca2
69
+ Hout = Hin*sca
70
+ Wout = Win*sca
71
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
72
+
73
+ assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
74
+
75
+ result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
76
+ for idx in range(sca2):
77
+ result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
78
+
79
+ return result
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output):
83
+ N, Cg_out, Hg_out, Wg_out = grad_output.size()
84
+ dtype = grad_output.data.type()
85
+ sca = 2
86
+ sca2 = sca*sca
87
+ Cg_in = sca2*Cg_out
88
+ Hg_in = Hg_out//sca
89
+ Wg_in = Wg_out//sca
90
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
91
+
92
+ # Build output
93
+ grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
94
+ # Populate output
95
+ for idx in range(sca2):
96
+ grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
97
+
98
+ return Variable(grad_input)
99
+
100
+ # Alias functions
101
+ upsamplefeatures = UpSampleFeaturesFunction.apply
denoising/models.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as M
5
+ import math
6
+ from torch import Tensor
7
+ from torch.nn import Parameter
8
+
9
+ '''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
10
+
11
+ def l2normalize(v, eps=1e-12):
12
+ return v / (v.norm() + eps)
13
+
14
+
15
+ class SpectralNorm(nn.Module):
16
+ def __init__(self, module, name='weight', power_iterations=1):
17
+ super(SpectralNorm, self).__init__()
18
+ self.module = module
19
+ self.name = name
20
+ self.power_iterations = power_iterations
21
+ if not self._made_params():
22
+ self._make_params()
23
+
24
+ def _update_u_v(self):
25
+ u = getattr(self.module, self.name + "_u")
26
+ v = getattr(self.module, self.name + "_v")
27
+ w = getattr(self.module, self.name + "_bar")
28
+
29
+ height = w.data.shape[0]
30
+ for _ in range(self.power_iterations):
31
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
32
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
33
+
34
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
35
+ sigma = u.dot(w.view(height, -1).mv(v))
36
+ setattr(self.module, self.name, w / sigma.expand_as(w))
37
+
38
+ def _made_params(self):
39
+ try:
40
+ u = getattr(self.module, self.name + "_u")
41
+ v = getattr(self.module, self.name + "_v")
42
+ w = getattr(self.module, self.name + "_bar")
43
+ return True
44
+ except AttributeError:
45
+ return False
46
+
47
+
48
+ def _make_params(self):
49
+ w = getattr(self.module, self.name)
50
+ height = w.data.shape[0]
51
+ width = w.view(height, -1).data.shape[1]
52
+
53
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
54
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
55
+ u.data = l2normalize(u.data)
56
+ v.data = l2normalize(v.data)
57
+ w_bar = Parameter(w.data)
58
+
59
+ del self.module._parameters[self.name]
60
+
61
+ self.module.register_parameter(self.name + "_u", u)
62
+ self.module.register_parameter(self.name + "_v", v)
63
+ self.module.register_parameter(self.name + "_bar", w_bar)
64
+
65
+
66
+ def forward(self, *args):
67
+ self._update_u_v()
68
+ return self.module.forward(*args)
69
+
70
+ class Selayer(nn.Module):
71
+ def __init__(self, inplanes):
72
+ super(Selayer, self).__init__()
73
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
74
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
75
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.sigmoid = nn.Sigmoid()
78
+
79
+ def forward(self, x):
80
+ out = self.global_avgpool(x)
81
+ out = self.conv1(out)
82
+ out = self.relu(out)
83
+ out = self.conv2(out)
84
+ out = self.sigmoid(out)
85
+
86
+ return x * out
87
+
88
+ class SelayerSpectr(nn.Module):
89
+ def __init__(self, inplanes):
90
+ super(SelayerSpectr, self).__init__()
91
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
92
+ self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
93
+ self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
94
+ self.relu = nn.ReLU(inplace=True)
95
+ self.sigmoid = nn.Sigmoid()
96
+
97
+ def forward(self, x):
98
+ out = self.global_avgpool(x)
99
+ out = self.conv1(out)
100
+ out = self.relu(out)
101
+ out = self.conv2(out)
102
+ out = self.sigmoid(out)
103
+
104
+ return x * out
105
+
106
+ class ResNeXtBottleneck(nn.Module):
107
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
108
+ super(ResNeXtBottleneck, self).__init__()
109
+ D = out_channels // 2
110
+ self.out_channels = out_channels
111
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
112
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
113
+ groups=cardinality,
114
+ bias=False)
115
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
116
+ self.shortcut = nn.Sequential()
117
+ if stride != 1:
118
+ self.shortcut.add_module('shortcut',
119
+ nn.AvgPool2d(2, stride=2))
120
+
121
+ self.selayer = Selayer(out_channels)
122
+
123
+ def forward(self, x):
124
+ bottleneck = self.conv_reduce.forward(x)
125
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
126
+ bottleneck = self.conv_conv.forward(bottleneck)
127
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
128
+ bottleneck = self.conv_expand.forward(bottleneck)
129
+ bottleneck = self.selayer(bottleneck)
130
+
131
+ x = self.shortcut.forward(x)
132
+ return x + bottleneck
133
+
134
+ class SpectrResNeXtBottleneck(nn.Module):
135
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
136
+ super(SpectrResNeXtBottleneck, self).__init__()
137
+ D = out_channels // 2
138
+ self.out_channels = out_channels
139
+ self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
140
+ self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
141
+ groups=cardinality,
142
+ bias=False))
143
+ self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
144
+ self.shortcut = nn.Sequential()
145
+ if stride != 1:
146
+ self.shortcut.add_module('shortcut',
147
+ nn.AvgPool2d(2, stride=2))
148
+
149
+ self.selayer = SelayerSpectr(out_channels)
150
+
151
+ def forward(self, x):
152
+ bottleneck = self.conv_reduce.forward(x)
153
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
154
+ bottleneck = self.conv_conv.forward(bottleneck)
155
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
156
+ bottleneck = self.conv_expand.forward(bottleneck)
157
+ bottleneck = self.selayer(bottleneck)
158
+
159
+ x = self.shortcut.forward(x)
160
+ return x + bottleneck
161
+
162
+ class FeatureConv(nn.Module):
163
+ def __init__(self, input_dim=512, output_dim=512):
164
+ super(FeatureConv, self).__init__()
165
+
166
+ no_bn = True
167
+
168
+ seq = []
169
+ seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
170
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
171
+ seq.append(nn.ReLU(inplace=True))
172
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
173
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
174
+ seq.append(nn.ReLU(inplace=True))
175
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
176
+ seq.append(nn.ReLU(inplace=True))
177
+
178
+ self.network = nn.Sequential(*seq)
179
+
180
+ def forward(self, x):
181
+ return self.network(x)
182
+
183
+ class Generator(nn.Module):
184
+ def __init__(self, ngf=64):
185
+ super(Generator, self).__init__()
186
+
187
+ self.feature_conv = FeatureConv()
188
+
189
+ self.to0 = self._make_encoder_block_first(6, 32)
190
+ self.to1 = self._make_encoder_block(32, 64)
191
+ self.to2 = self._make_encoder_block(64, 128)
192
+ self.to3 = self._make_encoder_block(128, 256)
193
+ self.to4 = self._make_encoder_block(256, 512)
194
+
195
+ self.deconv_for_decoder = nn.Sequential(
196
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
197
+ nn.LeakyReLU(0.2),
198
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
199
+ nn.LeakyReLU(0.2),
200
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
201
+ nn.LeakyReLU(0.2),
202
+ nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
203
+ nn.Tanh(),
204
+ )
205
+
206
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
207
+
208
+ self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
209
+ nn.LeakyReLU(0.2, True),
210
+ tunnel4,
211
+ nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
212
+ nn.PixelShuffle(2),
213
+ nn.LeakyReLU(0.2, True)
214
+ ) # 64
215
+
216
+ depth = 2
217
+ tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
218
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
219
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
220
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
221
+ ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
222
+ tunnel3 = nn.Sequential(*tunnel)
223
+
224
+ self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
225
+ nn.LeakyReLU(0.2, True),
226
+ tunnel3,
227
+ nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
228
+ nn.PixelShuffle(2),
229
+ nn.LeakyReLU(0.2, True)
230
+ ) # 128
231
+
232
+ tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
233
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
234
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
235
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
236
+ ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
237
+ tunnel2 = nn.Sequential(*tunnel)
238
+
239
+ self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
240
+ nn.LeakyReLU(0.2, True),
241
+ tunnel2,
242
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
243
+ nn.PixelShuffle(2),
244
+ nn.LeakyReLU(0.2, True)
245
+ )
246
+
247
+ tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
248
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
249
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
250
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
251
+ ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
252
+ tunnel1 = nn.Sequential(*tunnel)
253
+
254
+ self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
255
+ nn.LeakyReLU(0.2, True),
256
+ tunnel1,
257
+ nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
258
+ nn.PixelShuffle(2),
259
+ nn.LeakyReLU(0.2, True)
260
+ )
261
+
262
+ self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
263
+
264
+
265
+ def _make_encoder_block(self, inplanes, planes):
266
+ return nn.Sequential(
267
+ nn.Conv2d(inplanes, planes, 3, 2, 1),
268
+ nn.LeakyReLU(0.2),
269
+ nn.Conv2d(planes, planes, 3, 1, 1),
270
+ nn.LeakyReLU(0.2),
271
+ )
272
+
273
+ def _make_encoder_block_first(self, inplanes, planes):
274
+ return nn.Sequential(
275
+ nn.Conv2d(inplanes, planes, 3, 1, 1),
276
+ nn.LeakyReLU(0.2),
277
+ nn.Conv2d(planes, planes, 3, 1, 1),
278
+ nn.LeakyReLU(0.2),
279
+ )
280
+
281
+ def forward(self, sketch, sketch_feat):
282
+
283
+ x0 = self.to0(sketch)
284
+ x1 = self.to1(x0)
285
+ x2 = self.to2(x1)
286
+ x3 = self.to3(x2)
287
+ x4 = self.to4(x3)
288
+
289
+ sketch_feat = self.feature_conv(sketch_feat)
290
+
291
+ out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
292
+
293
+
294
+
295
+
296
+ x = self.tunnel3(torch.cat([out, x3], 1))
297
+ x = self.tunnel2(torch.cat([x, x2], 1))
298
+ x = self.tunnel1(torch.cat([x, x1], 1))
299
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
300
+
301
+ decoder_output = self.deconv_for_decoder(out)
302
+
303
+ return x, decoder_output
304
+ '''
305
+ class Colorizer(nn.Module):
306
+ def __init__(self, extractor_path = 'model/model.pth'):
307
+ super(Colorizer, self).__init__()
308
+
309
+ self.generator = Generator()
310
+ self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
311
+
312
+ def extractor_eval(self):
313
+ for param in self.extractor.parameters():
314
+ param.requires_grad = False
315
+
316
+ def extractor_train(self):
317
+ for param in extractor.parameters():
318
+ param.requires_grad = True
319
+
320
+ def forward(self, x, extractor_grad = False):
321
+
322
+ if extractor_grad:
323
+ features = self.extractor(x[:, 0:1])
324
+ else:
325
+ with torch.no_grad():
326
+ features = self.extractor(x[:, 0:1]).detach()
327
+
328
+ fake, guide = self.generator(x, features)
329
+
330
+ return fake, guide
331
+ '''
332
+
333
+ class Colorizer(nn.Module):
334
+ def __init__(self, generator_model, extractor_model):
335
+ super(Colorizer, self).__init__()
336
+
337
+ self.generator = generator_model
338
+ self.extractor = extractor_model
339
+
340
+ def load_generator_weights(self, gen_weights):
341
+ self.generator.load_state_dict(gen_weights)
342
+
343
+ def load_extractor_weights(self, ext_weights):
344
+ self.extractor.load_state_dict(ext_weights)
345
+
346
+ def extractor_eval(self):
347
+ for param in self.extractor.parameters():
348
+ param.requires_grad = False
349
+ self.extractor.eval()
350
+
351
+ def extractor_train(self):
352
+ for param in extractor.parameters():
353
+ param.requires_grad = True
354
+ self.extractor.train()
355
+
356
+ def forward(self, x, extractor_grad = False):
357
+
358
+ if extractor_grad:
359
+ features = self.extractor(x[:, 0:1])
360
+ else:
361
+ with torch.no_grad():
362
+ features = self.extractor(x[:, 0:1]).detach()
363
+
364
+ fake, guide = self.generator(x, features)
365
+
366
+ return fake, guide
367
+
368
+ class Discriminator(nn.Module):
369
+ def __init__(self, ndf=64):
370
+ super(Discriminator, self).__init__()
371
+
372
+ self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
373
+ nn.LeakyReLU(0.2, True),
374
+ SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
375
+ nn.LeakyReLU(0.2, True),
376
+
377
+
378
+
379
+
380
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
381
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
382
+ SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
383
+ nn.LeakyReLU(0.2, True),
384
+
385
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
386
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
387
+ SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
388
+ nn.LeakyReLU(0.2, True),
389
+
390
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
391
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
392
+ SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
393
+ nn.LeakyReLU(0.2, True),
394
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
395
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
396
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
397
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
398
+ nn.AdaptiveAvgPool2d((1, 1))
399
+ )
400
+
401
+ self.out = nn.Linear(512, 1)
402
+
403
+ def forward(self, color):
404
+ x = self.feed(color)
405
+
406
+ out = self.out(x.view(color.size(0), -1))
407
+ return out
408
+
409
+ class Content(nn.Module):
410
+ def __init__(self, path):
411
+ super(Content, self).__init__()
412
+ vgg16 = M.vgg16()
413
+ vgg16.load_state_dict(torch.load(path))
414
+ vgg16.features = nn.Sequential(
415
+ *list(vgg16.features.children())[:9]
416
+ )
417
+ self.model = vgg16.features
418
+ self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
419
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
420
+
421
+ def forward(self, images):
422
+ return self.model((images.mul(0.5) - self.mean) / self.std)
denoising/models/.gitkeep ADDED
File without changes
denoising/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Different utilities such as orthogonalization of weights, initialization of
3
+ loggers, etc
4
+
5
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
6
+
7
+ This program is free software: you can use, modify and/or
8
+ redistribute it under the terms of the GNU General Public
9
+ License as published by the Free Software Foundation, either
10
+ version 3 of the License, or (at your option) any later
11
+ version. You should have received a copy of this license along
12
+ this program. If not, see <http://www.gnu.org/licenses/>.
13
+ """
14
+ import numpy as np
15
+ import cv2
16
+
17
+
18
+ def variable_to_cv2_image(varim):
19
+ r"""Converts a torch.autograd.Variable to an OpenCV image
20
+
21
+ Args:
22
+ varim: a torch.autograd.Variable
23
+ """
24
+ nchannels = varim.size()[1]
25
+ if nchannels == 1:
26
+ res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
27
+ elif nchannels == 3:
28
+ res = varim.data.cpu().numpy()[0]
29
+ res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
30
+ res = (res*255.).clip(0, 255).astype(np.uint8)
31
+ else:
32
+ raise Exception('Number of color channels not supported')
33
+ return res
34
+
35
+
36
+ def normalize(data):
37
+ return np.float32(data/255.)
38
+
39
+ def remove_dataparallel_wrapper(state_dict):
40
+ r"""Converts a DataParallel model to a normal one by removing the "module."
41
+ wrapper in the module dictionary
42
+
43
+ Args:
44
+ state_dict: a torch.nn.DataParallel state dictionary
45
+ """
46
+ from collections import OrderedDict
47
+
48
+ new_state_dict = OrderedDict()
49
+ for k, vl in state_dict.items():
50
+ name = k[7:] # remove 'module.' of DataParallel
51
+ new_state_dict[name] = vl
52
+
53
+ return new_state_dict
54
+
55
+ def is_rgb(im_path):
56
+ r""" Returns True if the image in im_path is an RGB image
57
+ """
58
+ from skimage.io import imread
59
+ rgb = False
60
+ im = imread(im_path)
61
+ if (len(im.shape) == 3):
62
+ if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
63
+ rgb = True
64
+ print("rgb: {}".format(rgb))
65
+ print("im shape: {}".format(im.shape))
66
+ return rgb
drawing.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ import base64
4
+ import random
5
+ import string
6
+ import shutil
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file, Response
11
+ from flask_wtf import FlaskForm
12
+ from wtforms import StringField, FileField, BooleanField, DecimalField
13
+ from wtforms.validators import DataRequired
14
+ from flask import after_this_request
15
+
16
+ from model.models import Colorizer, Generator
17
+ from model.extractor import get_seresnext_extractor
18
+ from utils.xdog import XDoGSketcher
19
+ from utils.utils import open_json
20
+ from denoising.denoiser import FFDNetDenoiser
21
+ from inference import process_image_with_hint
22
+ from utils.utils import resize_pad
23
+ from utils.dataset_utils import get_sketch
24
+
25
+ def generate_id(size=25, chars=string.ascii_letters + string.digits):
26
+ return ''.join(random.SystemRandom().choice(chars) for _ in range(size))
27
+
28
+ def generate_unique_id(current_ids = set()):
29
+ id_t = generate_id()
30
+ while id_t in current_ids:
31
+ id_t = generate_id()
32
+
33
+ current_ids.add(id_t)
34
+
35
+ return id_t
36
+
37
+ app = Flask(__name__)
38
+ app.config.update(dict(
39
+ SECRET_KEY="lol kek",
40
+ WTF_CSRF_SECRET_KEY="cheburek"
41
+ ))
42
+
43
+ if torch.cuda.is_available():
44
+ device = 'cuda'
45
+ else:
46
+ device = 'cpu'
47
+
48
+ colorizer = torch.jit.load('./model/colorizer.zip', map_location=torch.device(device))
49
+
50
+ sketcher = XDoGSketcher()
51
+ xdog_config = open_json('configs/xdog_config.json')
52
+ for key in xdog_config.keys():
53
+ if key in sketcher.params:
54
+ sketcher.params[key] = xdog_config[key]
55
+
56
+ denoiser = FFDNetDenoiser(device)
57
+
58
+ color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True, 'auto_hint' : False, 'ignore_gray' : False, 'denoiser' : denoiser, 'denoiser_sigma' : 25}
59
+
60
+
61
+ class SubmitForm(FlaskForm):
62
+ file = FileField(validators=[DataRequired(), ])
63
+
64
+ def preprocess_image(file_id, ext):
65
+ directory_path = os.path.join('static', 'temp_images', file_id)
66
+ original_path = os.path.join(directory_path, 'original') + ext
67
+ original_image = plt.imread(original_path)
68
+
69
+ resized_image, _ = resize_pad(original_image)
70
+ resized_image = denoiser.get_denoised_image(resized_image, 25)
71
+ bw, dfm = get_sketch(resized_image, sketcher, True)
72
+
73
+ resized_name = 'resized_' + str(resized_image.shape[0]) + '_' + str(resized_image.shape[1]) + '.png'
74
+ plt.imsave(os.path.join(directory_path, resized_name), resized_image)
75
+ plt.imsave(os.path.join(directory_path, 'bw.png'), bw, cmap = 'gray')
76
+ plt.imsave(os.path.join(directory_path, 'dfm.png'), dfm, cmap = 'gray')
77
+ os.remove(original_path)
78
+
79
+ empty_hint = np.zeros((resized_image.shape[0], resized_image.shape[1], 4), dtype = np.float32)
80
+ plt.imsave(os.path.join(directory_path, 'hint.png'), empty_hint)
81
+
82
+ @app.route('/', methods=['GET', 'POST'])
83
+ def upload():
84
+ form = SubmitForm()
85
+ if form.validate_on_submit():
86
+ input_data = form.file.data
87
+
88
+ _, ext = os.path.splitext(input_data.filename)
89
+
90
+ if ext not in ('.jpg', '.png', '.jpeg'):
91
+ return abort(400)
92
+
93
+ file_id = generate_unique_id()
94
+ directory = os.path.join('static', 'temp_images', file_id)
95
+ original_filename = os.path.join(directory, 'original') + ext
96
+
97
+ try :
98
+ os.mkdir(directory)
99
+ input_data.save(original_filename)
100
+
101
+ preprocess_image(file_id, ext)
102
+
103
+ return redirect(f'/draw/{file_id}')
104
+
105
+ except :
106
+ print('Failed to colorize')
107
+ if os.path.exists(directory):
108
+ shutil.rmtree(directory)
109
+ return abort(400)
110
+
111
+
112
+ return render_template("upload.html", form = form)
113
+
114
+ @app.route('/img/<file_id>')
115
+ def show_image(file_id):
116
+ if not os.path.exists(os.path.join('static', 'temp_images', str(file_id))):
117
+ abort(404)
118
+ return f'<img src="/static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}">'
119
+
120
+ def colorize_image(file_id):
121
+ directory_path = os.path.join('static', 'temp_images', file_id)
122
+
123
+ bw = plt.imread(os.path.join(directory_path, 'bw.png'))[..., :1]
124
+ dfm = plt.imread(os.path.join(directory_path, 'dfm.png'))[..., :1]
125
+ hint = plt.imread(os.path.join(directory_path, 'hint.png'))
126
+
127
+ return process_image_with_hint(bw, dfm, hint, color_args)
128
+
129
+ @app.route('/colorize', methods=['POST'])
130
+ def colorize():
131
+
132
+ file_id = request.form['save_file_id']
133
+ file_id = file_id[file_id.rfind('/') + 1:]
134
+
135
+ img_data = request.form['save_image']
136
+ img_data = img_data[img_data.find(',') + 1:]
137
+
138
+ directory_path = os.path.join('static', 'temp_images', file_id)
139
+
140
+ with open(os.path.join(directory_path, 'hint.png'), "wb") as im:
141
+ im.write(base64.decodestring(str.encode(img_data)))
142
+
143
+ result = colorize_image(file_id)
144
+
145
+ plt.imsave(os.path.join(directory_path, 'colorized.png'), result)
146
+
147
+ src_path = f'../static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}'
148
+
149
+ return src_path
150
+
151
+ @app.route('/draw/<file_id>', methods=['GET', 'POST'])
152
+ def paintapp(file_id):
153
+ if request.method == 'GET':
154
+
155
+ directory_path = os.path.join('static', 'temp_images', str(file_id))
156
+ if not os.path.exists(directory_path):
157
+ abort(404)
158
+
159
+ resized_name = [x for x in os.listdir(directory_path) if x.startswith('resized_')][0]
160
+
161
+ split = os.path.splitext(resized_name)[0].split('_')
162
+ width = int(split[2])
163
+ height = int(split[1])
164
+
165
+ return render_template("drawing.html", height = height, width = width, img_path = os.path.join('temp_images', str(file_id), resized_name))
inference.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from utils.dataset_utils import get_sketch
5
+ from utils.utils import resize_pad, generate_mask, extract_cbr, create_cbz, sorted_alphanumeric, subfolder_image_search, remove_folder
6
+ from torchvision.transforms import ToTensor
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import argparse
10
+ from model.models import Colorizer, Generator
11
+ from model.extractor import get_seresnext_extractor
12
+ from utils.xdog import XDoGSketcher
13
+ from utils.utils import open_json
14
+ import sys
15
+ from denoising.denoiser import FFDNetDenoiser
16
+
17
+ def colorize_without_hint(inp, color_args):
18
+ i_hint = torch.zeros(1, 4, inp.shape[2], inp.shape[3]).float().to(color_args['device'])
19
+
20
+ with torch.no_grad():
21
+ fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
22
+
23
+ if color_args['auto_hint']:
24
+ mask = generate_mask(fake_color.shape[2], fake_color.shape[3], full = False, prob = 1, sigma = color_args['auto_hint_sigma']).unsqueeze(0)
25
+ mask = mask.to(color_args['device'])
26
+
27
+
28
+ if color_args['ignore_gray']:
29
+ diff1 = torch.abs(fake_color[:, 0] - fake_color[:, 1])
30
+ diff2 = torch.abs(fake_color[:, 0] - fake_color[:, 2])
31
+ diff3 = torch.abs(fake_color[:, 1] - fake_color[:, 2])
32
+ mask = ((mask + ((diff1 + diff2 + diff3) > 60 / 255).float().unsqueeze(1)) == 2).float()
33
+
34
+
35
+ i_hint = torch.cat([fake_color * mask, mask], 1)
36
+
37
+ with torch.no_grad():
38
+ fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
39
+
40
+ return fake_color
41
+
42
+
43
+ def process_image(image, color_args, to_tensor = ToTensor()):
44
+ image, pad = resize_pad(image)
45
+
46
+ if color_args['denoiser'] is not None:
47
+ image = color_args['denoiser'].get_denoised_image(image, color_args['denoiser_sigma'])
48
+
49
+ bw, dfm = get_sketch(image, color_args['sketcher'], color_args['dfm'])
50
+
51
+ bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
52
+ dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
53
+
54
+ output = colorize_without_hint(torch.cat([bw, dfm], 1), color_args)
55
+ result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
56
+
57
+ if pad[0] != 0:
58
+ result = result[:-pad[0]]
59
+ if pad[1] != 0:
60
+ result = result[:, :-pad[1]]
61
+
62
+ return result
63
+
64
+ def colorize_with_hint(inp, color_args):
65
+ with torch.no_grad():
66
+ fake_color, _ = color_args['colorizer'](inp)
67
+
68
+ return fake_color
69
+
70
+ def process_image_with_hint(bw, dfm, hint, color_args, to_tensor = ToTensor()):
71
+ bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
72
+ dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
73
+
74
+ i_hint = (torch.FloatTensor(hint[..., :3]).permute(2, 0, 1) - 0.5) / 0.5
75
+ mask = torch.FloatTensor(hint[..., 3:]).permute(2, 0, 1)
76
+ i_hint = torch.cat([i_hint * mask, mask], 0).unsqueeze(0).to(color_args['device'])
77
+
78
+ output = colorize_with_hint(torch.cat([bw, dfm, i_hint], 1), color_args)
79
+ result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
80
+
81
+ return result
82
+
83
+ def colorize_single_image(file_path, save_path, color_args):
84
+ try:
85
+ image = plt.imread(file_path)
86
+
87
+ colorization = process_image(image, color_args)
88
+
89
+ plt.imsave(save_path, colorization)
90
+
91
+ return True
92
+ except KeyboardInterrupt:
93
+ sys.exit(0)
94
+ except:
95
+ print('Failed to colorize {}'.format(file_path))
96
+ return False
97
+
98
+ def colorize_images(source_path, target_path, color_args):
99
+ images = os.listdir(source_path)
100
+
101
+ for image_name in images:
102
+ file_path = os.path.join(source_path, image_name)
103
+
104
+ name, ext = os.path.splitext(image_name)
105
+ if (ext != '.png'):
106
+ image_name = name + '.png'
107
+
108
+ save_path = os.path.join(target_path, image_name)
109
+ colorize_single_image(file_path, save_path, color_args)
110
+
111
+ def colorize_cbr(file_path, color_args):
112
+ file_name = os.path.splitext(os.path.basename(file_path))[0]
113
+ temp_path = 'temp_colorization'
114
+
115
+ if not os.path.exists(temp_path):
116
+ os.makedirs(temp_path)
117
+ extract_cbr(file_path, temp_path)
118
+
119
+ images = subfolder_image_search(temp_path)
120
+
121
+ result_images = []
122
+ for image_path in images:
123
+ save_path = image_path
124
+
125
+ path, ext = os.path.splitext(save_path)
126
+ if (ext != '.png'):
127
+ save_path = path + '.png'
128
+
129
+ res_flag = colorize_single_image(image_path, save_path, color_args)
130
+
131
+ result_images.append(save_path if res_flag else image_path)
132
+
133
+
134
+ result_name = os.path.join(os.path.dirname(file_path), file_name + '_colorized.cbz')
135
+
136
+ create_cbz(result_name, result_images)
137
+
138
+ remove_folder(temp_path)
139
+
140
+ return result_name
141
+
142
+ def parse_args():
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument("-p", "--path", required=True)
145
+ parser.add_argument("-gen", "--generator", default = 'model/generator.pth')
146
+ parser.add_argument("-ext", "--extractor", default = 'model/extractor.pth')
147
+ parser.add_argument("-s", "--sigma", type = float, default = 0.003)
148
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
149
+ parser.add_argument('-ah', '--auto', dest = 'autohint', action = 'store_true')
150
+ parser.add_argument('-ig', '--ignore_grey', dest = 'ignore', action = 'store_true')
151
+ parser.add_argument('-nd', '--no_denoise', dest = 'denoiser', action = 'store_false')
152
+ parser.add_argument("-ds", "--denoiser_sigma", type = int, default = 25)
153
+ parser.set_defaults(gpu = False)
154
+ parser.set_defaults(autohint = False)
155
+ parser.set_defaults(ignore = False)
156
+ parser.set_defaults(denoiser = True)
157
+ args = parser.parse_args()
158
+
159
+ return args
160
+
161
+
162
+ if __name__ == "__main__":
163
+
164
+ args = parse_args()
165
+
166
+ if args.gpu:
167
+ device = 'cuda'
168
+ else:
169
+ device = 'cpu'
170
+
171
+ generator = Generator()
172
+ generator.load_state_dict(torch.load(args.generator))
173
+
174
+ extractor = get_seresnext_extractor()
175
+ extractor.load_state_dict(torch.load(args.extractor))
176
+
177
+ colorizer = Colorizer(generator, extractor)
178
+ colorizer = colorizer.eval().to(device)
179
+
180
+ sketcher = XDoGSketcher()
181
+ xdog_config = open_json('configs/xdog_config.json')
182
+ for key in xdog_config.keys():
183
+ if key in sketcher.params:
184
+ sketcher.params[key] = xdog_config[key]
185
+
186
+ denoiser = None
187
+ if args.denoiser:
188
+ denoiser = FFDNetDenoiser(device, args.denoiser_sigma)
189
+
190
+ color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'auto_hint':args.autohint, 'auto_hint_sigma':args.sigma,\
191
+ 'ignore_gray':args.ignore, 'device':device, 'dfm' : True, 'denoiser':denoiser, 'denoiser_sigma' : args.denoiser_sigma}
192
+
193
+
194
+ if os.path.isdir(args.path):
195
+ colorization_path = os.path.join(args.path, 'colorization')
196
+ if not os.path.exists(colorization_path):
197
+ os.makedirs(colorization_path)
198
+
199
+ colorize_images(args.path, colorization_path, color_args)
200
+
201
+ elif os.path.isfile(args.path):
202
+
203
+ split = os.path.splitext(args.path)
204
+
205
+ if split[1].lower() in ('.cbr', '.cbz', '.rar', '.zip'):
206
+ colorize_cbr(args.path, color_args)
207
+ elif split[1].lower() in ('.jpg', '.png', ',jpeg'):
208
+ new_image_path = split[0] + '_colorized' + '.png'
209
+
210
+ colorize_single_image(args.path, new_image_path, color_args)
211
+ else:
212
+ print('Wrong format')
213
+ else:
214
+ print('Wrong path')
215
+
model/__pycache__/extractor.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
model/__pycache__/extractor.cpython-36.pyc ADDED
Binary file (3.9 kB). View file
 
model/__pycache__/extractor.cpython-39.pyc ADDED
Binary file (3.95 kB). View file
 
model/__pycache__/models.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
model/__pycache__/models.cpython-36.pyc ADDED
Binary file (14 kB). View file
 
model/__pycache__/models.cpython-39.pyc ADDED
Binary file (13.5 kB). View file
 
model/extractor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee3c59f02ac8c59298fd9b819fa33d2efa168847e15e4be39b35c286f7c18607
3
+ size 6340842
model/extractor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ '''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
6
+
7
+ # Pretrained version
8
+ class Selayer(nn.Module):
9
+ def __init__(self, inplanes):
10
+ super(Selayer, self).__init__()
11
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
13
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x):
18
+ out = self.global_avgpool(x)
19
+ out = self.conv1(out)
20
+ out = self.relu(out)
21
+ out = self.conv2(out)
22
+ out = self.sigmoid(out)
23
+
24
+ return x * out
25
+
26
+
27
+ class BottleneckX_Origin(nn.Module):
28
+ expansion = 4
29
+
30
+ def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
31
+ super(BottleneckX_Origin, self).__init__()
32
+ self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(planes * 2)
34
+
35
+ self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
36
+ padding=1, groups=cardinality, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(planes * 2)
38
+
39
+ self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes * 4)
41
+
42
+ self.selayer = Selayer(planes * 4)
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.downsample = downsample
46
+ self.stride = stride
47
+
48
+ def forward(self, x):
49
+ residual = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv3(out)
60
+ out = self.bn3(out)
61
+
62
+ out = self.selayer(out)
63
+
64
+ if self.downsample is not None:
65
+ residual = self.downsample(x)
66
+
67
+ out += residual
68
+ out = self.relu(out)
69
+
70
+ return out
71
+
72
+ class SEResNeXt_extractor(nn.Module):
73
+ def __init__(self, block, layers, input_channels=3, cardinality=32):
74
+ super(SEResNeXt_extractor, self).__init__()
75
+ self.cardinality = cardinality
76
+ self.inplanes = 64
77
+ self.input_channels = input_channels
78
+
79
+ self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
80
+ bias=False)
81
+ self.bn1 = nn.BatchNorm2d(64)
82
+ self.relu = nn.ReLU(inplace=True)
83
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
84
+
85
+ self.layer1 = self._make_layer(block, 64, layers[0])
86
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91
+ m.weight.data.normal_(0, math.sqrt(2. / n))
92
+ if m.bias is not None:
93
+ m.bias.data.zero_()
94
+ elif isinstance(m, nn.BatchNorm2d):
95
+ m.weight.data.fill_(1)
96
+ m.bias.data.zero_()
97
+
98
+ def _make_layer(self, block, planes, blocks, stride=1):
99
+ downsample = None
100
+ if stride != 1 or self.inplanes != planes * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.inplanes, planes * block.expansion,
103
+ kernel_size=1, stride=stride, bias=False),
104
+ nn.BatchNorm2d(planes * block.expansion),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
109
+ self.inplanes = planes * block.expansion
110
+ for i in range(1, blocks):
111
+ layers.append(block(self.inplanes, planes, self.cardinality))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+ x = self.conv1(x)
117
+ x = self.bn1(x)
118
+ x = self.relu(x)
119
+ x = self.maxpool(x)
120
+
121
+ x = self.layer1(x)
122
+ x = self.layer2(x)
123
+
124
+ return x
125
+
126
+ def get_seresnext_extractor():
127
+ return SEResNeXt_extractor(BottleneckX_Origin, [3, 4, 6, 3], 1)
model/models.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as M
5
+ import math
6
+ from torch import Tensor
7
+ from torch.nn import Parameter
8
+
9
+ '''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
10
+
11
+ def l2normalize(v, eps=1e-12):
12
+ return v / (v.norm() + eps)
13
+
14
+
15
+ class SpectralNorm(nn.Module):
16
+ def __init__(self, module, name='weight', power_iterations=1):
17
+ super(SpectralNorm, self).__init__()
18
+ self.module = module
19
+ self.name = name
20
+ self.power_iterations = power_iterations
21
+ if not self._made_params():
22
+ self._make_params()
23
+
24
+ def _update_u_v(self):
25
+ u = getattr(self.module, self.name + "_u")
26
+ v = getattr(self.module, self.name + "_v")
27
+ w = getattr(self.module, self.name + "_bar")
28
+
29
+ height = w.data.shape[0]
30
+ for _ in range(self.power_iterations):
31
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
32
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
33
+
34
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
35
+ sigma = u.dot(w.view(height, -1).mv(v))
36
+ setattr(self.module, self.name, w / sigma.expand_as(w))
37
+
38
+ def _made_params(self):
39
+ try:
40
+ u = getattr(self.module, self.name + "_u")
41
+ v = getattr(self.module, self.name + "_v")
42
+ w = getattr(self.module, self.name + "_bar")
43
+ return True
44
+ except AttributeError:
45
+ return False
46
+
47
+
48
+ def _make_params(self):
49
+ w = getattr(self.module, self.name)
50
+ height = w.data.shape[0]
51
+ width = w.view(height, -1).data.shape[1]
52
+
53
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
54
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
55
+ u.data = l2normalize(u.data)
56
+ v.data = l2normalize(v.data)
57
+ w_bar = Parameter(w.data)
58
+
59
+ del self.module._parameters[self.name]
60
+
61
+ self.module.register_parameter(self.name + "_u", u)
62
+ self.module.register_parameter(self.name + "_v", v)
63
+ self.module.register_parameter(self.name + "_bar", w_bar)
64
+
65
+
66
+ def forward(self, *args):
67
+ self._update_u_v()
68
+ return self.module.forward(*args)
69
+
70
+ class Selayer(nn.Module):
71
+ def __init__(self, inplanes):
72
+ super(Selayer, self).__init__()
73
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
74
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
75
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.sigmoid = nn.Sigmoid()
78
+
79
+ def forward(self, x):
80
+ out = self.global_avgpool(x)
81
+ out = self.conv1(out)
82
+ out = self.relu(out)
83
+ out = self.conv2(out)
84
+ out = self.sigmoid(out)
85
+
86
+ return x * out
87
+
88
+ class SelayerSpectr(nn.Module):
89
+ def __init__(self, inplanes):
90
+ super(SelayerSpectr, self).__init__()
91
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
92
+ self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
93
+ self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
94
+ self.relu = nn.ReLU(inplace=True)
95
+ self.sigmoid = nn.Sigmoid()
96
+
97
+ def forward(self, x):
98
+ out = self.global_avgpool(x)
99
+ out = self.conv1(out)
100
+ out = self.relu(out)
101
+ out = self.conv2(out)
102
+ out = self.sigmoid(out)
103
+
104
+ return x * out
105
+
106
+ class ResNeXtBottleneck(nn.Module):
107
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
108
+ super(ResNeXtBottleneck, self).__init__()
109
+ D = out_channels // 2
110
+ self.out_channels = out_channels
111
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
112
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
113
+ groups=cardinality,
114
+ bias=False)
115
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
116
+ self.shortcut = nn.Sequential()
117
+ if stride != 1:
118
+ self.shortcut.add_module('shortcut',
119
+ nn.AvgPool2d(2, stride=2))
120
+
121
+ self.selayer = Selayer(out_channels)
122
+
123
+ def forward(self, x):
124
+ bottleneck = self.conv_reduce.forward(x)
125
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
126
+ bottleneck = self.conv_conv.forward(bottleneck)
127
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
128
+ bottleneck = self.conv_expand.forward(bottleneck)
129
+ bottleneck = self.selayer(bottleneck)
130
+
131
+ x = self.shortcut.forward(x)
132
+ return x + bottleneck
133
+
134
+ class SpectrResNeXtBottleneck(nn.Module):
135
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
136
+ super(SpectrResNeXtBottleneck, self).__init__()
137
+ D = out_channels // 2
138
+ self.out_channels = out_channels
139
+ self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
140
+ self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
141
+ groups=cardinality,
142
+ bias=False))
143
+ self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
144
+ self.shortcut = nn.Sequential()
145
+ if stride != 1:
146
+ self.shortcut.add_module('shortcut',
147
+ nn.AvgPool2d(2, stride=2))
148
+
149
+ self.selayer = SelayerSpectr(out_channels)
150
+
151
+ def forward(self, x):
152
+ bottleneck = self.conv_reduce.forward(x)
153
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
154
+ bottleneck = self.conv_conv.forward(bottleneck)
155
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
156
+ bottleneck = self.conv_expand.forward(bottleneck)
157
+ bottleneck = self.selayer(bottleneck)
158
+
159
+ x = self.shortcut.forward(x)
160
+ return x + bottleneck
161
+
162
+ class FeatureConv(nn.Module):
163
+ def __init__(self, input_dim=512, output_dim=512):
164
+ super(FeatureConv, self).__init__()
165
+
166
+ no_bn = True
167
+
168
+ seq = []
169
+ seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
170
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
171
+ seq.append(nn.ReLU(inplace=True))
172
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
173
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
174
+ seq.append(nn.ReLU(inplace=True))
175
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
176
+ seq.append(nn.ReLU(inplace=True))
177
+
178
+ self.network = nn.Sequential(*seq)
179
+
180
+ def forward(self, x):
181
+ return self.network(x)
182
+
183
+ class Generator(nn.Module):
184
+ def __init__(self, ngf=64):
185
+ super(Generator, self).__init__()
186
+
187
+ self.feature_conv = FeatureConv()
188
+
189
+ self.to0 = self._make_encoder_block_first(6, 32)
190
+ self.to1 = self._make_encoder_block(32, 64)
191
+ self.to2 = self._make_encoder_block(64, 128)
192
+ self.to3 = self._make_encoder_block(128, 256)
193
+ self.to4 = self._make_encoder_block(256, 512)
194
+
195
+ self.deconv_for_decoder = nn.Sequential(
196
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
197
+ nn.LeakyReLU(0.2),
198
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
199
+ nn.LeakyReLU(0.2),
200
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
201
+ nn.LeakyReLU(0.2),
202
+ nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
203
+ nn.Tanh(),
204
+ )
205
+
206
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
207
+
208
+ self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
209
+ nn.LeakyReLU(0.2, True),
210
+ tunnel4,
211
+ nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
212
+ nn.PixelShuffle(2),
213
+ nn.LeakyReLU(0.2, True)
214
+ ) # 64
215
+
216
+ depth = 2
217
+ tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
218
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
219
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
220
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
221
+ ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
222
+ tunnel3 = nn.Sequential(*tunnel)
223
+
224
+ self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
225
+ nn.LeakyReLU(0.2, True),
226
+ tunnel3,
227
+ nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
228
+ nn.PixelShuffle(2),
229
+ nn.LeakyReLU(0.2, True)
230
+ ) # 128
231
+
232
+ tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
233
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
234
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
235
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
236
+ ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
237
+ tunnel2 = nn.Sequential(*tunnel)
238
+
239
+ self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
240
+ nn.LeakyReLU(0.2, True),
241
+ tunnel2,
242
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
243
+ nn.PixelShuffle(2),
244
+ nn.LeakyReLU(0.2, True)
245
+ )
246
+
247
+ tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
248
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
249
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
250
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
251
+ ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
252
+ tunnel1 = nn.Sequential(*tunnel)
253
+
254
+ self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
255
+ nn.LeakyReLU(0.2, True),
256
+ tunnel1,
257
+ nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
258
+ nn.PixelShuffle(2),
259
+ nn.LeakyReLU(0.2, True)
260
+ )
261
+
262
+ self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
263
+
264
+
265
+ def _make_encoder_block(self, inplanes, planes):
266
+ return nn.Sequential(
267
+ nn.Conv2d(inplanes, planes, 3, 2, 1),
268
+ nn.LeakyReLU(0.2),
269
+ nn.Conv2d(planes, planes, 3, 1, 1),
270
+ nn.LeakyReLU(0.2),
271
+ )
272
+
273
+ def _make_encoder_block_first(self, inplanes, planes):
274
+ return nn.Sequential(
275
+ nn.Conv2d(inplanes, planes, 3, 1, 1),
276
+ nn.LeakyReLU(0.2),
277
+ nn.Conv2d(planes, planes, 3, 1, 1),
278
+ nn.LeakyReLU(0.2),
279
+ )
280
+
281
+ def forward(self, sketch, sketch_feat):
282
+
283
+ x0 = self.to0(sketch)
284
+ x1 = self.to1(x0)
285
+ x2 = self.to2(x1)
286
+ x3 = self.to3(x2)
287
+ x4 = self.to4(x3)
288
+
289
+ sketch_feat = self.feature_conv(sketch_feat)
290
+
291
+ out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
292
+
293
+
294
+
295
+
296
+ x = self.tunnel3(torch.cat([out, x3], 1))
297
+ x = self.tunnel2(torch.cat([x, x2], 1))
298
+ x = self.tunnel1(torch.cat([x, x1], 1))
299
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
300
+
301
+ decoder_output = self.deconv_for_decoder(out)
302
+
303
+ return x, decoder_output
304
+ '''
305
+ class Colorizer(nn.Module):
306
+ def __init__(self, extractor_path = 'model/model.pth'):
307
+ super(Colorizer, self).__init__()
308
+
309
+ self.generator = Generator()
310
+ self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
311
+
312
+ def extractor_eval(self):
313
+ for param in self.extractor.parameters():
314
+ param.requires_grad = False
315
+
316
+ def extractor_train(self):
317
+ for param in extractor.parameters():
318
+ param.requires_grad = True
319
+
320
+ def forward(self, x, extractor_grad = False):
321
+
322
+ if extractor_grad:
323
+ features = self.extractor(x[:, 0:1])
324
+ else:
325
+ with torch.no_grad():
326
+ features = self.extractor(x[:, 0:1]).detach()
327
+
328
+ fake, guide = self.generator(x, features)
329
+
330
+ return fake, guide
331
+ '''
332
+
333
+ class Colorizer(nn.Module):
334
+ def __init__(self, generator_model, extractor_model):
335
+ super(Colorizer, self).__init__()
336
+
337
+ self.generator = generator_model
338
+ self.extractor = extractor_model
339
+
340
+ def load_generator_weights(self, gen_weights):
341
+ self.generator.load_state_dict(gen_weights)
342
+
343
+ def load_extractor_weights(self, ext_weights):
344
+ self.extractor.load_state_dict(ext_weights)
345
+
346
+ def extractor_eval(self):
347
+ for param in self.extractor.parameters():
348
+ param.requires_grad = False
349
+ self.extractor.eval()
350
+
351
+ def extractor_train(self):
352
+ for param in extractor.parameters():
353
+ param.requires_grad = True
354
+ self.extractor.train()
355
+
356
+ def forward(self, x, extractor_grad = False):
357
+
358
+ if extractor_grad:
359
+ features = self.extractor(x[:, 0:1])
360
+ else:
361
+ with torch.no_grad():
362
+ features = self.extractor(x[:, 0:1]).detach()
363
+
364
+ fake, guide = self.generator(x, features)
365
+
366
+ return fake, guide
367
+
368
+ class Discriminator(nn.Module):
369
+ def __init__(self, ndf=64):
370
+ super(Discriminator, self).__init__()
371
+
372
+ self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
373
+ nn.LeakyReLU(0.2, True),
374
+ SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
375
+ nn.LeakyReLU(0.2, True),
376
+
377
+
378
+
379
+
380
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
381
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
382
+ SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
383
+ nn.LeakyReLU(0.2, True),
384
+
385
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
386
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
387
+ SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
388
+ nn.LeakyReLU(0.2, True),
389
+
390
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
391
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
392
+ SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
393
+ nn.LeakyReLU(0.2, True),
394
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
395
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
396
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
397
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
398
+ nn.AdaptiveAvgPool2d((1, 1))
399
+ )
400
+
401
+ self.out = nn.Linear(512, 1)
402
+
403
+ def forward(self, color):
404
+ x = self.feed(color)
405
+
406
+ out = self.out(x.view(color.size(0), -1))
407
+ return out
408
+
409
+ class Content(nn.Module):
410
+ def __init__(self, path):
411
+ super(Content, self).__init__()
412
+ vgg16 = M.vgg16()
413
+ vgg16.load_state_dict(torch.load(path))
414
+ vgg16.features = nn.Sequential(
415
+ *list(vgg16.features.children())[:9]
416
+ )
417
+ self.model = vgg16.features
418
+ self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
419
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
420
+
421
+ def forward(self, images):
422
+ return self.model((images.mul(0.5) - self.mean) / self.std)
model/vgg16-397923af.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0
3
+ size 553433881
readme.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ UPD. See the [improved version](https://github.com/qweasdd/manga-colorization-v2).
2
+
3
+ # Automatic colorization
4
+
5
+ 1. Download [generator](https://drive.google.com/file/d/1Oo6ycphJ3sUOpDCDoG29NA5pbhQVCevY/view?usp=sharing), [extractor](https://drive.google.com/file/d/12cbNyJcCa1zI2EBz6nea3BXl21Fm73Bt/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put generator and extractor weights in `model` and denoiser weights in `denoising/models`.
6
+ 2. To colorize image, folder of images, `.cbz` or `.cbr` file, use the following command:
7
+ ```
8
+ $ python inference.py -p "path to file or folder"
9
+ ```
10
+
11
+ # Manual colorization with color hints
12
+
13
+ 1. Download [colorizer](https://drive.google.com/file/d/1BERrMl9e7cKsk9m2L0q1yO4k7blNhEWC/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put colorizer weights in `model` and denoiser weights in `denoising/models`.
14
+ 2. Run gunicorn server with:
15
+ ```
16
+ $ ./run_drawing.sh
17
+ ```
18
+ 3. Open `localhost:5000` with a browser.
19
+
20
+ # References
21
+ 1. Extractor weights are taken from https://github.com/blandocs/Tag2Pix/releases/download/release/model.pth
22
+ 2. Denoiser weights are taken from http://www.ipol.im/pub/art/2019/231.
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ flask==1.1.1
2
+ gunicorn
3
+ numpy==1.16.6
4
+ flask_wtf==0.14.3
5
+ matplotlib==3.1.1
6
+ opencv-python==4.1.2.30
7
+ snowy
8
+ scipy==1.3.3
9
+ scikit-image==0.15.0
10
+ patool==1.12
run_drawing.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ gunicorn --worker-class gevent --timeout 150 -w 1 -b 0.0.0.0:5000 drawing:app
static/js/draw.js ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var canvas = document.getElementById('draw_canvas');
2
+ var ctx = canvas.getContext('2d');
3
+ var canvasWidth = canvas.width;
4
+ var canvasHeight = canvas.height;
5
+ var prevX, prevY;
6
+
7
+ var result_canvas = document.getElementById('result');
8
+ var result_ctx = result_canvas.getContext('2d');
9
+ result_canvas.width = canvas.width;
10
+ result_canvas.height = canvas.height;
11
+
12
+ var color_indicator = document.getElementById('color');
13
+ ctx.fillStyle = 'black';
14
+ color_indicator.value = '#000000';
15
+
16
+ var cur_id = window.location.pathname.substring(window.location.pathname.lastIndexOf('/') + 1);
17
+
18
+ function getRandomInt(max) {
19
+ return Math.floor(Math.random() * Math.floor(max));
20
+ }
21
+
22
+ var init_hint = new Image();
23
+ init_hint.addEventListener('load', function() {
24
+ ctx.drawImage(init_hint, 0, 0);
25
+ });
26
+ init_hint.src = '../static/temp_images/' + cur_id + '/hint.png?' + getRandomInt(100000).toString();
27
+
28
+ result_canvas.addEventListener('load', function(e) {
29
+ var img = new Image();
30
+ img.addEventListener('load', function() {
31
+ ctx.drawImage(img, 0, 0);
32
+ }, false);
33
+ console.log(window.location.pathname);
34
+ })
35
+
36
+
37
+ canvas.onload = function (e) {
38
+ var img = new Image();
39
+ img.addEventListener('load', function() {
40
+ ctx.drawImage(img, 0, 0);
41
+ }, false);
42
+ console.log(window.location.pathname);
43
+ //img.src = ;
44
+ }
45
+
46
+ function reset() {
47
+ ctx.clearRect(0, 0, canvasWidth, canvasHeight);
48
+ }
49
+
50
+ function getMousePos(canvas, evt) {
51
+ var rect = canvas.getBoundingClientRect();
52
+ return {
53
+ x: (evt.clientX - rect.left) / (rect.right - rect.left) * canvas.width,
54
+ y: (evt.clientY - rect.top) / (rect.bottom - rect.top) * canvas.height
55
+ };
56
+ }
57
+
58
+ function colorize() {
59
+ var file_id = document.location.pathname;
60
+ var image = canvas.toDataURL();
61
+
62
+ $.post("/colorize", { save_file_id: file_id, save_image: image}).done(function( data ) {
63
+ //console.log(document.location.origin + '/img/' + data)
64
+ //window.open(document.location.origin + '/img/' + data, '_blank');
65
+ //result.src = data;
66
+ var img = new Image();
67
+ img.addEventListener('load', function() {
68
+ result_ctx.drawImage(img, 0, 0);
69
+ }, false);
70
+ img.src = data;
71
+ });
72
+ }
73
+
74
+ canvas.addEventListener('mousedown', function(e) {
75
+ var mousePos = getMousePos(canvas, e);
76
+ if (e.button == 0) {
77
+ ctx.fillRect(mousePos['x'], mousePos['y'], 1, 1);
78
+ }
79
+
80
+ if (e.button == 2) {
81
+ prevX = mousePos['x']
82
+ prevY = mousePos['y']
83
+ }
84
+
85
+ })
86
+
87
+ canvas.addEventListener('mouseup', function(e) {
88
+ if (e.button == 2) {
89
+ var mousePos = getMousePos(canvas, e);
90
+ var diff_width = mousePos['x'] - prevX;
91
+ var diff_height = mousePos['y'] - prevY;
92
+
93
+ ctx.clearRect(prevX, prevY, diff_width, diff_height);
94
+ }
95
+ })
96
+
97
+
98
+ canvas.addEventListener('contextmenu', function(evt) {
99
+ evt.preventDefault();
100
+ })
101
+
102
+ function color(color_value){
103
+ ctx.fillStyle = color_value;
104
+ color_indicator.value = color_value;
105
+ }
106
+
107
+ color_indicator.oninput = function() {
108
+ color(this.value);
109
+ }
110
+
111
+ function rgbToHex(rgb){
112
+ return '#' + ((rgb[0] << 16) | (rgb[1] << 8) | rgb[2]).toString(16);
113
+ };
114
+
115
+ result_canvas.addEventListener('click', function(e) {
116
+ if (e.button == 0) {
117
+ var cur_pixel = result_ctx.getImageData(e.offsetX, e.offsetY, 1, 1).data;
118
+ color(rgbToHex(cur_pixel));
119
+ }
120
+ })
static/temp_images/.gitkeep ADDED
File without changes
templates/drawing.html ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Colorization app</title>
6
+ <style>
7
+ .back{
8
+ height:100%;
9
+ width:100%;
10
+ position: absolute;
11
+ background-color: yellow;
12
+ top:0px;
13
+ padding: 10px;
14
+ }
15
+
16
+ #draw_canvas {
17
+ height: 200%;
18
+ border: 3px solid black;
19
+ background-image: linear-gradient(rgba(60,60,60,.85), rgba(60,60,60,.85)), url(../static/{{img_path}});
20
+ background-color: #c7b39b;
21
+ background-size: 100%;
22
+ }
23
+ </style>
24
+
25
+ </head>
26
+ <body>
27
+ <div align="left" style ="margin : 5px; margin-left : 0px">
28
+ <input type="button" onclick="location.href='/';" value="Home" />
29
+ </div>
30
+ <p style="margin: 0; padding: 0">
31
+ Left click - colorize, right hold - remove with rectangle, left click on result - use corresponding color.
32
+
33
+ </p>
34
+ <hr style ="margin: 0; padding: 0">
35
+
36
+ <p><table>
37
+ <tr>
38
+ <td>
39
+ <table>
40
+ <tr>
41
+ <td><button style="background-color: #000000; height: 20px; width: 20px;" onclick="color('#000000')"></button>
42
+ <td><button style="background-color: #B0171F; height: 20px; width: 20px;" onclick="color('#B0171F')"></button>
43
+ </tr>
44
+ <tr>
45
+ <td><button style="background-color: #DA70D6; height: 20px; width: 20px;" onclick="color('#DA70D6')"></button>
46
+ <td><button style="background-color: #8A2BE2; height: 20px; width: 20px;" onclick="color('#8A2BE2')"></button>
47
+ </tr>
48
+ <tr>
49
+ <td><button style="background-color: #0000FF; height: 20px; width: 20px;" onclick="color('#0000FF')"></button>
50
+ <td><button style="background-color: #4876FF; height: 20px; width: 20px;" onclick="color('#4876FF')"></button>
51
+ </tr>
52
+ <tr>
53
+ <td><button style="background-color: #CAE1FF; height: 20px; width: 20px;" onclick="color('#CAE1FF')"></button>
54
+ <td><button style="background-color: #6E7B8B; height: 20px; width: 20px;" onclick="color('#6E7B8B')"></button>
55
+ </tr>
56
+ <tr>
57
+ <td><button style="background-color: #00C78C; height: 20px; width: 20px;" onclick="color('#00C78C')"></button>
58
+ <td><button style="background-color: #00FA9A; height: 20px; width: 20px;" onclick="color('#00FA9A')"></button>
59
+ </tr>
60
+ <tr>
61
+ <td><button style="background-color: #00FF7F; height: 20px; width: 20px;" onclick="color('#00FF7F')"></button>
62
+ <td><button style="background-color: #00C957; height: 20px; width: 20px;" onclick="color('#00C957')"></button>
63
+ </tr>
64
+ <tr>
65
+ <td><button style="background-color: #3D9140; height: 20px; width: 20px;" onclick="color('#3D9140')"></button>
66
+ <td><button style="background-color: #32CD32; height: 20px; width: 20px;" onclick="color('#32CD32')"></button>
67
+ </tr>
68
+ <tr>
69
+ <td><button style="background-color: #00EE00; height: 20px; width: 20px;" onclick="color('#00EE00')"></button>
70
+
71
+ <td><button style="background-color: #008B00; height: 20px; width: 20px;" onclick="color('#008B00')"></button>
72
+ </tr>
73
+ <tr>
74
+ <td><button style="background-color: #76EE00; height: 20px; width: 20px;" onclick="color('#76EE00')"></button>
75
+
76
+ <td><button style="background-color: #CAFF70; height: 20px; width: 20px;" onclick="color('#CAFF70')"></button>
77
+ </tr>
78
+ <tr>
79
+ <td><button style="background-color: #FFFF00; height: 20px; width: 20px;" onclick="color('#FFFF00')"></button>
80
+
81
+ <td><button style="background-color: #CDCD00; height: 20px; width: 20px;" onclick="color('#CDCD00')"></button>
82
+ </tr>
83
+ <tr>
84
+ <td><button style="background-color: #FFF68F; height: 20px; width: 20px;" onclick="color('#FFF68F')"></button>
85
+
86
+ <td><button style="background-color: #FFFACD; height: 20px; width: 20px;" onclick="color('#FFFACD')"></button>
87
+ </tr>
88
+ <tr>
89
+ <td><button style="background-color: #FFEC8B; height: 20px; width: 20px;" onclick="color('#FFEC8B')"></button>
90
+
91
+ <td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
92
+ </tr>
93
+ <tr>
94
+ <td><button style="background-color: #F5DEB3; height: 20px; width: 20px;" onclick="color('#F5DEB3')"></button>
95
+
96
+ <td><button style="background-color: #FFE4B5; height: 20px; width: 20px;" onclick="color('#FFE4B5')"></button>
97
+ </tr>
98
+ <tr>
99
+ <td><button style="background-color: #EECFA1; height: 20px; width: 20px;" onclick="color('#EECFA1')"></button>
100
+
101
+ <td><button style="background-color: #FF9912; height: 20px; width: 20px;" onclick="color('#FF9912')"></button>
102
+ </tr>
103
+ <tr>
104
+ <td><button style="background-color: #8E388E; height: 20px; width: 20px;" onclick="color('#8E388E')"></button>
105
+
106
+ <td><button style="background-color: #7171C6; height: 20px; width: 20px;" onclick="color('#7171C6')"></button>
107
+ </tr>
108
+
109
+ <tr>
110
+ <td><button style="background-color: #7D9EC0; height: 20px; width: 20px;" onclick="color('#7D9EC0')"></button>
111
+
112
+ <td><button style="background-color: #388E8E; height: 20px; width: 20px;" onclick="color('#388E8E')"></button>
113
+
114
+ </tr>
115
+
116
+ <tr>
117
+ <td><button style="background-color: #71C671; height: 20px; width: 20px;" onclick="color('#71C671')"></button>
118
+
119
+ <td><button style="background-color: #8E8E38; height: 20px; width: 20px;" onclick="color('#8E8E38')"></button>
120
+ </tr>
121
+ <tr>
122
+ <td><button style="background-color: #C5C1AA; height: 20px; width: 20px;" onclick="color('#C5C1AA')"></button>
123
+
124
+ <td><button style="background-color: #C67171; height: 20px; width: 20px;" onclick="color('#C67171')"></button>
125
+ </tr>
126
+ <tr>
127
+ <td><button style="background-color: #555555; height: 20px; width: 20px;" onclick="color('#555555')"></button>
128
+ <td><button style="background-color: #848484; height: 20px; width: 20px;" onclick="color('#848484')"></button>
129
+ </tr>
130
+ <tr>
131
+ <td><button style="background-color: #FFFFFF; height: 20px; width: 20px;" onclick="color('#FFFFFF')"></button>
132
+ <td><button style="background-color: #EE0000; height: 20px; width: 20px;" onclick="color('#EE0000')"></button>
133
+ </tr>
134
+ <tr>
135
+ <td><button style="background-color: #FF4040; height: 20px; width: 20px;" onclick="color('#FF4040')"></button>
136
+ <td><button style="background-color: #EE6363; height: 20px; width: 20px;" onclick="color('#EE6363')"></button>
137
+ </tr>
138
+ <tr>
139
+ <td><button style="background-color: #FFC1C1; height: 20px; width: 20px;" onclick="color('#FFC1C1')"></button>
140
+ <td><button style="background-color: #FF7256; height: 20px; width: 20px;" onclick="color('#FF7256')"></button>
141
+ </tr>
142
+ <tr>
143
+ <td><button style="background-color: #FF4500; height: 20px; width: 20px;" onclick="color('#FF4500')"></button>
144
+ <td><button style="background-color: #F4A460; height: 20px; width: 20px;" onclick="color('#F4A460')"></button>
145
+ </tr>
146
+ <tr>
147
+ <td><button style="background-color: #FF8000; height: 20px; width: 20px;" onclick="color('FF8000')"></button>
148
+ <td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
149
+ </tr>
150
+ <tr>
151
+ <td><button style="background-color: #8B864E; height: 20px; width: 20px;" onclick="color('#8B864E')"></button>
152
+ <td><button style="background-color: #9ACD32; height: 20px; width: 20px;" onclick="color('#9ACD32')"></button>
153
+ </tr>
154
+ <tr>
155
+ <td><button style="background-color: #66CD00; height: 20px; width: 20px;" onclick="color('#66CD00')"></button>
156
+ <td><button style="background-color: #BDFCC9; height: 20px; width: 20px;" onclick="color('#BDFCC9')"></button>
157
+ </tr>
158
+ <tr>
159
+ <td><button style="background-color: #76EEC6; height: 20px; width: 20px;" onclick="color('#76EEC6')"></button>
160
+ <td><button style="background-color: #40E0D0; height: 20px; width: 20px;" onclick="color('#40E0D0')"></button>
161
+ </tr>
162
+ <tr>
163
+ <td><button style="background-color: #E0EEEE; height: 20px; width: 20px;" onclick="color('#E0EEEE')"></button>
164
+ <td><button style="background-color: #98F5FF; height: 20px; width: 20px;" onclick="color('#98F5FF')"></button>
165
+ </tr>
166
+ <tr>
167
+ <td><button style="background-color: #33A1C9; height: 20px; width: 20px;" onclick="color('#33A1C9')"></button>
168
+ <td><button style="background-color: #F0F8FF; height: 20px; width: 20px;" onclick="color('#F0F8FF')"></button>
169
+ </tr>
170
+ <tr>
171
+ <td><button style="background-color: #4682B4; height: 20px; width: 20px;" onclick="color('#4682B4')"></button>
172
+ <td><button style="background-color: #C6E2FF; height: 20px; width: 20px;" onclick="color('#C6E2FF')"></button>
173
+ </tr>
174
+ <tr>
175
+ <td><button style="background-color: #9B30FF; height: 20px; width: 20px;" onclick="color('#9B30FF')"></button>
176
+ <td><button style="background-color: #EE82EE; height: 20px; width: 20px;" onclick="color('#EE82EE')"></button>
177
+ </tr>
178
+ <tr>
179
+ <td><button style="background-color: #FFC0CB; height: 20px; width: 20px;" onclick="color('#FFC0CB')"></button>
180
+ <td><button style="background-color: #7CFC00; height: 20px; width: 20px;" onclick="color('#7CFC00')"></button>
181
+ </tr>
182
+ <tr>
183
+ <input type="color" id="color">
184
+ </tr>
185
+ </table>
186
+ </td>
187
+ <td>
188
+ <div style="width: 1150px; height: 800px; overflow: auto"><canvas align = "center" id="draw_canvas" width="{{width}}" height="{{height}}"></canvas></div>
189
+ </td>
190
+ <td>
191
+ <canvas id='result'></canvas>
192
+ </td>
193
+ </tr>
194
+ </table></p>
195
+
196
+ <button style="height: 20px; width: 80px" onclick="reset()">Clear</button>
197
+
198
+ <button style="height: 20px; width: 80px" onclick="colorize()" >Colorize</button>
199
+
200
+
201
+
202
+ <script src="http://code.jquery.com/jquery-1.8.3.js"></script>
203
+ <script src="/static/js/draw.js">
204
+ </script>
205
+ </body>
206
+ </html>
templates/submit.html ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <form method="POST" action="/" enctype="multipart/form-data" target="_blank">
2
+ {{ form.hidden_tag() }}
3
+ {{ form.file.label }} {{ form.file(size=20) }}
4
+ {{ form.denoise.label }} {{ form.denoise(size=5) }}
5
+ {{ form.denoise_sigma.label }} {{ form.denoise_sigma(size=5) }}
6
+ {{ form.autohint.label }} {{ form.autohint(size=5) }}
7
+ {{ form.autohint_sigma.label }} {{ form.autohint_sigma(size=5) }}
8
+ {{ form.ignore_gray.label }} {{ form.ignore_gray(size=5) }}
9
+ <input type="submit" value="Colorize">
10
+
11
+ </form>
templates/upload.html ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <style>
2
+ form{
3
+ position:fixed;
4
+ top:50%;
5
+ left:45%;
6
+ width:1250px;
7
+ }
8
+ </style>
9
+
10
+
11
+ <form id="form" method="POST" action="/" enctype="multipart/form-data">
12
+ {{ form.hidden_tag() }}
13
+ {{ form.file(size=20) }}
14
+ </form>
15
+
16
+ <script>
17
+ document.getElementById("file").onchange = function() {
18
+ document.getElementById("form").submit();
19
+ };
20
+ </script>
train.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import albumentations as albu
6
+ import argparse
7
+ import datetime
8
+
9
+ from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
10
+ from model.models import Colorizer, Generator, Content, Discriminator
11
+ from model.extractor import get_seresnext_extractor
12
+ from dataset.datasets import TrainDataset, FineTuningDataset
13
+ from PIL import Image
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("-p", "--path", required=True, help = "dataset path")
18
+ parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true')
19
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
20
+ parser.set_defaults(fine_tuning = False)
21
+ parser.set_defaults(gpu = False)
22
+ args = parser.parse_args()
23
+
24
+ return args
25
+
26
+ def get_transforms():
27
+ return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
28
+
29
+ def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
30
+ train_dataset = TrainDataset(data_path, transforms)
31
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
32
+
33
+ if fine_tuning:
34
+ finetuning_dataset = FineTuningDataset(data_path, transforms)
35
+ finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
36
+
37
+ return train_dataloader, finetuning_dataloader
38
+
39
+ def get_models(device):
40
+ generator = Generator()
41
+ extractor = get_seresnext_extractor()
42
+ colorizer = Colorizer(generator, extractor)
43
+
44
+ colorizer.extractor_eval()
45
+ colorizer = colorizer.to(device)
46
+
47
+ discriminator = Discriminator().to(device)
48
+
49
+ content = Content('model/vgg16-397923af.pth').eval().to(device)
50
+ for param in content.parameters():
51
+ param.requires_grad = False
52
+
53
+ return colorizer, discriminator, content
54
+
55
+ def set_weights(colorizer, discriminator):
56
+ colorizer.generator.apply(weights_init)
57
+ colorizer.load_extractor_weights(torch.load('model/extractor.pth'))
58
+
59
+ discriminator.apply(weights_init_spectr)
60
+
61
+ def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
62
+ sim_loss_full = dist_loss(main_output, real_image)
63
+ sim_loss_guide = dist_loss(guide_output, real_image)
64
+
65
+ adv_loss = class_loss(disc_output, true_labels)
66
+
67
+ content_loss = content_dist_loss(content_gen, content_true)
68
+
69
+ sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
70
+
71
+ return sum_loss
72
+
73
+ def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
74
+ optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
75
+ optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9))
76
+
77
+ return optimizerG, optimizerD
78
+
79
+ def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
80
+ for p in discriminator.parameters():
81
+ p.requires_grad = False
82
+ for p in colorizer.generator.parameters():
83
+ p.requires_grad = True
84
+
85
+ colorizer.generator.zero_grad()
86
+
87
+ bw, color, hint, dfm = inputs
88
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
89
+
90
+ fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
91
+
92
+ logits_fake = discriminator(fake)
93
+ y_real = torch.ones((bw.size(0), 1), device = device)
94
+
95
+ content_fake = content(fake)
96
+ with torch.no_grad():
97
+ content_true = content(color)
98
+
99
+ generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
100
+
101
+ if white_penalty:
102
+ mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
103
+ white_zones = mask * (fake + 1) / 2
104
+ white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
105
+
106
+ generator_loss += white_penalty
107
+
108
+ generator_loss.backward()
109
+
110
+ optimizer.step()
111
+
112
+ return generator_loss.item()
113
+
114
+ def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
115
+
116
+ for p in discriminator.parameters():
117
+ p.requires_grad = True
118
+ for p in colorizer.generator.parameters():
119
+ p.requires_grad = False
120
+
121
+ discriminator.zero_grad()
122
+
123
+ bw, color, hint, dfm = inputs
124
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
125
+
126
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
127
+
128
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
129
+
130
+ with torch.no_grad():
131
+ fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
132
+ fake_color.detach()
133
+
134
+ logits_fake = discriminator(fake_color)
135
+ logits_real = discriminator(color)
136
+
137
+ fake_loss = loss_function(logits_fake, y_fake)
138
+ real_loss = loss_function(logits_real, y_real)
139
+
140
+ discriminator_loss = real_loss + fake_loss
141
+
142
+ discriminator_loss.backward()
143
+ optimizer.step()
144
+
145
+ return discriminator_loss.item()
146
+
147
+ def decrease_lr(optimizer, rate):
148
+ for group in optimizer.param_groups:
149
+ group['lr'] /= rate
150
+
151
+ def set_lr(optimizer, value):
152
+ for group in optimizer.param_groups:
153
+ group['lr'] = value
154
+
155
+ def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
156
+ colorizer.generator.train()
157
+ discriminator.train()
158
+
159
+ disc_step = True
160
+
161
+ for epoch in range(epochs):
162
+ if (epoch == lr_decay_epoch):
163
+ decrease_lr(colorizer_optimizer, 10)
164
+ decrease_lr(discriminator_optimizer, 10)
165
+
166
+ sum_disc_loss = 0
167
+ sum_gen_loss = 0
168
+
169
+ for n, inputs in enumerate(dataloader):
170
+ if n % 5 == 0:
171
+ print(datetime.datetime.now().time())
172
+ print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
173
+
174
+
175
+ if disc_step:
176
+ step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
177
+ sum_disc_loss += step_loss
178
+ else:
179
+ step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
180
+ sum_gen_loss += step_loss
181
+
182
+ disc_step = disc_step ^ True
183
+
184
+
185
+ print(datetime.datetime.now().time())
186
+ print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
187
+
188
+
189
+ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
190
+
191
+ for p in discriminator.parameters():
192
+ p.requires_grad = True
193
+ for p in colorizer.generator.parameters():
194
+ p.requires_grad = False
195
+
196
+ for cur_disc_step in range(5):
197
+ discriminator.zero_grad()
198
+
199
+ bw, dfm, color_for_real = data_iter.next()
200
+ bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
201
+
202
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
203
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
204
+
205
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
206
+
207
+ with torch.no_grad():
208
+ fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
209
+ fake_color_manga.detach()
210
+
211
+ logits_fake = discriminator(fake_color_manga)
212
+ logits_real = discriminator(color_for_real)
213
+
214
+ fake_loss = loss_function(logits_fake, y_fake)
215
+ real_loss = loss_function(logits_real, y_real)
216
+ discriminator_loss = real_loss + fake_loss
217
+
218
+ discriminator_loss.backward()
219
+ disc_optimizer.step()
220
+
221
+
222
+ for p in discriminator.parameters():
223
+ p.requires_grad = False
224
+ for p in colorizer.generator.parameters():
225
+ p.requires_grad = True
226
+
227
+ colorizer.generator.zero_grad()
228
+
229
+ bw, dfm, _ = data_iter.next()
230
+ bw, dfm = bw.to(device), dfm.to(device)
231
+
232
+ y_real = torch.ones((bw.size(0), 1), device = device)
233
+
234
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
235
+
236
+ fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
237
+
238
+ logits_fake = discriminator(fake_manga)
239
+ adv_loss = loss_function(logits_fake, y_real)
240
+
241
+ generator_loss = adv_loss
242
+
243
+ generator_loss.backward()
244
+ gen_optimizer.step()
245
+
246
+
247
+
248
+ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
249
+ colorizer.generator.train()
250
+ discriminator.train()
251
+
252
+ disc_step = True
253
+
254
+ for n, inputs in enumerate(dataloader):
255
+
256
+ if n == iterations:
257
+ return
258
+
259
+ if disc_step:
260
+ discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
261
+ else:
262
+ generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
263
+
264
+ disc_step = disc_step ^ True
265
+
266
+ if n % 10 == 5:
267
+ fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
268
+
269
+ if __name__ == '__main__':
270
+ args = parse_args()
271
+ config = open_json('configs/train_config.json')
272
+
273
+ if args.gpu:
274
+ device = 'cuda'
275
+ else:
276
+ device = 'cpu'
277
+
278
+ augmentations = get_transforms()
279
+
280
+ train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
281
+
282
+ colorizer, discriminator, content = get_models(device)
283
+ set_weights(colorizer, discriminator)
284
+
285
+ gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
286
+
287
+ train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
288
+
289
+ if args.fine_tuning:
290
+ set_lr(gen_optimizer, config["finetuning_generator_lr"])
291
+ fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
292
+
293
+ torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
train/bw/blackclover_cl268.png ADDED
train/bw/dfm_blackclover_cl268.png ADDED
train/color/blackclover_cl268.png ADDED
train/real_manga/blackclover_cl268.png ADDED
train/real_manga/dfm_blackclover_cl268.png ADDED
utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
utils/__pycache__/utils.cpython-36.pyc ADDED
Binary file (3.8 kB). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.81 kB). View file
 
utils/dataset_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import cv2
4
+ import snowy
5
+ import os
6
+
7
+
8
+ def get_resized_image(img, size):
9
+ if len(img.shape) == 2:
10
+ img = np.repeat(np.expand_dims(img, 2), 3, 2)
11
+
12
+ if (img.shape[0] < img.shape[1]):
13
+ height = img.shape[0]
14
+ ratio = height / size
15
+ width = int(np.ceil(img.shape[1] / ratio))
16
+ img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
17
+ else:
18
+ width = img.shape[1]
19
+ ratio = width / size
20
+ height = int(np.ceil(img.shape[0] / ratio))
21
+ img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
22
+
23
+ if (img.dtype == 'float32'):
24
+ np.clip(img, 0, 1, out = img)
25
+
26
+ return img
27
+
28
+
29
+ def get_sketch_image(img, sketcher, mult_val):
30
+
31
+ if mult_val:
32
+ sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
33
+ else:
34
+ sketch_image = sketcher.get_sketch_with_resize(img)
35
+
36
+ return sketch_image
37
+
38
+
39
+ def get_dfm_image(sketch):
40
+ dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
41
+ return dfm_image
42
+
43
+ def get_sketch(image, sketcher, dfm, mult = None):
44
+ sketch_image = get_sketch_image(image, sketcher, mult)
45
+
46
+ dfm_image = None
47
+
48
+ if dfm:
49
+ dfm_image = get_dfm_image(sketch_image)
50
+
51
+ sketch_image = (sketch_image * 255).astype('uint8')
52
+
53
+ if dfm:
54
+ dfm_image = (dfm_image * 255).astype('uint8')
55
+
56
+ return sketch_image, dfm_image
57
+
58
+ def get_sketches(image, sketcher, mult_list, dfm):
59
+ for mult in mult_list:
60
+ yield get_sketch(image, sketcher, dfm, mult)
61
+
62
+
63
+ def create_resized_dataset(source_path, target_path, side_size):
64
+ images = os.listdir(source_path)
65
+
66
+ for image_name in images:
67
+
68
+ new_image_name = image_name[:image_name.rfind('.')] + '.png'
69
+ new_path = os.path.join(target_path, new_image_name)
70
+
71
+ if not os.path.exists(new_path):
72
+ try:
73
+ image = cv2.imread(os.path.join(source_path, image_name))
74
+
75
+ if image is None:
76
+ raise Exception()
77
+
78
+ image = get_resized_image(image, side_size)
79
+
80
+ cv2.imwrite(new_path, image)
81
+ except:
82
+ print('Failed to process {}'.format(image_name))
83
+
84
+
85
+ def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
86
+
87
+ images = os.listdir(source_path)
88
+ for image_name in images:
89
+ try:
90
+ image = cv2.imread(os.path.join(source_path, image_name))
91
+
92
+ if image is None:
93
+ raise Exception()
94
+
95
+ for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
96
+ new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
97
+ cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
98
+
99
+ if dfm:
100
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
101
+ cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
102
+
103
+ except:
104
+ print('Failed to process {}'.format(image_name))
105
+
106
+
107
+ def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
108
+ images = os.listdir(source_path)
109
+
110
+ color_path = os.path.join(target_path, 'color')
111
+ sketch_path = os.path.join(target_path, 'bw')
112
+
113
+ if not os.path.exists(color_path):
114
+ os.makedirs(color_path)
115
+
116
+ if not os.path.exists(sketch_path):
117
+ os.makedirs(sketch_path)
118
+
119
+ for image_name in images:
120
+ new_image_name = image_name[:image_name.rfind('.')] + '.png'
121
+
122
+ try:
123
+ image = cv2.imread(os.path.join(source_path, image_name))
124
+
125
+ if image is None:
126
+ raise Exception()
127
+
128
+ resized_image = get_resized_image(image, side_size)
129
+ cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
130
+
131
+ for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
132
+ new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
133
+ cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
134
+
135
+ if dfm:
136
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
137
+ cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
138
+
139
+ except:
140
+ print('Failed to process {}'.format(image_name))
141
+
utils/utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import scipy.stats as stats
5
+ import cv2
6
+ import json
7
+ import patoolib
8
+ import re
9
+ from pathlib import Path
10
+ from shutil import rmtree
11
+
12
+ def weights_init(m):
13
+ classname = m.__class__.__name__
14
+ if classname.find('Conv2d') != -1:
15
+ nn.init.xavier_uniform_(m.weight.data)
16
+
17
+ def weights_init_spectr(m):
18
+ classname = m.__class__.__name__
19
+ if classname.find('Conv2d') != -1:
20
+ nn.init.xavier_uniform_(m.weight_bar.data)
21
+
22
+ def generate_mask(height, width, mu = 1, sigma = 0.0005, prob = 0.5, full = True, full_prob = 0.01):
23
+ X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
24
+
25
+ if full:
26
+ if (np.random.binomial(1, p = full_prob) == 1):
27
+ return torch.ones(1, height, width).float()
28
+
29
+ if np.random.binomial(1, p = prob) == 1:
30
+ mask = torch.rand(1, height, width).ge(X.rvs(1)[0]).float()
31
+ else:
32
+ mask = torch.zeros(1, height, width).float()
33
+
34
+ return mask
35
+
36
+ def resize_pad(img, size = 512):
37
+
38
+ if len(img.shape) == 2:
39
+ img = np.expand_dims(img, 2)
40
+
41
+ if img.shape[2] == 1:
42
+ img = np.repeat(img, 3, 2)
43
+
44
+ if img.shape[2] == 4:
45
+ img = img[:, :, :3]
46
+
47
+ pad = None
48
+
49
+ if (img.shape[0] < img.shape[1]):
50
+ height = img.shape[0]
51
+ ratio = height / size
52
+ width = int(np.ceil(img.shape[1] / ratio))
53
+ img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
54
+
55
+ new_width = width
56
+ while (new_width % 32 != 0):
57
+ new_width += 1
58
+
59
+ pad = (0, new_width - width)
60
+
61
+ img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
62
+ else:
63
+ width = img.shape[1]
64
+ ratio = width / size
65
+ height = int(np.ceil(img.shape[0] / ratio))
66
+ img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
67
+
68
+ new_height = height
69
+ while (new_height % 32 != 0):
70
+ new_height += 1
71
+
72
+ pad = (new_height - height, 0)
73
+
74
+ img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
75
+
76
+ if (img.dtype == 'float32'):
77
+ np.clip(img, 0, 1, out = img)
78
+
79
+ return img, pad
80
+
81
+ def open_json(file):
82
+ with open(file) as json_file:
83
+ data = json.load(json_file)
84
+
85
+ return data
86
+
87
+ def extract_cbr(file, out_dir):
88
+ patoolib.extract_archive(file, outdir = out_dir, verbosity = 1, interactive = False)
89
+
90
+ def create_cbz(file_path, files):
91
+ patoolib.create_archive(file_path, files, verbosity = 1, interactive = False)
92
+
93
+ def subfolder_image_search(start_folder):
94
+ return [x.as_posix() for x in Path(start_folder).rglob("*.[pPjJ][nNpP][gG]")]
95
+
96
+ def remove_folder(folder_path):
97
+ rmtree(folder_path)
98
+
99
+ def sorted_alphanumeric(data):
100
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
101
+ alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
102
+ return sorted(data, key=alphanum_key)
utils/xdog.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cv2 import resize, INTER_LANCZOS4, INTER_AREA
2
+ from skimage.color import rgb2gray
3
+ import numpy as np
4
+ from scipy.ndimage.filters import gaussian_filter
5
+ from skimage.filters import threshold_otsu
6
+ import matplotlib.pyplot as plt
7
+
8
+ class XDoGSketcher:
9
+
10
+ def __init__(self, gamma = 0.95, phi = 89.25, eps = -0.1, k = 8, sigma = 0.5, mult = 1):
11
+ self.params = {}
12
+ self.params['gamma'] = gamma
13
+ self.params['phi'] = phi
14
+ self.params['eps'] = eps
15
+ self.params['k'] = k
16
+ self.params['sigma'] = sigma
17
+
18
+ self.params['mult'] = mult
19
+
20
+ def _xdog(self, im, **transform_params):
21
+ # Source : https://github.com/CemalUnal/XDoG-Filter
22
+ # Reference : XDoG: An eXtended difference-of-Gaussians compendium including advanced image stylization
23
+ # Link : http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.365.151&rep=rep1&type=pdf
24
+
25
+ if im.shape[2] == 3:
26
+ im = rgb2gray(im)
27
+
28
+ imf1 = gaussian_filter(im, transform_params['sigma'])
29
+ imf2 = gaussian_filter(im, transform_params['sigma'] * transform_params['k'])
30
+ imdiff = imf1 - transform_params['gamma'] * imf2
31
+ imdiff = (imdiff < transform_params['eps']) * 1.0 \
32
+ + (imdiff >= transform_params['eps']) * (1.0 + np.tanh(transform_params['phi'] * imdiff))
33
+ imdiff -= imdiff.min()
34
+ imdiff /= imdiff.max()
35
+
36
+
37
+ th = threshold_otsu(imdiff)
38
+ imdiff = imdiff >= th
39
+
40
+ imdiff = imdiff.astype('float32')
41
+
42
+ return imdiff
43
+
44
+
45
+ def get_sketch(self, image, **kwargs):
46
+ current_params = self.params.copy()
47
+
48
+ for key in kwargs.keys():
49
+ if key in current_params.keys():
50
+ current_params[key] = kwargs[key]
51
+
52
+ result_image = self._xdog(image, **current_params)
53
+
54
+ return result_image
55
+
56
+ def get_sketch_with_resize(self, image, **kwargs):
57
+ if 'mult' in kwargs.keys():
58
+ mult = kwargs['mult']
59
+ else:
60
+ mult = self.params['mult']
61
+
62
+ temp_image = resize(image, (image.shape[1] * mult, image.shape[0] * mult), interpolation = INTER_LANCZOS4)
63
+ temp_image = self.get_sketch(temp_image, **kwargs)
64
+ image = resize(temp_image, (image.shape[1], image.shape[0]), interpolation = INTER_AREA)
65
+
66
+ return image
67
+
68
+
web.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file
2
+ from flask_wtf import FlaskForm
3
+ from wtforms import StringField, FileField, BooleanField, DecimalField
4
+ from wtforms.validators import DataRequired
5
+ from flask import after_this_request
6
+
7
+ import torch
8
+
9
+ import os
10
+ from model.models import Colorizer, Generator
11
+ from model.extractor import get_seresnext_extractor
12
+ from utils.xdog import XDoGSketcher
13
+ from utils.utils import open_json
14
+ from denoising.denoiser import FFDNetDenoiser
15
+ from datetime import datetime
16
+
17
+ from inference import colorize_single_image, colorize_images, colorize_cbr
18
+
19
+ if torch.cuda.is_available():
20
+ device = 'cuda'
21
+ else:
22
+ device = 'cpu'
23
+
24
+ generator = Generator()
25
+ generator.load_state_dict(torch.load('model/generator.pth'))
26
+
27
+ extractor = get_seresnext_extractor()
28
+ extractor.load_state_dict(torch.load('model/extractor.pth'))
29
+
30
+ colorizer = Colorizer(generator, extractor)
31
+ colorizer = colorizer.eval().to(device)
32
+
33
+ sketcher = XDoGSketcher()
34
+ xdog_config = open_json('configs/xdog_config.json')
35
+ for key in xdog_config.keys():
36
+ if key in sketcher.params:
37
+ sketcher.params[key] = xdog_config[key]
38
+
39
+ denoiser = FFDNetDenoiser(device)
40
+
41
+
42
+ app = Flask(__name__)
43
+ app.config.update(dict(
44
+ SECRET_KEY="lol kek",
45
+ WTF_CSRF_SECRET_KEY="cheburek"
46
+ ))
47
+
48
+ color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True}
49
+
50
+ class SubmitForm(FlaskForm):
51
+ file = FileField(validators=[DataRequired()])
52
+ denoise = BooleanField(default = 'checked')
53
+ denoise_sigma = DecimalField(label = 'Denoise sigma', validators=[DataRequired()], default = 25, places = None)
54
+ autohint = BooleanField(default = None)
55
+ autohint_sigma = DecimalField(label = 'Autohint sigma', validators=[DataRequired()], default= 0.0003, places = None)
56
+ ignore_gray = BooleanField(label = 'Ignore gray autohint', default = None)
57
+
58
+ @app.route('/img/<path>')
59
+ def show_image(path):
60
+ return f'<img src="/static/{path}">'
61
+
62
+ @app.route('/', methods=('GET', 'POST'))
63
+ def submit_data():
64
+ form = SubmitForm()
65
+ if form.validate_on_submit():
66
+
67
+ input_data = form.file.data
68
+
69
+ _, ext = os.path.splitext(input_data.filename)
70
+ filename = str(datetime.now()) + ext
71
+
72
+ input_data.save(filename)
73
+
74
+ color_args['auto_hint'] = form.autohint.data
75
+ color_args['auto_hint_sigma'] = float(form.autohint_sigma.data)
76
+ color_args['ignore_gray'] = form.ignore_gray.data
77
+ color_args['denoiser'] = None
78
+
79
+ if form.denoise.data:
80
+ color_args['denoiser'] = denoiser
81
+ color_args['denoiser_sigma'] = float(form.denoise_sigma.data)
82
+
83
+ if ext.lower() in ('.cbr', '.cbz', '.rar', '.zip'):
84
+ result_name = colorize_cbr(filename, color_args)
85
+ os.remove(filename)
86
+
87
+ @after_this_request
88
+ def remove_file(response):
89
+ try:
90
+ os.remove(result_name)
91
+ except Exception as error:
92
+ app.logger.error("Error removing or closing downloaded file handle", error)
93
+ return response
94
+
95
+ return send_file(result_name, mimetype='application/vnd.comicbook-rar', attachment_filename=result_name, as_attachment=True)
96
+
97
+ elif ext.lower() in ('.jpg', '.png', ',jpeg'):
98
+ random_name = str(datetime.now()) + '.png'
99
+ new_image_path = os.path.join('static', random_name)
100
+
101
+ colorize_single_image(filename, new_image_path, color_args)
102
+ os.remove(filename)
103
+
104
+ return redirect(f'/img/{random_name}')
105
+ else:
106
+ return 'Wrong format'
107
+
108
+ return render_template('submit.html', form=form)