Commit
·
7234ee2
1
Parent(s):
5748770
files added
Browse files- .gitattributes +2 -0
- app.py +122 -0
- color2edge.pth +3 -0
- dataloader.py +173 -0
- edge2color.pth +3 -0
- examples/img1.jpg +0 -0
- examples/img2.jpg +0 -0
- examples/img3.jpg +0 -0
- examples/img4.jpg +0 -0
- examples/ref1.jpg +0 -0
- examples/ref2.jpg +0 -0
- examples/ref3.jpg +0 -0
- examples/ref4.jpg +0 -0
- examples/sketch1.jpg +0 -0
- examples/sketch2.jpg +0 -0
- examples/sketch3.jpg +0 -0
- examples/sketch4.jpg +0 -0
- mymodels.py +460 -0
- requirements.txt +4 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
color2edge.pth filter=lfs diff=lfs merge=lfs -text
|
37 |
+
edge2color.pth filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# For plotting
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# For utilities
|
5 |
+
from timeit import default_timer as timer
|
6 |
+
|
7 |
+
# For conversion
|
8 |
+
import opencv_transforms.transforms as TF
|
9 |
+
import opencv_transforms.functional as FF
|
10 |
+
|
11 |
+
# For everything
|
12 |
+
import torch
|
13 |
+
|
14 |
+
# For our model
|
15 |
+
import mymodels
|
16 |
+
|
17 |
+
# For demo api
|
18 |
+
import gradio as gr
|
19 |
+
|
20 |
+
# To ignore warning
|
21 |
+
import warnings
|
22 |
+
|
23 |
+
warnings.simplefilter("ignore", UserWarning)
|
24 |
+
|
25 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
26 |
+
ncluster = 9
|
27 |
+
nc = 3 * (ncluster + 1)
|
28 |
+
netC2S = mymodels.Color2Sketch(pretrained=True).to(device)
|
29 |
+
netG = mymodels.Sketch2Color(nc=nc, pretrained=True).to(device)
|
30 |
+
transform = TF.Resize((512, 512))
|
31 |
+
|
32 |
+
|
33 |
+
def make_tensor(img):
|
34 |
+
img = FF.to_tensor(img)
|
35 |
+
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
36 |
+
return img
|
37 |
+
|
38 |
+
|
39 |
+
def predictC2S(img):
|
40 |
+
final_transform = TF.Resize((img.size[0], img.size[1]))
|
41 |
+
img = np.array(img)
|
42 |
+
img = transform(img)
|
43 |
+
img = make_tensor(img)
|
44 |
+
start_time = timer()
|
45 |
+
with torch.inference_mode():
|
46 |
+
img_edge = netC2S(img.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
|
47 |
+
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
|
48 |
+
img = FF.to_tensor(img_edge).permute(1, 2, 0).cpu().numpy()
|
49 |
+
end_time = timer()
|
50 |
+
img = final_transform(img)
|
51 |
+
return img, round(end_time - start_time, 3)
|
52 |
+
|
53 |
+
|
54 |
+
def predictS2C(img, ref):
|
55 |
+
final_transform = TF.Resize((img.size[0], img.size[1]))
|
56 |
+
img = np.array(img)
|
57 |
+
ref = np.array(ref)
|
58 |
+
ref = transform(ref)
|
59 |
+
img = transform(img)
|
60 |
+
img = make_tensor(img)
|
61 |
+
color_palette = mymodels.color_cluster(ref)
|
62 |
+
for i in range(0, len(color_palette)):
|
63 |
+
color = color_palette[i]
|
64 |
+
color_palette[i] = make_tensor(color)
|
65 |
+
start_time = timer()
|
66 |
+
with torch.inference_mode():
|
67 |
+
img_edge = netC2S(img.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
|
68 |
+
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
|
69 |
+
img = FF.to_tensor(img_edge)
|
70 |
+
input_tensor = torch.cat([img.cpu()] + color_palette, dim=0).to(device)
|
71 |
+
with torch.inference_mode():
|
72 |
+
fake = netG(input_tensor.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
|
73 |
+
end_time = timer()
|
74 |
+
fake = final_transform(fake)
|
75 |
+
return fake, round(end_time - start_time, 3)
|
76 |
+
|
77 |
+
|
78 |
+
example_list1 = [["./examples/img1.jpg", "./examples/ref1.jpg"],
|
79 |
+
["./examples/img2.jpg", "./examples/ref2.jpg"],
|
80 |
+
["./examples/img3.jpg", "./examples/ref3.jpg"],
|
81 |
+
["./examples/img4.jpg", "./examples/ref4.jpg"]]
|
82 |
+
example_list2 = [["./examples/sketch1.jpg"],
|
83 |
+
["./examples/sketch2.jpg"],
|
84 |
+
["./examples/sketch3.jpg"],
|
85 |
+
["./examples/sketch4.jpg"]]
|
86 |
+
|
87 |
+
with gr.Blocks() as demo:
|
88 |
+
gr.Markdown("# Color2Sketch & Sketch2Color")
|
89 |
+
with gr.Tab("Sketch To Color"):
|
90 |
+
gr.Markdown("### Enter the **Sketch** & **Reference** on the left side. You can use example list.")
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column():
|
93 |
+
input1 = [gr.inputs.Image(type="pil", label="Sketch"), gr.inputs.Image(type="pil", label="Reference")]
|
94 |
+
with gr.Row():
|
95 |
+
# Clear Button
|
96 |
+
gr.ClearButton(input1)
|
97 |
+
btn1 = gr.Button("Submit")
|
98 |
+
gr.Examples(examples=example_list1, inputs=input1)
|
99 |
+
with gr.Column():
|
100 |
+
output1 = [gr.inputs.Image(type="pil", label="Colored Sketch"), gr.Number(label="Prediction time (s)")]
|
101 |
+
with gr.Tab("Color To Sketch"):
|
102 |
+
gr.Markdown(
|
103 |
+
"### Enter the **Colored Sketch** on the left side. You can use example list.")
|
104 |
+
with gr.Row():
|
105 |
+
with gr.Column():
|
106 |
+
input2 = gr.inputs.Image(type="pil", label="Color Sketch")
|
107 |
+
with gr.Row():
|
108 |
+
# Clear Button
|
109 |
+
gr.ClearButton(input2)
|
110 |
+
btn2 = gr.Button("Submit")
|
111 |
+
gr.Examples(example_list2, inputs=input2)
|
112 |
+
with gr.Column():
|
113 |
+
output2 = [gr.inputs.Image(type="pil", label="Sketch"), gr.Number(label="Prediction time (s)")]
|
114 |
+
btn1.click(predictS2C, inputs=input1, outputs=output1)
|
115 |
+
btn2.click(predictC2S, inputs=input2, outputs=output2)
|
116 |
+
gr.Markdown("""
|
117 |
+
### The model is taken from [this GitHub Repo.](https://github.com/delta6189/Anime-Sketch-Colorizer)
|
118 |
+
|
119 |
+
Email : rajatsingh072002@gmail.com | My [GitHub Repo](https://github.com/Rajatsingh24/Anime-Sketch2Color-Color2Sketch)
|
120 |
+
""")
|
121 |
+
if __name__ == "__main__":
|
122 |
+
demo.launch(debug=False)
|
color2edge.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:922c4e109a0f7086d48e211156c7f6fbeff6b0393baecb606f22f44c7cda9877
|
3 |
+
size 254000447
|
dataloader.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import opencv_transforms.functional as FF
|
6 |
+
from torchvision import datasets
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
def color_cluster(img, nclusters=9):
|
10 |
+
"""
|
11 |
+
Apply K-means clustering to the input image
|
12 |
+
|
13 |
+
Args:
|
14 |
+
img: Numpy array which has shape of (H, W, C)
|
15 |
+
nclusters: # of clusters (default = 9)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
color_palette: list of 3D numpy arrays which have same shape of that of input image
|
19 |
+
e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
|
20 |
+
and each component is (256, 256, 3) numpy array.
|
21 |
+
|
22 |
+
Note:
|
23 |
+
K-means clustering algorithm is quite computaionally intensive.
|
24 |
+
Thus, before extracting dominant colors, the input images are resized to x0.25 size.
|
25 |
+
"""
|
26 |
+
img_size = img.shape
|
27 |
+
small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
|
28 |
+
sample = small_img.reshape((-1, 3))
|
29 |
+
sample = np.float32(sample)
|
30 |
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
31 |
+
flags = cv2.KMEANS_PP_CENTERS
|
32 |
+
|
33 |
+
_, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
|
34 |
+
centers = np.uint8(centers)
|
35 |
+
color_palette = []
|
36 |
+
|
37 |
+
for i in range(0, nclusters):
|
38 |
+
dominant_color = np.zeros(img_size, dtype='uint8')
|
39 |
+
dominant_color[:,:,:] = centers[i]
|
40 |
+
color_palette.append(dominant_color)
|
41 |
+
|
42 |
+
return color_palette
|
43 |
+
|
44 |
+
class PairImageFolder(datasets.ImageFolder):
|
45 |
+
"""
|
46 |
+
A generic data loader where the images are arranged in this way: ::
|
47 |
+
|
48 |
+
root/dog/xxx.png
|
49 |
+
root/dog/xxy.png
|
50 |
+
root/dog/xxz.png
|
51 |
+
|
52 |
+
root/cat/123.png
|
53 |
+
root/cat/nsdf3.png
|
54 |
+
root/cat/asd932_.png
|
55 |
+
|
56 |
+
This class works properly for paired image in form of [sketch, color_image]
|
57 |
+
|
58 |
+
Args:
|
59 |
+
root (string): Root directory path.
|
60 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
61 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
62 |
+
target_transform (callable, optional): A function/transform that takes in the
|
63 |
+
target and transforms it.
|
64 |
+
loader (callable, optional): A function to load an image given its path.
|
65 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
66 |
+
and check if the file is a valid file (used to check of corrupt files)
|
67 |
+
sketch_net: The network to convert color image to sketch image
|
68 |
+
ncluster: Number of clusters when extracting color palette.
|
69 |
+
|
70 |
+
Attributes:
|
71 |
+
classes (list): List of the class names.
|
72 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
73 |
+
imgs (list): List of (image path, class_index) tuples
|
74 |
+
|
75 |
+
Getitem:
|
76 |
+
img_edge: Edge image
|
77 |
+
img: Color Image
|
78 |
+
color_palette: Extracted color paltette
|
79 |
+
"""
|
80 |
+
def __init__(self, root, transform, sketch_net, ncluster):
|
81 |
+
super(PairImageFolder, self).__init__(root, transform)
|
82 |
+
self.ncluster = ncluster
|
83 |
+
self.sketch_net = sketch_net
|
84 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
85 |
+
|
86 |
+
def __getitem__(self, index):
|
87 |
+
path, label = self.imgs[index]
|
88 |
+
img = self.loader(path)
|
89 |
+
img = np.asarray(img)
|
90 |
+
img = img[:, 0:512, :]
|
91 |
+
img = self.transform(img)
|
92 |
+
color_palette = color_cluster(img, nclusters=self.ncluster)
|
93 |
+
img = self.make_tensor(img)
|
94 |
+
|
95 |
+
with torch.no_grad():
|
96 |
+
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
|
97 |
+
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
|
98 |
+
img_edge = FF.to_tensor(img_edge)
|
99 |
+
|
100 |
+
for i in range(0, len(color_palette)):
|
101 |
+
color = color_palette[i]
|
102 |
+
color_palette[i] = self.make_tensor(color)
|
103 |
+
|
104 |
+
return img_edge, img, color_palette
|
105 |
+
|
106 |
+
def make_tensor(self, img):
|
107 |
+
img = FF.to_tensor(img)
|
108 |
+
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
109 |
+
return img
|
110 |
+
|
111 |
+
class GetImageFolder(datasets.ImageFolder):
|
112 |
+
"""
|
113 |
+
A generic data loader where the images are arranged in this way: ::
|
114 |
+
|
115 |
+
root/dog/xxx.png
|
116 |
+
root/dog/xxy.png
|
117 |
+
root/dog/xxz.png
|
118 |
+
|
119 |
+
root/cat/123.png
|
120 |
+
root/cat/nsdf3.png
|
121 |
+
root/cat/asd932_.png
|
122 |
+
|
123 |
+
Args:
|
124 |
+
root (string): Root directory path.
|
125 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
126 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
127 |
+
target_transform (callable, optional): A function/transform that takes in the
|
128 |
+
target and transforms it.
|
129 |
+
loader (callable, optional): A function to load an image given its path.
|
130 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
131 |
+
and check if the file is a valid file (used to check of corrupt files)
|
132 |
+
sketch_net: The network to convert color image to sketch image
|
133 |
+
ncluster: Number of clusters when extracting color palette.
|
134 |
+
|
135 |
+
Attributes:
|
136 |
+
classes (list): List of the class names.
|
137 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
138 |
+
imgs (list): List of (image path, class_index) tuples
|
139 |
+
|
140 |
+
Getitem:
|
141 |
+
img_edge: Edge image
|
142 |
+
img: Color Image
|
143 |
+
color_palette: Extracted color paltette
|
144 |
+
"""
|
145 |
+
def __init__(self, root, transform, sketch_net, ncluster):
|
146 |
+
super(GetImageFolder, self).__init__(root, transform)
|
147 |
+
self.ncluster = ncluster
|
148 |
+
self.sketch_net = sketch_net
|
149 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
150 |
+
|
151 |
+
def __getitem__(self, index):
|
152 |
+
path, label = self.imgs[index]
|
153 |
+
img = self.loader(path)
|
154 |
+
img = np.asarray(img)
|
155 |
+
img = self.transform(img)
|
156 |
+
color_palette = color_cluster(img, nclusters=self.ncluster)
|
157 |
+
img = self.make_tensor(img)
|
158 |
+
|
159 |
+
with torch.no_grad():
|
160 |
+
img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1, 2, 0).cpu().numpy()
|
161 |
+
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
|
162 |
+
img_edge = FF.to_tensor(img_edge)
|
163 |
+
|
164 |
+
for i in range(0, len(color_palette)):
|
165 |
+
color = color_palette[i]
|
166 |
+
color_palette[i] = self.make_tensor(color)
|
167 |
+
|
168 |
+
return img_edge, img, color_palette
|
169 |
+
|
170 |
+
def make_tensor(self, img):
|
171 |
+
img = FF.to_tensor(img)
|
172 |
+
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
173 |
+
return img
|
edge2color.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:221a9a798919af697a590a8c98f6cdb0b29f3ab836ca8876ca5724f8eea4f7bd
|
3 |
+
size 254069569
|
examples/img1.jpg
ADDED
![]() |
examples/img2.jpg
ADDED
![]() |
examples/img3.jpg
ADDED
![]() |
examples/img4.jpg
ADDED
![]() |
examples/ref1.jpg
ADDED
![]() |
examples/ref2.jpg
ADDED
![]() |
examples/ref3.jpg
ADDED
![]() |
examples/ref4.jpg
ADDED
![]() |
examples/sketch1.jpg
ADDED
![]() |
examples/sketch2.jpg
ADDED
![]() |
examples/sketch3.jpg
ADDED
![]() |
examples/sketch4.jpg
ADDED
![]() |
mymodels.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'color_cluster','Color2Sketch', 'Sketch2Color', 'Discriminator',
|
9 |
+
]
|
10 |
+
|
11 |
+
|
12 |
+
def color_cluster(img, nclusters=9):
|
13 |
+
"""
|
14 |
+
Apply K-means clustering to the input image
|
15 |
+
|
16 |
+
Args:
|
17 |
+
img: Numpy array which has shape of (H, W, C)
|
18 |
+
nclusters: # of clusters (default = 9)
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
color_palette: list of 3D numpy arrays which have same shape of that of input image
|
22 |
+
e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
|
23 |
+
and each component is (256, 256, 3) numpy array.
|
24 |
+
|
25 |
+
Note:
|
26 |
+
K-means clustering algorithm is quite computaionally intensive.
|
27 |
+
Thus, before extracting dominant colors, the input images are resized to x0.25 size.
|
28 |
+
"""
|
29 |
+
img_size = img.shape
|
30 |
+
small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
|
31 |
+
sample = small_img.reshape((-1, 3))
|
32 |
+
sample = np.float32(sample)
|
33 |
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
34 |
+
flags = cv2.KMEANS_PP_CENTERS
|
35 |
+
|
36 |
+
_, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
|
37 |
+
centers = np.uint8(centers)
|
38 |
+
color_palette = []
|
39 |
+
|
40 |
+
for i in range(0, nclusters):
|
41 |
+
dominant_color = np.zeros(img_size, dtype='uint8')
|
42 |
+
dominant_color[:, :, :] = centers[i]
|
43 |
+
color_palette.append(dominant_color)
|
44 |
+
|
45 |
+
return color_palette
|
46 |
+
|
47 |
+
|
48 |
+
class ApplyNoise(nn.Module):
|
49 |
+
def __init__(self, channels):
|
50 |
+
super().__init__()
|
51 |
+
self.weight = nn.Parameter(torch.zeros(channels))
|
52 |
+
|
53 |
+
def forward(self, x, noise=None):
|
54 |
+
if noise is None:
|
55 |
+
noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
|
56 |
+
return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device)
|
57 |
+
|
58 |
+
|
59 |
+
class Conv2d_WS(nn.Conv2d):
|
60 |
+
def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
61 |
+
super().__init__(in_chan, out_chan, kernel_size, stride, padding, dilation, groups, bias)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
weight = self.weight
|
65 |
+
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
|
66 |
+
weight = weight - weight_mean
|
67 |
+
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
68 |
+
weight = weight / std.expand_as(weight)
|
69 |
+
return torch.nn.functional.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
70 |
+
|
71 |
+
|
72 |
+
class ResidualBlock(nn.Module):
|
73 |
+
def __init__(self, in_channels, out_channels, stride=1, sample=None):
|
74 |
+
super(ResidualBlock, self).__init__()
|
75 |
+
self.ic = in_channels
|
76 |
+
self.oc = out_channels
|
77 |
+
self.conv1 = Conv2d_WS(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
78 |
+
self.bn1 = nn.GroupNorm(32, out_channels)
|
79 |
+
self.conv2 = Conv2d_WS(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
80 |
+
self.bn2 = nn.GroupNorm(32, out_channels)
|
81 |
+
self.convr = Conv2d_WS(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
82 |
+
self.bnr = nn.GroupNorm(32, out_channels)
|
83 |
+
self.relu = nn.ReLU(inplace=True)
|
84 |
+
self.sample = sample
|
85 |
+
if self.sample == 'down':
|
86 |
+
self.sampling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
87 |
+
elif self.sample == 'up':
|
88 |
+
self.sampling = nn.Upsample(scale_factor=2, mode='nearest')
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
if self.ic != self.oc:
|
92 |
+
residual = self.convr(x)
|
93 |
+
residual = self.bnr(residual)
|
94 |
+
else:
|
95 |
+
residual = x
|
96 |
+
out = self.conv1(x)
|
97 |
+
out = self.bn1(out)
|
98 |
+
out = self.relu(out)
|
99 |
+
out = self.conv2(out)
|
100 |
+
out = self.bn2(out)
|
101 |
+
out += residual
|
102 |
+
out = self.relu(out)
|
103 |
+
if self.sample is not None:
|
104 |
+
out = self.sampling(out)
|
105 |
+
return out
|
106 |
+
|
107 |
+
|
108 |
+
class Attention_block(nn.Module):
|
109 |
+
def __init__(self, F_g, F_l, F_int):
|
110 |
+
super(Attention_block, self).__init__()
|
111 |
+
self.W_g = nn.Sequential(
|
112 |
+
Conv2d_WS(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
113 |
+
nn.GroupNorm(32, F_int)
|
114 |
+
)
|
115 |
+
|
116 |
+
self.W_x = nn.Sequential(
|
117 |
+
Conv2d_WS(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
118 |
+
nn.GroupNorm(32, F_int)
|
119 |
+
)
|
120 |
+
|
121 |
+
self.psi = nn.Sequential(
|
122 |
+
Conv2d_WS(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
123 |
+
nn.InstanceNorm2d(1),
|
124 |
+
nn.Sigmoid()
|
125 |
+
)
|
126 |
+
|
127 |
+
self.relu = nn.ReLU(inplace=True)
|
128 |
+
|
129 |
+
def forward(self, g, x):
|
130 |
+
g1 = self.W_g(g)
|
131 |
+
x1 = self.W_x(x)
|
132 |
+
psi = self.relu(g1 + x1)
|
133 |
+
psi = self.psi(psi)
|
134 |
+
|
135 |
+
return x * psi
|
136 |
+
|
137 |
+
|
138 |
+
class Color2Sketch(nn.Module):
|
139 |
+
def __init__(self, nc=3, pretrained=False):
|
140 |
+
super(Color2Sketch, self).__init__()
|
141 |
+
|
142 |
+
class Encoder(nn.Module):
|
143 |
+
def __init__(self):
|
144 |
+
super(Encoder, self).__init__()
|
145 |
+
# Build ResNet and change first conv layer to accept single-channel input
|
146 |
+
self.layer1 = ResidualBlock(nc, 64, sample='down')
|
147 |
+
self.layer2 = ResidualBlock(64, 128, sample='down')
|
148 |
+
self.layer3 = ResidualBlock(128, 256, sample='down')
|
149 |
+
self.layer4 = ResidualBlock(256, 512, sample='down')
|
150 |
+
self.layer5 = ResidualBlock(512, 512, sample='down')
|
151 |
+
self.layer6 = ResidualBlock(512, 512, sample='down')
|
152 |
+
self.layer7 = ResidualBlock(512, 512, sample='down')
|
153 |
+
|
154 |
+
def forward(self, input_image):
|
155 |
+
# Pass input through ResNet-gray to extract features
|
156 |
+
x0 = input_image # nc * 256 * 256
|
157 |
+
x1 = self.layer1(x0) # 64 * 128 * 128
|
158 |
+
x2 = self.layer2(x1) # 128 * 64 * 64
|
159 |
+
x3 = self.layer3(x2) # 256 * 32 * 32
|
160 |
+
x4 = self.layer4(x3) # 512 * 16 * 16
|
161 |
+
x5 = self.layer5(x4) # 512 * 8 * 8
|
162 |
+
x6 = self.layer6(x5) # 512 * 4 * 4
|
163 |
+
x7 = self.layer7(x6) # 512 * 2 * 2
|
164 |
+
|
165 |
+
return x1, x2, x3, x4, x5, x6, x7
|
166 |
+
|
167 |
+
class Decoder(nn.Module):
|
168 |
+
def __init__(self):
|
169 |
+
super(Decoder, self).__init__()
|
170 |
+
# Convolutional layers and upsampling
|
171 |
+
self.noise7 = ApplyNoise(512)
|
172 |
+
self.layer7_up = ResidualBlock(512, 512, sample='up')
|
173 |
+
|
174 |
+
self.Att6 = Attention_block(F_g=512, F_l=512, F_int=256)
|
175 |
+
self.layer6 = ResidualBlock(1024, 512, sample=None)
|
176 |
+
self.noise6 = ApplyNoise(512)
|
177 |
+
self.layer6_up = ResidualBlock(512, 512, sample='up')
|
178 |
+
|
179 |
+
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
|
180 |
+
self.layer5 = ResidualBlock(1024, 512, sample=None)
|
181 |
+
self.noise5 = ApplyNoise(512)
|
182 |
+
self.layer5_up = ResidualBlock(512, 512, sample='up')
|
183 |
+
|
184 |
+
self.Att4 = Attention_block(F_g=512, F_l=512, F_int=256)
|
185 |
+
self.layer4 = ResidualBlock(1024, 512, sample=None)
|
186 |
+
self.noise4 = ApplyNoise(512)
|
187 |
+
self.layer4_up = ResidualBlock(512, 256, sample='up')
|
188 |
+
|
189 |
+
self.Att3 = Attention_block(F_g=256, F_l=256, F_int=128)
|
190 |
+
self.layer3 = ResidualBlock(512, 256, sample=None)
|
191 |
+
self.noise3 = ApplyNoise(256)
|
192 |
+
self.layer3_up = ResidualBlock(256, 128, sample='up')
|
193 |
+
|
194 |
+
self.Att2 = Attention_block(F_g=128, F_l=128, F_int=64)
|
195 |
+
self.layer2 = ResidualBlock(256, 128, sample=None)
|
196 |
+
self.noise2 = ApplyNoise(128)
|
197 |
+
self.layer2_up = ResidualBlock(128, 64, sample='up')
|
198 |
+
|
199 |
+
self.Att1 = Attention_block(F_g=64, F_l=64, F_int=32)
|
200 |
+
self.layer1 = ResidualBlock(128, 64, sample=None)
|
201 |
+
self.noise1 = ApplyNoise(64)
|
202 |
+
self.layer1_up = ResidualBlock(64, 32, sample='up')
|
203 |
+
|
204 |
+
self.noise0 = ApplyNoise(32)
|
205 |
+
self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1)
|
206 |
+
self.activation = nn.ReLU(inplace=True)
|
207 |
+
self.tanh = nn.Tanh()
|
208 |
+
|
209 |
+
def forward(self, midlevel_input): # , global_input):
|
210 |
+
x1, x2, x3, x4, x5, x6, x7 = midlevel_input
|
211 |
+
|
212 |
+
x = self.noise7(x7)
|
213 |
+
x = self.layer7_up(x) # 512 * 4 * 4
|
214 |
+
|
215 |
+
x6 = self.Att6(g=x, x=x6)
|
216 |
+
x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4
|
217 |
+
x = self.layer6(x) # 512 * 4 * 4
|
218 |
+
x = self.noise6(x)
|
219 |
+
x = self.layer6_up(x) # 512 * 8 * 8
|
220 |
+
|
221 |
+
x5 = self.Att5(g=x, x=x5)
|
222 |
+
x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8
|
223 |
+
x = self.layer5(x) # 512 * 8 * 8
|
224 |
+
x = self.noise5(x)
|
225 |
+
x = self.layer5_up(x) # 512 * 16 * 16
|
226 |
+
|
227 |
+
x4 = self.Att4(g=x, x=x4)
|
228 |
+
x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16
|
229 |
+
x = self.layer4(x) # 512 * 16 * 16
|
230 |
+
x = self.noise4(x)
|
231 |
+
x = self.layer4_up(x) # 256 * 32 * 32
|
232 |
+
|
233 |
+
x3 = self.Att3(g=x, x=x3)
|
234 |
+
x = torch.cat((x, x3), dim=1) # 512 * 32 * 32
|
235 |
+
x = self.layer3(x) # 256 * 32 * 32
|
236 |
+
x = self.noise3(x)
|
237 |
+
x = self.layer3_up(x) # 128 * 64 * 64
|
238 |
+
|
239 |
+
x2 = self.Att2(g=x, x=x2)
|
240 |
+
x = torch.cat((x, x2), dim=1) # 256 * 64 * 64
|
241 |
+
x = self.layer2(x) # 128 * 64 * 64
|
242 |
+
x = self.noise2(x)
|
243 |
+
x = self.layer2_up(x) # 64 * 128 * 128
|
244 |
+
|
245 |
+
x1 = self.Att1(g=x, x=x1)
|
246 |
+
x = torch.cat((x, x1), dim=1) # 128 * 128 * 128
|
247 |
+
x = self.layer1(x) # 64 * 128 * 128
|
248 |
+
x = self.noise1(x)
|
249 |
+
x = self.layer1_up(x) # 32 * 256 * 256
|
250 |
+
|
251 |
+
x = self.noise0(x)
|
252 |
+
x = self.layer0(x) # 3 * 256 * 256
|
253 |
+
x = self.tanh(x)
|
254 |
+
|
255 |
+
return x
|
256 |
+
|
257 |
+
self.encoder = Encoder()
|
258 |
+
self.decoder = Decoder()
|
259 |
+
if pretrained:
|
260 |
+
print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
|
261 |
+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
262 |
+
checkpoint = torch.load('color2edge.pth')
|
263 |
+
self.load_state_dict(checkpoint['netG'], strict=True)
|
264 |
+
print("Done!")
|
265 |
+
else:
|
266 |
+
self.apply(weights_init)
|
267 |
+
print('Weights of {0} model are initialized'.format('Color2Sketch'))
|
268 |
+
|
269 |
+
def forward(self, inputs):
|
270 |
+
encode = self.encoder(inputs)
|
271 |
+
output = self.decoder(encode)
|
272 |
+
|
273 |
+
return output
|
274 |
+
|
275 |
+
|
276 |
+
class Sketch2Color(nn.Module):
|
277 |
+
def __init__(self, nc=3, pretrained=False):
|
278 |
+
super(Sketch2Color, self).__init__()
|
279 |
+
|
280 |
+
class Encoder(nn.Module):
|
281 |
+
def __init__(self):
|
282 |
+
super(Encoder, self).__init__()
|
283 |
+
# Build ResNet and change first conv layer to accept single-channel input
|
284 |
+
self.layer1 = ResidualBlock(nc, 64, sample='down')
|
285 |
+
self.layer2 = ResidualBlock(64, 128, sample='down')
|
286 |
+
self.layer3 = ResidualBlock(128, 256, sample='down')
|
287 |
+
self.layer4 = ResidualBlock(256, 512, sample='down')
|
288 |
+
self.layer5 = ResidualBlock(512, 512, sample='down')
|
289 |
+
self.layer6 = ResidualBlock(512, 512, sample='down')
|
290 |
+
self.layer7 = ResidualBlock(512, 512, sample='down')
|
291 |
+
|
292 |
+
def forward(self, input_image):
|
293 |
+
# Pass input through ResNet-gray to extract features
|
294 |
+
x0 = input_image # nc * 256 * 256
|
295 |
+
x1 = self.layer1(x0) # 64 * 128 * 128
|
296 |
+
x2 = self.layer2(x1) # 128 * 64 * 64
|
297 |
+
x3 = self.layer3(x2) # 256 * 32 * 32
|
298 |
+
x4 = self.layer4(x3) # 512 * 16 * 16
|
299 |
+
x5 = self.layer5(x4) # 512 * 8 * 8
|
300 |
+
x6 = self.layer6(x5) # 512 * 4 * 4
|
301 |
+
x7 = self.layer7(x6) # 512 * 2 * 2
|
302 |
+
|
303 |
+
return x1, x2, x3, x4, x5, x6, x7
|
304 |
+
|
305 |
+
class Decoder(nn.Module):
|
306 |
+
def __init__(self):
|
307 |
+
super(Decoder, self).__init__()
|
308 |
+
# Convolutional layers and upsampling
|
309 |
+
self.noise7 = ApplyNoise(512)
|
310 |
+
self.layer7_up = ResidualBlock(512, 512, sample='up')
|
311 |
+
|
312 |
+
self.Att6 = Attention_block(F_g=512, F_l=512, F_int=256)
|
313 |
+
self.layer6 = ResidualBlock(1024, 512, sample=None)
|
314 |
+
self.noise6 = ApplyNoise(512)
|
315 |
+
self.layer6_up = ResidualBlock(512, 512, sample='up')
|
316 |
+
|
317 |
+
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
|
318 |
+
self.layer5 = ResidualBlock(1024, 512, sample=None)
|
319 |
+
self.noise5 = ApplyNoise(512)
|
320 |
+
self.layer5_up = ResidualBlock(512, 512, sample='up')
|
321 |
+
|
322 |
+
self.Att4 = Attention_block(F_g=512, F_l=512, F_int=256)
|
323 |
+
self.layer4 = ResidualBlock(1024, 512, sample=None)
|
324 |
+
self.noise4 = ApplyNoise(512)
|
325 |
+
self.layer4_up = ResidualBlock(512, 256, sample='up')
|
326 |
+
|
327 |
+
self.Att3 = Attention_block(F_g=256, F_l=256, F_int=128)
|
328 |
+
self.layer3 = ResidualBlock(512, 256, sample=None)
|
329 |
+
self.noise3 = ApplyNoise(256)
|
330 |
+
self.layer3_up = ResidualBlock(256, 128, sample='up')
|
331 |
+
|
332 |
+
self.Att2 = Attention_block(F_g=128, F_l=128, F_int=64)
|
333 |
+
self.layer2 = ResidualBlock(256, 128, sample=None)
|
334 |
+
self.noise2 = ApplyNoise(128)
|
335 |
+
self.layer2_up = ResidualBlock(128, 64, sample='up')
|
336 |
+
|
337 |
+
self.Att1 = Attention_block(F_g=64, F_l=64, F_int=32)
|
338 |
+
self.layer1 = ResidualBlock(128, 64, sample=None)
|
339 |
+
self.noise1 = ApplyNoise(64)
|
340 |
+
self.layer1_up = ResidualBlock(64, 32, sample='up')
|
341 |
+
|
342 |
+
self.noise0 = ApplyNoise(32)
|
343 |
+
self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1)
|
344 |
+
self.activation = nn.ReLU(inplace=True)
|
345 |
+
self.tanh = nn.Tanh()
|
346 |
+
|
347 |
+
def forward(self, midlevel_input): # , global_input):
|
348 |
+
x1, x2, x3, x4, x5, x6, x7 = midlevel_input
|
349 |
+
|
350 |
+
x = self.noise7(x7)
|
351 |
+
x = self.layer7_up(x) # 512 * 4 * 4
|
352 |
+
|
353 |
+
x6 = self.Att6(g=x, x=x6)
|
354 |
+
x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4
|
355 |
+
x = self.layer6(x) # 512 * 4 * 4
|
356 |
+
x = self.noise6(x)
|
357 |
+
x = self.layer6_up(x) # 512 * 8 * 8
|
358 |
+
|
359 |
+
x5 = self.Att5(g=x, x=x5)
|
360 |
+
x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8
|
361 |
+
x = self.layer5(x) # 512 * 8 * 8
|
362 |
+
x = self.noise5(x)
|
363 |
+
x = self.layer5_up(x) # 512 * 16 * 16
|
364 |
+
|
365 |
+
x4 = self.Att4(g=x, x=x4)
|
366 |
+
x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16
|
367 |
+
x = self.layer4(x) # 512 * 16 * 16
|
368 |
+
x = self.noise4(x)
|
369 |
+
x = self.layer4_up(x) # 256 * 32 * 32
|
370 |
+
|
371 |
+
x3 = self.Att3(g=x, x=x3)
|
372 |
+
x = torch.cat((x, x3), dim=1) # 512 * 32 * 32
|
373 |
+
x = self.layer3(x) # 256 * 32 * 32
|
374 |
+
x = self.noise3(x)
|
375 |
+
x = self.layer3_up(x) # 128 * 64 * 64
|
376 |
+
|
377 |
+
x2 = self.Att2(g=x, x=x2)
|
378 |
+
x = torch.cat((x, x2), dim=1) # 256 * 64 * 64
|
379 |
+
x = self.layer2(x) # 128 * 64 * 64
|
380 |
+
x = self.noise2(x)
|
381 |
+
x = self.layer2_up(x) # 64 * 128 * 128
|
382 |
+
|
383 |
+
x1 = self.Att1(g=x, x=x1)
|
384 |
+
x = torch.cat((x, x1), dim=1) # 128 * 128 * 128
|
385 |
+
x = self.layer1(x) # 64 * 128 * 128
|
386 |
+
x = self.noise1(x)
|
387 |
+
x = self.layer1_up(x) # 32 * 256 * 256
|
388 |
+
|
389 |
+
x = self.noise0(x)
|
390 |
+
x = self.layer0(x) # 3 * 256 * 256
|
391 |
+
x = self.tanh(x)
|
392 |
+
|
393 |
+
return x
|
394 |
+
|
395 |
+
self.encoder = Encoder()
|
396 |
+
self.decoder = Decoder()
|
397 |
+
if pretrained:
|
398 |
+
print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
|
399 |
+
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
400 |
+
checkpoint = torch.load('edge2color.pth')
|
401 |
+
self.load_state_dict(checkpoint['netG'], strict=True)
|
402 |
+
print("Done!")
|
403 |
+
else:
|
404 |
+
self.apply(weights_init)
|
405 |
+
print('Weights of {0} model are initialized'.format('Sketch2Color'))
|
406 |
+
|
407 |
+
def forward(self, inputs):
|
408 |
+
encode = self.encoder(inputs)
|
409 |
+
output = self.decoder(encode)
|
410 |
+
|
411 |
+
return output
|
412 |
+
|
413 |
+
|
414 |
+
class Discriminator(nn.Module):
|
415 |
+
def __init__(self, nc=6, pretrained=False):
|
416 |
+
super(Discriminator, self).__init__()
|
417 |
+
self.conv1 = torch.nn.utils.spectral_norm(nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1))
|
418 |
+
self.bn1 = nn.GroupNorm(32, 64)
|
419 |
+
self.conv2 = torch.nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1))
|
420 |
+
self.bn2 = nn.GroupNorm(32, 128)
|
421 |
+
self.conv3 = torch.nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1))
|
422 |
+
self.bn3 = nn.GroupNorm(32, 256)
|
423 |
+
self.conv4 = torch.nn.utils.spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1))
|
424 |
+
self.bn4 = nn.GroupNorm(32, 512)
|
425 |
+
self.conv5 = torch.nn.utils.spectral_norm(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1))
|
426 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
427 |
+
self.sigmoid = nn.Sigmoid()
|
428 |
+
|
429 |
+
if pretrained:
|
430 |
+
pass
|
431 |
+
else:
|
432 |
+
self.apply(weights_init)
|
433 |
+
print('Weights of {0} model are initialized'.format('Discriminator'))
|
434 |
+
|
435 |
+
def forward(self, base, unknown):
|
436 |
+
input = torch.cat((base, unknown), dim=1)
|
437 |
+
x = self.activation(self.conv1(input))
|
438 |
+
x = self.activation(self.bn2(self.conv2(x)))
|
439 |
+
x = self.activation(self.bn3(self.conv3(x)))
|
440 |
+
x = self.activation(self.bn4(self.conv4(x)))
|
441 |
+
x = self.sigmoid(self.conv5(x))
|
442 |
+
|
443 |
+
return x.mean((2, 3))
|
444 |
+
|
445 |
+
|
446 |
+
# To initialize model weights
|
447 |
+
def weights_init(model):
|
448 |
+
classname = model.__class__.__name__
|
449 |
+
if classname.find('Conv') != -1:
|
450 |
+
nn.init.normal_(model.weight.data, 0.0, 0.02)
|
451 |
+
elif classname.find('Conv2d_WS') != -1:
|
452 |
+
nn.init.normal_(model.weight.data, 0.0, 0.02)
|
453 |
+
elif classname.find('BatchNorm') != -1:
|
454 |
+
nn.init.normal_(model.weight.data, 1.0, 0.02)
|
455 |
+
nn.init.constant_(model.bias.data, 0)
|
456 |
+
elif classname.find('GroupNorm') != -1:
|
457 |
+
nn.init.normal_(model.weight.data, 1.0, 0.02)
|
458 |
+
nn.init.constant_(model.bias.data, 0)
|
459 |
+
else:
|
460 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
gradio==3.38.0
|
3 |
+
opencv-python==4.8.0.74
|
4 |
+
opencv-transforms==0.0.6
|