lampongyuen commited on
Commit
c6ad93b
1 Parent(s): 7ce11fa

Upload 9 files

Browse files
Files changed (9) hide show
  1. LICENSE +21 -0
  2. README.md +55 -12
  3. app.py +122 -0
  4. makeup.py +107 -0
  5. model.py +283 -0
  6. requirements.txt +8 -0
  7. resnet.py +109 -0
  8. scarlet.jpg +0 -0
  9. test.py +84 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 zll
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,55 @@
1
- ---
2
- title: Virtual Makeup
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.28.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # face-makeup.PyTorch
2
+ Lip and hair color editor using face parsing maps.
3
+
4
+ <table>
5
+
6
+ <tr>
7
+ <th>&nbsp;</th>
8
+ <th>Hair</th>
9
+ <th>Lip</th>
10
+ </tr>
11
+
12
+ <!-- Line 1: Original Input -->
13
+ <tr>
14
+ <td><em>Original Input</em></td>
15
+ <td><img src="makeup/116_ori.png" height="256" width="256" alt="Original Input"></td>
16
+ <td><img src="makeup/116_lip_ori.png" height="256" width="256" alt="Original Input"></td>
17
+ </tr>
18
+
19
+ <!-- Line 2: Color -->
20
+ <tr>
21
+ <td >Color</td>
22
+ <td><img src="makeup/116_0.png" height="256" width="256" alt="Color"></td>
23
+ <td><img src="makeup/116_6.png" height="256" width="256" alt="Color"></td>
24
+ </tr>
25
+
26
+ <!-- Line 3: Color -->
27
+ <tr>
28
+ <td>Color</td>
29
+ <td><img src="makeup/116_1.png" height="256" width="256" alt="Color"></td>
30
+ <td><img src="makeup/116_3.png" height="256" width="256" alt="Color"></td>
31
+ </tr>
32
+
33
+ <!-- Line 4: Color -->
34
+ <tr>
35
+ <td>Color</td>
36
+ <td><img src="makeup/116_2.png" height="256" width="256" alt="Color"></td>
37
+ <td><img src="makeup/116_4.png" height="256" width="256" alt="Color"></td>
38
+ </tr>
39
+
40
+ </table>
41
+
42
+ ### Using PyTorch 1.0 and python 3.x
43
+
44
+ ## Demo
45
+ Change hair and lip color:
46
+ ```Shell
47
+ python makeup.py --img-path imgs/116.jpg
48
+ ```
49
+ ### Try to use other colors:
50
+ Change the color list in **makeup.py**(line 83)
51
+ ```
52
+ colors = [[230, 50, 20], [20, 70, 180], [20, 70, 180]]
53
+ ```
54
+ ### Train face parsing model (optional)
55
+ Follow this repo [zllrunning/face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+ from skimage.filters import gaussian
5
+ from test import evaluate
6
+ import streamlit as st
7
+ from PIL import Image, ImageColor
8
+
9
+ def sharpen(img):
10
+ img = img * 1.0
11
+ gauss_out = gaussian(img, sigma=5, multichannel=True)
12
+
13
+ alpha = 1.5
14
+ img_out = (img - gauss_out) * alpha + img
15
+
16
+ img_out = img_out / 255.0
17
+
18
+ mask_1 = img_out < 0
19
+ mask_2 = img_out > 1
20
+
21
+ img_out = img_out * (1 - mask_1)
22
+ img_out = img_out * (1 - mask_2) + mask_2
23
+ img_out = np.clip(img_out, 0, 1)
24
+ img_out = img_out * 255
25
+ return np.array(img_out, dtype=np.uint8)
26
+
27
+
28
+ def hair(image, parsing, part=17, color=[230, 50, 20]):
29
+ b, g, r = color #[10, 50, 250] # [10, 250, 10]
30
+ tar_color = np.zeros_like(image)
31
+ tar_color[:, :, 0] = b
32
+ tar_color[:, :, 1] = g
33
+ tar_color[:, :, 2] = r
34
+ np.repeat(parsing[:, :, np.newaxis], 3, axis=2)
35
+
36
+ image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
37
+ tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
38
+
39
+ if part == 12 or part == 13:
40
+ image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
41
+ else:
42
+ image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
43
+
44
+ changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
45
+
46
+ if part == 17:
47
+ changed = sharpen(changed)
48
+
49
+
50
+ changed[parsing != part] = image[parsing != part]
51
+ return changed
52
+
53
+ DEMO_IMAGE = 'imgs/116.jpg'
54
+
55
+ st.title('Virtual Makeup')
56
+
57
+ st.sidebar.title('Virtual Makeup')
58
+ st.sidebar.subheader('Parameters')
59
+
60
+ table = {
61
+ 'hair': 17,
62
+ 'upper_lip': 12,
63
+ 'lower_lip': 13,
64
+
65
+ }
66
+
67
+ img_file_buffer = st.sidebar.file_uploader("Upload an image", type=[ "jpg", "jpeg",'png'])
68
+
69
+ if img_file_buffer is not None:
70
+ image = np.array(Image.open(img_file_buffer))
71
+ demo_image = img_file_buffer
72
+
73
+ else:
74
+ demo_image = DEMO_IMAGE
75
+ image = np.array(Image.open(demo_image))
76
+
77
+ #st.set_option('deprecation.showfileUploaderEncoding', False)
78
+
79
+ new_image = image.copy()
80
+
81
+
82
+
83
+
84
+
85
+ st.subheader('Original Image')
86
+
87
+ st.image(image,use_column_width = True)
88
+
89
+
90
+ cp = 'cp/79999_iter.pth'
91
+ ori = image.copy()
92
+ h,w,_ = ori.shape
93
+
94
+ #print(h)
95
+ #print(w)
96
+ image = cv2.resize(image,(1024,1024))
97
+
98
+ parsing = evaluate(demo_image, cp)
99
+ parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
100
+
101
+ parts = [table['hair'], table['upper_lip'], table['lower_lip']]
102
+
103
+ hair_color = st.sidebar.color_picker('Pick the Hair Color', '#000')
104
+ hair_color = ImageColor.getcolor(hair_color, "RGB")
105
+
106
+ lip_color = st.sidebar.color_picker('Pick the Lip Color', '#edbad1')
107
+
108
+ lip_color = ImageColor.getcolor(lip_color, "RGB")
109
+
110
+
111
+
112
+ colors = [hair_color, lip_color, lip_color]
113
+
114
+ for part, color in zip(parts, colors):
115
+ image = hair(image, parsing, part, color)
116
+
117
+ image = cv2.resize(image,(w,h))
118
+
119
+
120
+ st.subheader('Output Image')
121
+
122
+ st.image(image,use_column_width = True)
makeup.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+ from skimage.filters import gaussian
5
+ from test import evaluate
6
+ import argparse
7
+
8
+
9
+ def parse_args():
10
+ parse = argparse.ArgumentParser()
11
+ parse.add_argument('--img-path', default='imgs/116.jpg')
12
+ return parse.parse_args()
13
+
14
+
15
+ def sharpen(img):
16
+ img = img * 1.0
17
+ gauss_out = gaussian(img, sigma=5, multichannel=True)
18
+
19
+ alpha = 1.5
20
+ img_out = (img - gauss_out) * alpha + img
21
+
22
+ img_out = img_out / 255.0
23
+
24
+ mask_1 = img_out < 0
25
+ mask_2 = img_out > 1
26
+
27
+ img_out = img_out * (1 - mask_1)
28
+ img_out = img_out * (1 - mask_2) + mask_2
29
+ img_out = np.clip(img_out, 0, 1)
30
+ img_out = img_out * 255
31
+ return np.array(img_out, dtype=np.uint8)
32
+
33
+
34
+ def hair(image, parsing, part=17, color=[230, 50, 20]):
35
+ b, g, r = color #[10, 50, 250] # [10, 250, 10]
36
+ tar_color = np.zeros_like(image)
37
+ tar_color[:, :, 0] = b
38
+ tar_color[:, :, 1] = g
39
+ tar_color[:, :, 2] = r
40
+
41
+ image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
42
+ tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
43
+
44
+ if part == 12 or part == 13:
45
+ image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
46
+ else:
47
+ image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
48
+
49
+ changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
50
+
51
+ if part == 17:
52
+ changed = sharpen(changed)
53
+
54
+ changed[parsing != part] = image[parsing != part]
55
+ return changed
56
+
57
+
58
+ if __name__ == '__main__':
59
+ # 1 face
60
+ # 11 teeth
61
+ # 12 upper lip
62
+ # 13 lower lip
63
+ # 17 hair
64
+
65
+ args = parse_args()
66
+
67
+ table = {
68
+ 'hair': 17,
69
+ 'upper_lip': 12,
70
+ 'lower_lip': 13
71
+ }
72
+
73
+ image_path = args.img_path
74
+ cp = 'cp/79999_iter.pth'
75
+
76
+ image = cv2.imread(image_path)
77
+ ori = image.copy()
78
+ parsing = evaluate(image_path, cp)
79
+ parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
80
+
81
+ parts = [table['hair'], table['upper_lip'], table['lower_lip']]
82
+
83
+ colors = [[230, 50, 20], [20, 70, 180], [20, 70, 180]]
84
+
85
+ for part, color in zip(parts, colors):
86
+ image = hair(image, parsing, part, color)
87
+
88
+ #cv2.imshow('image', cv2.resize(ori, (512, 512)))
89
+ cv2.imshow('color', cv2.resize(image, (512, 512)))
90
+
91
+ cv2.waitKey(0)
92
+ cv2.destroyAllWindows()
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ #net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480)
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.9.0
2
+ torchvision==0.10.0
3
+ scikit_image==0.18.2
4
+ streamlit==0.85.0
5
+ numpy==1.18.5
6
+ opencv_python_headless==4.5.2.54
7
+ Pillow==8.3.1
8
+
resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
scarlet.jpg ADDED
test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import os
6
+ from model import BiSeNet
7
+ import os.path as osp
8
+ import numpy as np
9
+ from PIL import Image
10
+ import torchvision.transforms as transforms
11
+ import cv2
12
+
13
+
14
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
15
+ # Colors for all 20 parts
16
+ part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
17
+ [255, 0, 85], [255, 0, 170],
18
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
19
+ [0, 255, 85], [0, 255, 170],
20
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
21
+ [0, 85, 255], [0, 170, 255],
22
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
23
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
24
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
25
+
26
+ im = np.array(im)
27
+ vis_im = im.copy().astype(np.uint8)
28
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
29
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
30
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
31
+
32
+ num_of_class = np.max(vis_parsing_anno)
33
+
34
+ for pi in range(1, num_of_class + 1):
35
+ index = np.where(vis_parsing_anno == pi)
36
+ vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
37
+
38
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
39
+ # print(vis_parsing_anno_color.shape, vis_im.shape)
40
+ vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
41
+
42
+ # Save result or not
43
+ if save_im:
44
+ cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
45
+ cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
46
+ return vis_parsing_anno
47
+ # return vis_im
48
+
49
+
50
+ def evaluate(image_path='./imgs/116.jpg', cp='cp/79999_iter.pth'):
51
+
52
+ # if not os.path.exists(respth):
53
+ # os.makedirs(respth)
54
+
55
+ n_classes = 19
56
+ net = BiSeNet(n_classes=n_classes)
57
+ #net.cuda()
58
+ #net.load_state_dict(torch.load(cp))
59
+ net.load_state_dict(torch.load(cp, map_location=torch.device('cpu')))
60
+ net.eval()
61
+
62
+ to_tensor = transforms.Compose([
63
+ transforms.ToTensor(),
64
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
65
+ ])
66
+
67
+ with torch.no_grad():
68
+ img = Image.open(image_path)
69
+ image = img.resize((512, 512), Image.BILINEAR)
70
+ img = to_tensor(image)
71
+ img = torch.unsqueeze(img, 0)
72
+ #img = img.cuda()
73
+ out = net(img)[0]
74
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
75
+ # print(parsing)
76
+ # print(np.unique(parsing))
77
+
78
+ # vis_parsing_maps(image, parsing, stride=1, save_im=False, save_path=osp.join(respth, dspth))
79
+ return parsing
80
+
81
+ if __name__ == "__main__":
82
+ evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img/116.jpg', cp='79999_iter.pth')
83
+
84
+