JCTN commited on
Commit
a9cb2a0
1 Parent(s): 98522ca

Upload 4 files

Browse files
modelscope_modules/cv_unet_skin_retouching_torch/__pycache__/modelscope_modules_cv_unet_skin_retouching_torch___pycache_____init__.cpython-310.pyc ADDED
Binary file (177 Bytes). View file
 
modelscope_modules/cv_unet_skin_retouching_torch/__pycache__/modelscope_modules_cv_unet_skin_retouching_torch___pycache___ms_wrapper.cpython-310.pyc ADDED
Binary file (8.83 kB). View file
 
modelscope_modules/cv_unet_skin_retouching_torch/modelscope_modules_cv_unet_skin_retouching_torch___init__.py ADDED
File without changes
modelscope_modules/cv_unet_skin_retouching_torch/modelscope_modules_cv_unet_skin_retouching_torch_ms_wrapper.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import os
4
+ from typing import Any, Dict
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import PIL
9
+ import onnxruntime
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torchvision.transforms as transforms
13
+
14
+ from modelscope.utils.config import Config
15
+ from modelscope.metainfo import Pipelines
16
+ from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in import \
17
+ DetectionUNet
18
+ from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \
19
+ RetouchingNet
20
+ from modelscope.models.cv.skin_retouching.unet_deploy import UNet
21
+ from modelscope.models.cv.skin_retouching.utils import * # noqa F403
22
+ from modelscope.outputs import OutputKeys
23
+ from modelscope.pipelines import pipeline
24
+ from modelscope.pipelines.base import Input, Pipeline
25
+ from modelscope.pipelines.builder import PIPELINES
26
+ from modelscope.preprocessors import LoadImage
27
+ from modelscope.utils.constant import ModelFile, Tasks
28
+ from modelscope.utils.device import create_device, device_placement
29
+ from modelscope.utils.logger import get_logger
30
+
31
+
32
+ logger = get_logger()
33
+
34
+
35
+ @PIPELINES.register_module('skin-retouching-torch', module_name='skin-retouching-torch')
36
+ class SkinRetouchingTorchPipeline(Pipeline):
37
+
38
+ def __init__(self, model: str, device: str):
39
+ """
40
+ use `model` to create a skin retouching pipeline for prediction
41
+ Args:
42
+ model: model id on modelscope hub.
43
+ """
44
+ super().__init__(model=model, device=device)
45
+
46
+ device = create_device(self.device_name)
47
+ model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE)
48
+ local_model_path = os.path.join(self.model, 'joint_20210926.pth')
49
+ skin_model_path = os.path.join(self.model, 'model.onnx')
50
+
51
+ self.generator = UNet(3, 3).to(device)
52
+ self.generator.load_state_dict(
53
+ torch.load(model_path, map_location='cpu')['generator'])
54
+ self.generator.eval()
55
+
56
+ det_model_id = 'damo/cv_resnet50_face-detection_retinaface'
57
+ self.detector = pipeline(Tasks.face_detection, model=det_model_id)
58
+ self.detector.detector.to(device)
59
+
60
+ self.local_model_path = local_model_path
61
+ ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu')
62
+ self.inpainting_net = RetouchingNet(
63
+ in_channels=4, out_channels=3).to(device)
64
+ self.detection_net = DetectionUNet(
65
+ n_channels=3, n_classes=1).to(device)
66
+
67
+ self.inpainting_net.load_state_dict(ckpt_dict_load['inpainting_net'])
68
+ self.detection_net.load_state_dict(ckpt_dict_load['detection_net'])
69
+
70
+ self.inpainting_net.eval()
71
+ self.detection_net.eval()
72
+
73
+ self.patch_size = 512
74
+
75
+ self.skin_model_path = skin_model_path
76
+ self.sess, self.input_node_name, self.out_node_name = self.load_onnx_model(
77
+ skin_model_path)
78
+
79
+ self.image_files_transforms = transforms.Compose([
80
+ transforms.ToTensor(),
81
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
82
+ ])
83
+
84
+ self.diffuse_mask = gen_diffuse_mask()
85
+ self.diffuse_mask = torch.from_numpy(
86
+ self.diffuse_mask).to(device).float()
87
+ self.diffuse_mask = self.diffuse_mask.permute(2, 0, 1)[None, ...]
88
+
89
+ self.input_size = 512
90
+ self.device = device
91
+
92
+ def load_onnx_model(self, onnx_path):
93
+ sess = onnxruntime.InferenceSession(onnx_path)
94
+ out_node_name = []
95
+ input_node_name = []
96
+ for node in sess.get_outputs():
97
+ out_node_name.append(node.name)
98
+
99
+ for node in sess.get_inputs():
100
+ input_node_name.append(node.name)
101
+
102
+ return sess, input_node_name, out_node_name
103
+
104
+ def preprocess(self, input: Input) -> Dict[str, Any]:
105
+ img = LoadImage.convert_to_ndarray(input)
106
+ if len(img.shape) == 2:
107
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
108
+ img = img.astype(float)
109
+ result = {'img': img}
110
+ return result
111
+
112
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
113
+ rgb_image = input['img'].cpu().numpy().astype(np.uint8)
114
+
115
+ retouch_local = True
116
+ whitening = True
117
+ degree = 1.0
118
+ whitening_degree = 0.8
119
+ return_mg = False
120
+
121
+ with torch.no_grad():
122
+ if whitening and whitening_degree > 0 and self.skin_model_path is not None:
123
+ rgb_image_small, resize_scale = resize_on_long_side(
124
+ rgb_image, 800)
125
+ input_feed = {}
126
+ input_feed[self.input_node_name[0]] = rgb_image_small.astype('float32')
127
+ skin_mask = self.sess.run(self.out_node_name, input_feed=input_feed)[0]
128
+
129
+ output_pred = torch.from_numpy(rgb_image).to(self.device)
130
+ if return_mg:
131
+ output_mg = np.ones(
132
+ (rgb_image.shape[0], rgb_image.shape[1], 3),
133
+ dtype=np.float32) * 0.5
134
+
135
+ det_results = self.detector(rgb_image)
136
+ # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
137
+ results = []
138
+ for i in range(len(det_results['scores'])):
139
+ info_dict = {}
140
+ info_dict['bbox'] = np.array(det_results['boxes'][i]).astype(
141
+ np.int32).tolist()
142
+ info_dict['score'] = det_results['scores'][i]
143
+ info_dict['landmarks'] = np.array(
144
+ det_results['keypoints'][i]).astype(np.int32).reshape(
145
+ 5, 2).tolist()
146
+ results.append(info_dict)
147
+
148
+ crop_bboxes = get_crop_bbox(results)
149
+
150
+ face_num = len(crop_bboxes)
151
+ if face_num == 0:
152
+ output = {
153
+ 'pred': output_pred.cpu().numpy()[:, :, ::-1],
154
+ 'face_num': face_num
155
+ }
156
+ return output
157
+
158
+ flag_bigKernal = False
159
+ for bbox in crop_bboxes:
160
+ roi, expand, crop_tblr = get_roi_without_padding(
161
+ rgb_image, bbox)
162
+ roi = roi_to_tensor(roi) # bgr -> rgb
163
+
164
+ if roi.shape[2] > 0.4 * rgb_image.shape[0]:
165
+ flag_bigKernal = True
166
+
167
+ roi = roi.to(self.device)
168
+
169
+ roi = preprocess_roi(roi)
170
+
171
+ if retouch_local and self.local_model_path is not None:
172
+ roi = self.retouch_local(roi)
173
+
174
+ roi_output = self.predict_roi(
175
+ roi,
176
+ degree=degree,
177
+ smooth_border=True,
178
+ return_mg=return_mg)
179
+
180
+ roi_pred = roi_output['pred']
181
+ output_pred[crop_tblr[0]:crop_tblr[1],
182
+ crop_tblr[2]:crop_tblr[3]] = roi_pred
183
+
184
+ if return_mg:
185
+ roi_mg = roi_output['pred_mg']
186
+ output_mg[crop_tblr[0]:crop_tblr[1],
187
+ crop_tblr[2]:crop_tblr[3]] = roi_mg
188
+
189
+ if whitening and whitening_degree > 0 and self.skin_model_path is not None:
190
+ output_pred = whiten_img(
191
+ output_pred,
192
+ skin_mask,
193
+ whitening_degree,
194
+ flag_bigKernal=flag_bigKernal)
195
+
196
+ if not isinstance(output_pred, np.ndarray):
197
+ output_pred = output_pred.cpu().numpy()
198
+
199
+ output_pred = output_pred[:, :, ::-1]
200
+
201
+ return {OutputKeys.OUTPUT_IMG: output_pred}
202
+
203
+ def retouch_local(self, image):
204
+ """
205
+ image: rgb
206
+ """
207
+ with torch.no_grad():
208
+ sub_H, sub_W = image.shape[2:]
209
+
210
+ sub_image_standard = F.interpolate(
211
+ image, size=(768, 768), mode='bilinear', align_corners=True)
212
+ sub_mask_pred = torch.sigmoid(
213
+ self.detection_net(sub_image_standard))
214
+ sub_mask_pred = F.interpolate(
215
+ sub_mask_pred, size=(sub_H, sub_W), mode='nearest')
216
+
217
+ sub_mask_pred_hard_low = (sub_mask_pred >= 0.35).float()
218
+ sub_mask_pred_hard_high = (sub_mask_pred >= 0.5).float()
219
+ sub_mask_pred = sub_mask_pred * (
220
+ 1 - sub_mask_pred_hard_high) + sub_mask_pred_hard_high
221
+ sub_mask_pred = sub_mask_pred * sub_mask_pred_hard_low
222
+ sub_mask_pred = 1 - sub_mask_pred
223
+
224
+ sub_H_standard = sub_H if sub_H % self.patch_size == 0 else (
225
+ sub_H // self.patch_size + 1) * self.patch_size
226
+ sub_W_standard = sub_W if sub_W % self.patch_size == 0 else (
227
+ sub_W // self.patch_size + 1) * self.patch_size
228
+
229
+ sub_image_padding = F.pad(
230
+ image,
231
+ pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0,
232
+ 0),
233
+ mode='constant',
234
+ value=0)
235
+ sub_mask_pred_padding = F.pad(
236
+ sub_mask_pred,
237
+ pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0,
238
+ 0),
239
+ mode='constant',
240
+ value=0)
241
+
242
+ sub_image_padding = patch_partition_overlap(
243
+ sub_image_padding, p1=self.patch_size, p2=self.patch_size)
244
+ sub_mask_pred_padding = patch_partition_overlap(
245
+ sub_mask_pred_padding, p1=self.patch_size, p2=self.patch_size)
246
+ B_padding, C_padding, _, _ = sub_image_padding.size()
247
+
248
+ sub_comp_padding_list = []
249
+ for window_item in range(B_padding):
250
+ sub_image_padding_window = sub_image_padding[
251
+ window_item:window_item + 1]
252
+ sub_mask_pred_padding_window = sub_mask_pred_padding[
253
+ window_item:window_item + 1]
254
+
255
+ sub_input_image_padding_window = sub_image_padding_window * sub_mask_pred_padding_window
256
+
257
+ sub_output_padding_window = self.inpainting_net(
258
+ sub_input_image_padding_window,
259
+ sub_mask_pred_padding_window)
260
+ sub_comp_padding_window = sub_input_image_padding_window + (
261
+ 1
262
+ - sub_mask_pred_padding_window) * sub_output_padding_window
263
+
264
+ sub_comp_padding_list.append(sub_comp_padding_window)
265
+
266
+ sub_comp_padding = torch.cat(sub_comp_padding_list, dim=0)
267
+ sub_comp = patch_aggregation_overlap(
268
+ sub_comp_padding,
269
+ h=int(round(sub_H_standard / self.patch_size)),
270
+ w=int(round(sub_W_standard
271
+ / self.patch_size)))[:, :, :sub_H, :sub_W]
272
+
273
+ return sub_comp
274
+
275
+ def predict_roi(self,
276
+ roi,
277
+ degree=1.0,
278
+ smooth_border=False,
279
+ return_mg=False):
280
+ with torch.no_grad():
281
+ image = F.interpolate(
282
+ roi, (self.input_size, self.input_size), mode='bilinear')
283
+
284
+ pred_mg = self.generator(image) # value: 0~1
285
+ pred_mg = (pred_mg - 0.5) * degree + 0.5
286
+ pred_mg = pred_mg.clamp(0.0, 1.0)
287
+ pred_mg = F.interpolate(pred_mg, roi.shape[2:], mode='bilinear')
288
+ pred_mg = pred_mg[0].permute(
289
+ 1, 2, 0) # ndarray, (h, w, 1) or (h0, w0, 3)
290
+ if len(pred_mg.shape) == 2:
291
+ pred_mg = pred_mg[..., None]
292
+
293
+ if smooth_border:
294
+ pred_mg = smooth_border_mg(self.diffuse_mask, pred_mg)
295
+
296
+ image = (roi[0].permute(1, 2, 0) + 1.0) / 2
297
+
298
+ pred = (1 - 2 * pred_mg
299
+ ) * image * image + 2 * pred_mg * image # value: 0~1
300
+
301
+ pred = (pred * 255.0).byte() # ndarray, (h, w, 3), rgb
302
+
303
+ output = {'pred': pred}
304
+ if return_mg:
305
+ output['pred_mg'] = pred_mg.cpu().numpy()
306
+ return output
307
+
308
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
309
+ return inputs
310
+
311
+
312
+ # Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id
313
+ usr_config_path = '/tmp/snapdown/'
314
+ config = Config({
315
+ "framework": 'pytorch',
316
+ "task": 'skin-retouching-torch',
317
+ "pipeline": {"type": "skin-retouching-torch"},
318
+ "allow_remote": True
319
+ })
320
+ config.dump('/tmp/snapdown/' + 'configuration.json')
321
+
322
+ if __name__ == "__main__":
323
+ from modelscope.models import Model
324
+ from modelscope.pipelines import pipeline
325
+ # model = Model.from_pretrained(usr_config_path)
326
+ inference = pipeline('skin-retouching-torch', model=usr_config_path)
327
+ img_name = "skin_retouching_examples_1.jpg"
328
+ output = inference(img_name)
329
+
330
+ cv2.imwrite('result.png', output[OutputKeys.OUTPUT_IMG])
331
+ print(output)