rajatsingh0702 commited on
Commit
7234ee2
·
1 Parent(s): 5748770

files added

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ color2edge.pth filter=lfs diff=lfs merge=lfs -text
37
+ edge2color.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For plotting
2
+ import numpy as np
3
+
4
+ # For utilities
5
+ from timeit import default_timer as timer
6
+
7
+ # For conversion
8
+ import opencv_transforms.transforms as TF
9
+ import opencv_transforms.functional as FF
10
+
11
+ # For everything
12
+ import torch
13
+
14
+ # For our model
15
+ import mymodels
16
+
17
+ # For demo api
18
+ import gradio as gr
19
+
20
+ # To ignore warning
21
+ import warnings
22
+
23
+ warnings.simplefilter("ignore", UserWarning)
24
+
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+ ncluster = 9
27
+ nc = 3 * (ncluster + 1)
28
+ netC2S = mymodels.Color2Sketch(pretrained=True).to(device)
29
+ netG = mymodels.Sketch2Color(nc=nc, pretrained=True).to(device)
30
+ transform = TF.Resize((512, 512))
31
+
32
+
33
+ def make_tensor(img):
34
+ img = FF.to_tensor(img)
35
+ img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
36
+ return img
37
+
38
+
39
+ def predictC2S(img):
40
+ final_transform = TF.Resize((img.size[0], img.size[1]))
41
+ img = np.array(img)
42
+ img = transform(img)
43
+ img = make_tensor(img)
44
+ start_time = timer()
45
+ with torch.inference_mode():
46
+ img_edge = netC2S(img.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
47
+ img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
48
+ img = FF.to_tensor(img_edge).permute(1, 2, 0).cpu().numpy()
49
+ end_time = timer()
50
+ img = final_transform(img)
51
+ return img, round(end_time - start_time, 3)
52
+
53
+
54
+ def predictS2C(img, ref):
55
+ final_transform = TF.Resize((img.size[0], img.size[1]))
56
+ img = np.array(img)
57
+ ref = np.array(ref)
58
+ ref = transform(ref)
59
+ img = transform(img)
60
+ img = make_tensor(img)
61
+ color_palette = mymodels.color_cluster(ref)
62
+ for i in range(0, len(color_palette)):
63
+ color = color_palette[i]
64
+ color_palette[i] = make_tensor(color)
65
+ start_time = timer()
66
+ with torch.inference_mode():
67
+ img_edge = netC2S(img.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
68
+ img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
69
+ img = FF.to_tensor(img_edge)
70
+ input_tensor = torch.cat([img.cpu()] + color_palette, dim=0).to(device)
71
+ with torch.inference_mode():
72
+ fake = netG(input_tensor.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
73
+ end_time = timer()
74
+ fake = final_transform(fake)
75
+ return fake, round(end_time - start_time, 3)
76
+
77
+
78
+ example_list1 = [["./examples/img1.jpg", "./examples/ref1.jpg"],
79
+ ["./examples/img2.jpg", "./examples/ref2.jpg"],
80
+ ["./examples/img3.jpg", "./examples/ref3.jpg"],
81
+ ["./examples/img4.jpg", "./examples/ref4.jpg"]]
82
+ example_list2 = [["./examples/sketch1.jpg"],
83
+ ["./examples/sketch2.jpg"],
84
+ ["./examples/sketch3.jpg"],
85
+ ["./examples/sketch4.jpg"]]
86
+
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("# Color2Sketch & Sketch2Color")
89
+ with gr.Tab("Sketch To Color"):
90
+ gr.Markdown("### Enter the **Sketch** & **Reference** on the left side. You can use example list.")
91
+ with gr.Row():
92
+ with gr.Column():
93
+ input1 = [gr.inputs.Image(type="pil", label="Sketch"), gr.inputs.Image(type="pil", label="Reference")]
94
+ with gr.Row():
95
+ # Clear Button
96
+ gr.ClearButton(input1)
97
+ btn1 = gr.Button("Submit")
98
+ gr.Examples(examples=example_list1, inputs=input1)
99
+ with gr.Column():
100
+ output1 = [gr.inputs.Image(type="pil", label="Colored Sketch"), gr.Number(label="Prediction time (s)")]
101
+ with gr.Tab("Color To Sketch"):
102
+ gr.Markdown(
103
+ "### Enter the **Colored Sketch** on the left side. You can use example list.")
104
+ with gr.Row():
105
+ with gr.Column():
106
+ input2 = gr.inputs.Image(type="pil", label="Color Sketch")
107
+ with gr.Row():
108
+ # Clear Button
109
+ gr.ClearButton(input2)
110
+ btn2 = gr.Button("Submit")
111
+ gr.Examples(example_list2, inputs=input2)
112
+ with gr.Column():
113
+ output2 = [gr.inputs.Image(type="pil", label="Sketch"), gr.Number(label="Prediction time (s)")]
114
+ btn1.click(predictS2C, inputs=input1, outputs=output1)
115
+ btn2.click(predictC2S, inputs=input2, outputs=output2)
116
+ gr.Markdown("""
117
+ ### The model is taken from [this GitHub Repo.](https://github.com/delta6189/Anime-Sketch-Colorizer)
118
+
119
+ Email : rajatsingh072002@gmail.com | My [GitHub Repo](https://github.com/Rajatsingh24/Anime-Sketch2Color-Color2Sketch)
120
+ """)
121
+ if __name__ == "__main__":
122
+ demo.launch(debug=False)
color2edge.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:922c4e109a0f7086d48e211156c7f6fbeff6b0393baecb606f22f44c7cda9877
3
+ size 254000447
dataloader.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import torchvision
5
+ import opencv_transforms.functional as FF
6
+ from torchvision import datasets
7
+ from PIL import Image
8
+
9
+ def color_cluster(img, nclusters=9):
10
+ """
11
+ Apply K-means clustering to the input image
12
+
13
+ Args:
14
+ img: Numpy array which has shape of (H, W, C)
15
+ nclusters: # of clusters (default = 9)
16
+
17
+ Returns:
18
+ color_palette: list of 3D numpy arrays which have same shape of that of input image
19
+ e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
20
+ and each component is (256, 256, 3) numpy array.
21
+
22
+ Note:
23
+ K-means clustering algorithm is quite computaionally intensive.
24
+ Thus, before extracting dominant colors, the input images are resized to x0.25 size.
25
+ """
26
+ img_size = img.shape
27
+ small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
28
+ sample = small_img.reshape((-1, 3))
29
+ sample = np.float32(sample)
30
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
31
+ flags = cv2.KMEANS_PP_CENTERS
32
+
33
+ _, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
34
+ centers = np.uint8(centers)
35
+ color_palette = []
36
+
37
+ for i in range(0, nclusters):
38
+ dominant_color = np.zeros(img_size, dtype='uint8')
39
+ dominant_color[:,:,:] = centers[i]
40
+ color_palette.append(dominant_color)
41
+
42
+ return color_palette
43
+
44
+ class PairImageFolder(datasets.ImageFolder):
45
+ """
46
+ A generic data loader where the images are arranged in this way: ::
47
+
48
+ root/dog/xxx.png
49
+ root/dog/xxy.png
50
+ root/dog/xxz.png
51
+
52
+ root/cat/123.png
53
+ root/cat/nsdf3.png
54
+ root/cat/asd932_.png
55
+
56
+ This class works properly for paired image in form of [sketch, color_image]
57
+
58
+ Args:
59
+ root (string): Root directory path.
60
+ transform (callable, optional): A function/transform that takes in an PIL image
61
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
62
+ target_transform (callable, optional): A function/transform that takes in the
63
+ target and transforms it.
64
+ loader (callable, optional): A function to load an image given its path.
65
+ is_valid_file (callable, optional): A function that takes path of an Image file
66
+ and check if the file is a valid file (used to check of corrupt files)
67
+ sketch_net: The network to convert color image to sketch image
68
+ ncluster: Number of clusters when extracting color palette.
69
+
70
+ Attributes:
71
+ classes (list): List of the class names.
72
+ class_to_idx (dict): Dict with items (class_name, class_index).
73
+ imgs (list): List of (image path, class_index) tuples
74
+
75
+ Getitem:
76
+ img_edge: Edge image
77
+ img: Color Image
78
+ color_palette: Extracted color paltette
79
+ """
80
+ def __init__(self, root, transform, sketch_net, ncluster):
81
+ super(PairImageFolder, self).__init__(root, transform)
82
+ self.ncluster = ncluster
83
+ self.sketch_net = sketch_net
84
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
85
+
86
+ def __getitem__(self, index):
87
+ path, label = self.imgs[index]
88
+ img = self.loader(path)
89
+ img = np.asarray(img)
90
+ img = img[:, 0:512, :]
91
+ img = self.transform(img)
92
+ color_palette = color_cluster(img, nclusters=self.ncluster)
93
+ img = self.make_tensor(img)
94
+
95
+ with torch.no_grad():
96
+ img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1,2,0).cpu().numpy()
97
+ img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
98
+ img_edge = FF.to_tensor(img_edge)
99
+
100
+ for i in range(0, len(color_palette)):
101
+ color = color_palette[i]
102
+ color_palette[i] = self.make_tensor(color)
103
+
104
+ return img_edge, img, color_palette
105
+
106
+ def make_tensor(self, img):
107
+ img = FF.to_tensor(img)
108
+ img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
109
+ return img
110
+
111
+ class GetImageFolder(datasets.ImageFolder):
112
+ """
113
+ A generic data loader where the images are arranged in this way: ::
114
+
115
+ root/dog/xxx.png
116
+ root/dog/xxy.png
117
+ root/dog/xxz.png
118
+
119
+ root/cat/123.png
120
+ root/cat/nsdf3.png
121
+ root/cat/asd932_.png
122
+
123
+ Args:
124
+ root (string): Root directory path.
125
+ transform (callable, optional): A function/transform that takes in an PIL image
126
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
127
+ target_transform (callable, optional): A function/transform that takes in the
128
+ target and transforms it.
129
+ loader (callable, optional): A function to load an image given its path.
130
+ is_valid_file (callable, optional): A function that takes path of an Image file
131
+ and check if the file is a valid file (used to check of corrupt files)
132
+ sketch_net: The network to convert color image to sketch image
133
+ ncluster: Number of clusters when extracting color palette.
134
+
135
+ Attributes:
136
+ classes (list): List of the class names.
137
+ class_to_idx (dict): Dict with items (class_name, class_index).
138
+ imgs (list): List of (image path, class_index) tuples
139
+
140
+ Getitem:
141
+ img_edge: Edge image
142
+ img: Color Image
143
+ color_palette: Extracted color paltette
144
+ """
145
+ def __init__(self, root, transform, sketch_net, ncluster):
146
+ super(GetImageFolder, self).__init__(root, transform)
147
+ self.ncluster = ncluster
148
+ self.sketch_net = sketch_net
149
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
150
+
151
+ def __getitem__(self, index):
152
+ path, label = self.imgs[index]
153
+ img = self.loader(path)
154
+ img = np.asarray(img)
155
+ img = self.transform(img)
156
+ color_palette = color_cluster(img, nclusters=self.ncluster)
157
+ img = self.make_tensor(img)
158
+
159
+ with torch.no_grad():
160
+ img_edge = self.sketch_net(img.unsqueeze(0).to(self.device)).squeeze().permute(1, 2, 0).cpu().numpy()
161
+ img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
162
+ img_edge = FF.to_tensor(img_edge)
163
+
164
+ for i in range(0, len(color_palette)):
165
+ color = color_palette[i]
166
+ color_palette[i] = self.make_tensor(color)
167
+
168
+ return img_edge, img, color_palette
169
+
170
+ def make_tensor(self, img):
171
+ img = FF.to_tensor(img)
172
+ img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
173
+ return img
edge2color.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:221a9a798919af697a590a8c98f6cdb0b29f3ab836ca8876ca5724f8eea4f7bd
3
+ size 254069569
examples/img1.jpg ADDED
examples/img2.jpg ADDED
examples/img3.jpg ADDED
examples/img4.jpg ADDED
examples/ref1.jpg ADDED
examples/ref2.jpg ADDED
examples/ref3.jpg ADDED
examples/ref4.jpg ADDED
examples/sketch1.jpg ADDED
examples/sketch2.jpg ADDED
examples/sketch3.jpg ADDED
examples/sketch4.jpg ADDED
mymodels.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+
7
+ __all__ = [
8
+ 'color_cluster','Color2Sketch', 'Sketch2Color', 'Discriminator',
9
+ ]
10
+
11
+
12
+ def color_cluster(img, nclusters=9):
13
+ """
14
+ Apply K-means clustering to the input image
15
+
16
+ Args:
17
+ img: Numpy array which has shape of (H, W, C)
18
+ nclusters: # of clusters (default = 9)
19
+
20
+ Returns:
21
+ color_palette: list of 3D numpy arrays which have same shape of that of input image
22
+ e.g. If input image has shape of (256, 256, 3) and nclusters is 4, the return color_palette is [color1, color2, color3, color4]
23
+ and each component is (256, 256, 3) numpy array.
24
+
25
+ Note:
26
+ K-means clustering algorithm is quite computaionally intensive.
27
+ Thus, before extracting dominant colors, the input images are resized to x0.25 size.
28
+ """
29
+ img_size = img.shape
30
+ small_img = cv2.resize(img, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
31
+ sample = small_img.reshape((-1, 3))
32
+ sample = np.float32(sample)
33
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
34
+ flags = cv2.KMEANS_PP_CENTERS
35
+
36
+ _, _, centers = cv2.kmeans(sample, nclusters, None, criteria, 10, flags)
37
+ centers = np.uint8(centers)
38
+ color_palette = []
39
+
40
+ for i in range(0, nclusters):
41
+ dominant_color = np.zeros(img_size, dtype='uint8')
42
+ dominant_color[:, :, :] = centers[i]
43
+ color_palette.append(dominant_color)
44
+
45
+ return color_palette
46
+
47
+
48
+ class ApplyNoise(nn.Module):
49
+ def __init__(self, channels):
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.zeros(channels))
52
+
53
+ def forward(self, x, noise=None):
54
+ if noise is None:
55
+ noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
56
+ return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device)
57
+
58
+
59
+ class Conv2d_WS(nn.Conv2d):
60
+ def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
61
+ super().__init__(in_chan, out_chan, kernel_size, stride, padding, dilation, groups, bias)
62
+
63
+ def forward(self, x):
64
+ weight = self.weight
65
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
66
+ weight = weight - weight_mean
67
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
68
+ weight = weight / std.expand_as(weight)
69
+ return torch.nn.functional.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
70
+
71
+
72
+ class ResidualBlock(nn.Module):
73
+ def __init__(self, in_channels, out_channels, stride=1, sample=None):
74
+ super(ResidualBlock, self).__init__()
75
+ self.ic = in_channels
76
+ self.oc = out_channels
77
+ self.conv1 = Conv2d_WS(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
78
+ self.bn1 = nn.GroupNorm(32, out_channels)
79
+ self.conv2 = Conv2d_WS(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
80
+ self.bn2 = nn.GroupNorm(32, out_channels)
81
+ self.convr = Conv2d_WS(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
82
+ self.bnr = nn.GroupNorm(32, out_channels)
83
+ self.relu = nn.ReLU(inplace=True)
84
+ self.sample = sample
85
+ if self.sample == 'down':
86
+ self.sampling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
87
+ elif self.sample == 'up':
88
+ self.sampling = nn.Upsample(scale_factor=2, mode='nearest')
89
+
90
+ def forward(self, x):
91
+ if self.ic != self.oc:
92
+ residual = self.convr(x)
93
+ residual = self.bnr(residual)
94
+ else:
95
+ residual = x
96
+ out = self.conv1(x)
97
+ out = self.bn1(out)
98
+ out = self.relu(out)
99
+ out = self.conv2(out)
100
+ out = self.bn2(out)
101
+ out += residual
102
+ out = self.relu(out)
103
+ if self.sample is not None:
104
+ out = self.sampling(out)
105
+ return out
106
+
107
+
108
+ class Attention_block(nn.Module):
109
+ def __init__(self, F_g, F_l, F_int):
110
+ super(Attention_block, self).__init__()
111
+ self.W_g = nn.Sequential(
112
+ Conv2d_WS(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
113
+ nn.GroupNorm(32, F_int)
114
+ )
115
+
116
+ self.W_x = nn.Sequential(
117
+ Conv2d_WS(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
118
+ nn.GroupNorm(32, F_int)
119
+ )
120
+
121
+ self.psi = nn.Sequential(
122
+ Conv2d_WS(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
123
+ nn.InstanceNorm2d(1),
124
+ nn.Sigmoid()
125
+ )
126
+
127
+ self.relu = nn.ReLU(inplace=True)
128
+
129
+ def forward(self, g, x):
130
+ g1 = self.W_g(g)
131
+ x1 = self.W_x(x)
132
+ psi = self.relu(g1 + x1)
133
+ psi = self.psi(psi)
134
+
135
+ return x * psi
136
+
137
+
138
+ class Color2Sketch(nn.Module):
139
+ def __init__(self, nc=3, pretrained=False):
140
+ super(Color2Sketch, self).__init__()
141
+
142
+ class Encoder(nn.Module):
143
+ def __init__(self):
144
+ super(Encoder, self).__init__()
145
+ # Build ResNet and change first conv layer to accept single-channel input
146
+ self.layer1 = ResidualBlock(nc, 64, sample='down')
147
+ self.layer2 = ResidualBlock(64, 128, sample='down')
148
+ self.layer3 = ResidualBlock(128, 256, sample='down')
149
+ self.layer4 = ResidualBlock(256, 512, sample='down')
150
+ self.layer5 = ResidualBlock(512, 512, sample='down')
151
+ self.layer6 = ResidualBlock(512, 512, sample='down')
152
+ self.layer7 = ResidualBlock(512, 512, sample='down')
153
+
154
+ def forward(self, input_image):
155
+ # Pass input through ResNet-gray to extract features
156
+ x0 = input_image # nc * 256 * 256
157
+ x1 = self.layer1(x0) # 64 * 128 * 128
158
+ x2 = self.layer2(x1) # 128 * 64 * 64
159
+ x3 = self.layer3(x2) # 256 * 32 * 32
160
+ x4 = self.layer4(x3) # 512 * 16 * 16
161
+ x5 = self.layer5(x4) # 512 * 8 * 8
162
+ x6 = self.layer6(x5) # 512 * 4 * 4
163
+ x7 = self.layer7(x6) # 512 * 2 * 2
164
+
165
+ return x1, x2, x3, x4, x5, x6, x7
166
+
167
+ class Decoder(nn.Module):
168
+ def __init__(self):
169
+ super(Decoder, self).__init__()
170
+ # Convolutional layers and upsampling
171
+ self.noise7 = ApplyNoise(512)
172
+ self.layer7_up = ResidualBlock(512, 512, sample='up')
173
+
174
+ self.Att6 = Attention_block(F_g=512, F_l=512, F_int=256)
175
+ self.layer6 = ResidualBlock(1024, 512, sample=None)
176
+ self.noise6 = ApplyNoise(512)
177
+ self.layer6_up = ResidualBlock(512, 512, sample='up')
178
+
179
+ self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
180
+ self.layer5 = ResidualBlock(1024, 512, sample=None)
181
+ self.noise5 = ApplyNoise(512)
182
+ self.layer5_up = ResidualBlock(512, 512, sample='up')
183
+
184
+ self.Att4 = Attention_block(F_g=512, F_l=512, F_int=256)
185
+ self.layer4 = ResidualBlock(1024, 512, sample=None)
186
+ self.noise4 = ApplyNoise(512)
187
+ self.layer4_up = ResidualBlock(512, 256, sample='up')
188
+
189
+ self.Att3 = Attention_block(F_g=256, F_l=256, F_int=128)
190
+ self.layer3 = ResidualBlock(512, 256, sample=None)
191
+ self.noise3 = ApplyNoise(256)
192
+ self.layer3_up = ResidualBlock(256, 128, sample='up')
193
+
194
+ self.Att2 = Attention_block(F_g=128, F_l=128, F_int=64)
195
+ self.layer2 = ResidualBlock(256, 128, sample=None)
196
+ self.noise2 = ApplyNoise(128)
197
+ self.layer2_up = ResidualBlock(128, 64, sample='up')
198
+
199
+ self.Att1 = Attention_block(F_g=64, F_l=64, F_int=32)
200
+ self.layer1 = ResidualBlock(128, 64, sample=None)
201
+ self.noise1 = ApplyNoise(64)
202
+ self.layer1_up = ResidualBlock(64, 32, sample='up')
203
+
204
+ self.noise0 = ApplyNoise(32)
205
+ self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1)
206
+ self.activation = nn.ReLU(inplace=True)
207
+ self.tanh = nn.Tanh()
208
+
209
+ def forward(self, midlevel_input): # , global_input):
210
+ x1, x2, x3, x4, x5, x6, x7 = midlevel_input
211
+
212
+ x = self.noise7(x7)
213
+ x = self.layer7_up(x) # 512 * 4 * 4
214
+
215
+ x6 = self.Att6(g=x, x=x6)
216
+ x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4
217
+ x = self.layer6(x) # 512 * 4 * 4
218
+ x = self.noise6(x)
219
+ x = self.layer6_up(x) # 512 * 8 * 8
220
+
221
+ x5 = self.Att5(g=x, x=x5)
222
+ x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8
223
+ x = self.layer5(x) # 512 * 8 * 8
224
+ x = self.noise5(x)
225
+ x = self.layer5_up(x) # 512 * 16 * 16
226
+
227
+ x4 = self.Att4(g=x, x=x4)
228
+ x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16
229
+ x = self.layer4(x) # 512 * 16 * 16
230
+ x = self.noise4(x)
231
+ x = self.layer4_up(x) # 256 * 32 * 32
232
+
233
+ x3 = self.Att3(g=x, x=x3)
234
+ x = torch.cat((x, x3), dim=1) # 512 * 32 * 32
235
+ x = self.layer3(x) # 256 * 32 * 32
236
+ x = self.noise3(x)
237
+ x = self.layer3_up(x) # 128 * 64 * 64
238
+
239
+ x2 = self.Att2(g=x, x=x2)
240
+ x = torch.cat((x, x2), dim=1) # 256 * 64 * 64
241
+ x = self.layer2(x) # 128 * 64 * 64
242
+ x = self.noise2(x)
243
+ x = self.layer2_up(x) # 64 * 128 * 128
244
+
245
+ x1 = self.Att1(g=x, x=x1)
246
+ x = torch.cat((x, x1), dim=1) # 128 * 128 * 128
247
+ x = self.layer1(x) # 64 * 128 * 128
248
+ x = self.noise1(x)
249
+ x = self.layer1_up(x) # 32 * 256 * 256
250
+
251
+ x = self.noise0(x)
252
+ x = self.layer0(x) # 3 * 256 * 256
253
+ x = self.tanh(x)
254
+
255
+ return x
256
+
257
+ self.encoder = Encoder()
258
+ self.decoder = Decoder()
259
+ if pretrained:
260
+ print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
261
+ assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
262
+ checkpoint = torch.load('color2edge.pth')
263
+ self.load_state_dict(checkpoint['netG'], strict=True)
264
+ print("Done!")
265
+ else:
266
+ self.apply(weights_init)
267
+ print('Weights of {0} model are initialized'.format('Color2Sketch'))
268
+
269
+ def forward(self, inputs):
270
+ encode = self.encoder(inputs)
271
+ output = self.decoder(encode)
272
+
273
+ return output
274
+
275
+
276
+ class Sketch2Color(nn.Module):
277
+ def __init__(self, nc=3, pretrained=False):
278
+ super(Sketch2Color, self).__init__()
279
+
280
+ class Encoder(nn.Module):
281
+ def __init__(self):
282
+ super(Encoder, self).__init__()
283
+ # Build ResNet and change first conv layer to accept single-channel input
284
+ self.layer1 = ResidualBlock(nc, 64, sample='down')
285
+ self.layer2 = ResidualBlock(64, 128, sample='down')
286
+ self.layer3 = ResidualBlock(128, 256, sample='down')
287
+ self.layer4 = ResidualBlock(256, 512, sample='down')
288
+ self.layer5 = ResidualBlock(512, 512, sample='down')
289
+ self.layer6 = ResidualBlock(512, 512, sample='down')
290
+ self.layer7 = ResidualBlock(512, 512, sample='down')
291
+
292
+ def forward(self, input_image):
293
+ # Pass input through ResNet-gray to extract features
294
+ x0 = input_image # nc * 256 * 256
295
+ x1 = self.layer1(x0) # 64 * 128 * 128
296
+ x2 = self.layer2(x1) # 128 * 64 * 64
297
+ x3 = self.layer3(x2) # 256 * 32 * 32
298
+ x4 = self.layer4(x3) # 512 * 16 * 16
299
+ x5 = self.layer5(x4) # 512 * 8 * 8
300
+ x6 = self.layer6(x5) # 512 * 4 * 4
301
+ x7 = self.layer7(x6) # 512 * 2 * 2
302
+
303
+ return x1, x2, x3, x4, x5, x6, x7
304
+
305
+ class Decoder(nn.Module):
306
+ def __init__(self):
307
+ super(Decoder, self).__init__()
308
+ # Convolutional layers and upsampling
309
+ self.noise7 = ApplyNoise(512)
310
+ self.layer7_up = ResidualBlock(512, 512, sample='up')
311
+
312
+ self.Att6 = Attention_block(F_g=512, F_l=512, F_int=256)
313
+ self.layer6 = ResidualBlock(1024, 512, sample=None)
314
+ self.noise6 = ApplyNoise(512)
315
+ self.layer6_up = ResidualBlock(512, 512, sample='up')
316
+
317
+ self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
318
+ self.layer5 = ResidualBlock(1024, 512, sample=None)
319
+ self.noise5 = ApplyNoise(512)
320
+ self.layer5_up = ResidualBlock(512, 512, sample='up')
321
+
322
+ self.Att4 = Attention_block(F_g=512, F_l=512, F_int=256)
323
+ self.layer4 = ResidualBlock(1024, 512, sample=None)
324
+ self.noise4 = ApplyNoise(512)
325
+ self.layer4_up = ResidualBlock(512, 256, sample='up')
326
+
327
+ self.Att3 = Attention_block(F_g=256, F_l=256, F_int=128)
328
+ self.layer3 = ResidualBlock(512, 256, sample=None)
329
+ self.noise3 = ApplyNoise(256)
330
+ self.layer3_up = ResidualBlock(256, 128, sample='up')
331
+
332
+ self.Att2 = Attention_block(F_g=128, F_l=128, F_int=64)
333
+ self.layer2 = ResidualBlock(256, 128, sample=None)
334
+ self.noise2 = ApplyNoise(128)
335
+ self.layer2_up = ResidualBlock(128, 64, sample='up')
336
+
337
+ self.Att1 = Attention_block(F_g=64, F_l=64, F_int=32)
338
+ self.layer1 = ResidualBlock(128, 64, sample=None)
339
+ self.noise1 = ApplyNoise(64)
340
+ self.layer1_up = ResidualBlock(64, 32, sample='up')
341
+
342
+ self.noise0 = ApplyNoise(32)
343
+ self.layer0 = Conv2d_WS(32, 3, kernel_size=3, stride=1, padding=1)
344
+ self.activation = nn.ReLU(inplace=True)
345
+ self.tanh = nn.Tanh()
346
+
347
+ def forward(self, midlevel_input): # , global_input):
348
+ x1, x2, x3, x4, x5, x6, x7 = midlevel_input
349
+
350
+ x = self.noise7(x7)
351
+ x = self.layer7_up(x) # 512 * 4 * 4
352
+
353
+ x6 = self.Att6(g=x, x=x6)
354
+ x = torch.cat((x, x6), dim=1) # 1024 * 4 * 4
355
+ x = self.layer6(x) # 512 * 4 * 4
356
+ x = self.noise6(x)
357
+ x = self.layer6_up(x) # 512 * 8 * 8
358
+
359
+ x5 = self.Att5(g=x, x=x5)
360
+ x = torch.cat((x, x5), dim=1) # 1024 * 8 * 8
361
+ x = self.layer5(x) # 512 * 8 * 8
362
+ x = self.noise5(x)
363
+ x = self.layer5_up(x) # 512 * 16 * 16
364
+
365
+ x4 = self.Att4(g=x, x=x4)
366
+ x = torch.cat((x, x4), dim=1) # 1024 * 16 * 16
367
+ x = self.layer4(x) # 512 * 16 * 16
368
+ x = self.noise4(x)
369
+ x = self.layer4_up(x) # 256 * 32 * 32
370
+
371
+ x3 = self.Att3(g=x, x=x3)
372
+ x = torch.cat((x, x3), dim=1) # 512 * 32 * 32
373
+ x = self.layer3(x) # 256 * 32 * 32
374
+ x = self.noise3(x)
375
+ x = self.layer3_up(x) # 128 * 64 * 64
376
+
377
+ x2 = self.Att2(g=x, x=x2)
378
+ x = torch.cat((x, x2), dim=1) # 256 * 64 * 64
379
+ x = self.layer2(x) # 128 * 64 * 64
380
+ x = self.noise2(x)
381
+ x = self.layer2_up(x) # 64 * 128 * 128
382
+
383
+ x1 = self.Att1(g=x, x=x1)
384
+ x = torch.cat((x, x1), dim=1) # 128 * 128 * 128
385
+ x = self.layer1(x) # 64 * 128 * 128
386
+ x = self.noise1(x)
387
+ x = self.layer1_up(x) # 32 * 256 * 256
388
+
389
+ x = self.noise0(x)
390
+ x = self.layer0(x) # 3 * 256 * 256
391
+ x = self.tanh(x)
392
+
393
+ return x
394
+
395
+ self.encoder = Encoder()
396
+ self.decoder = Decoder()
397
+ if pretrained:
398
+ print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
399
+ assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
400
+ checkpoint = torch.load('edge2color.pth')
401
+ self.load_state_dict(checkpoint['netG'], strict=True)
402
+ print("Done!")
403
+ else:
404
+ self.apply(weights_init)
405
+ print('Weights of {0} model are initialized'.format('Sketch2Color'))
406
+
407
+ def forward(self, inputs):
408
+ encode = self.encoder(inputs)
409
+ output = self.decoder(encode)
410
+
411
+ return output
412
+
413
+
414
+ class Discriminator(nn.Module):
415
+ def __init__(self, nc=6, pretrained=False):
416
+ super(Discriminator, self).__init__()
417
+ self.conv1 = torch.nn.utils.spectral_norm(nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1))
418
+ self.bn1 = nn.GroupNorm(32, 64)
419
+ self.conv2 = torch.nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1))
420
+ self.bn2 = nn.GroupNorm(32, 128)
421
+ self.conv3 = torch.nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1))
422
+ self.bn3 = nn.GroupNorm(32, 256)
423
+ self.conv4 = torch.nn.utils.spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1))
424
+ self.bn4 = nn.GroupNorm(32, 512)
425
+ self.conv5 = torch.nn.utils.spectral_norm(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1))
426
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
427
+ self.sigmoid = nn.Sigmoid()
428
+
429
+ if pretrained:
430
+ pass
431
+ else:
432
+ self.apply(weights_init)
433
+ print('Weights of {0} model are initialized'.format('Discriminator'))
434
+
435
+ def forward(self, base, unknown):
436
+ input = torch.cat((base, unknown), dim=1)
437
+ x = self.activation(self.conv1(input))
438
+ x = self.activation(self.bn2(self.conv2(x)))
439
+ x = self.activation(self.bn3(self.conv3(x)))
440
+ x = self.activation(self.bn4(self.conv4(x)))
441
+ x = self.sigmoid(self.conv5(x))
442
+
443
+ return x.mean((2, 3))
444
+
445
+
446
+ # To initialize model weights
447
+ def weights_init(model):
448
+ classname = model.__class__.__name__
449
+ if classname.find('Conv') != -1:
450
+ nn.init.normal_(model.weight.data, 0.0, 0.02)
451
+ elif classname.find('Conv2d_WS') != -1:
452
+ nn.init.normal_(model.weight.data, 0.0, 0.02)
453
+ elif classname.find('BatchNorm') != -1:
454
+ nn.init.normal_(model.weight.data, 1.0, 0.02)
455
+ nn.init.constant_(model.bias.data, 0)
456
+ elif classname.find('GroupNorm') != -1:
457
+ nn.init.normal_(model.weight.data, 1.0, 0.02)
458
+ nn.init.constant_(model.bias.data, 0)
459
+ else:
460
+ pass
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.0.1
2
+ gradio==3.38.0
3
+ opencv-python==4.8.0.74
4
+ opencv-transforms==0.0.6