ljsabc commited on
Commit
395d300
·
1 Parent(s): 77edf8c

Initial commit.

Browse files
animeinsseg/__init__.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv, torch
2
+ from tqdm import tqdm
3
+ from einops import rearrange
4
+ import os
5
+ import os.path as osp
6
+ import cv2
7
+ import gc
8
+ import math
9
+
10
+ from .anime_instances import AnimeInstances
11
+ import numpy as np
12
+ from typing import List, Tuple, Union, Optional, Callable
13
+ from mmengine import Config
14
+ from mmengine.model.utils import revert_sync_batchnorm
15
+ from mmdet.utils import register_all_modules, get_test_pipeline_cfg
16
+ from mmdet.apis import init_detector
17
+ from mmdet.registry import MODELS
18
+ from mmdet.structures import DetDataSample, SampleList
19
+ from mmdet.structures.bbox.transforms import scale_boxes, get_box_wh
20
+ from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead
21
+ from pycocotools.coco import COCO
22
+ from mmcv.transforms import Compose
23
+ from mmdet.models.detectors.single_stage import SingleStageDetector
24
+
25
+ from utils.logger import LOGGER
26
+ from utils.io_utils import square_pad_resize, find_all_imgs, imglist2grid, mask2rle, dict2json, scaledown_maxsize, resize_pad
27
+ from utils.constants import DEFAULT_DEVICE, CATEGORIES
28
+ from utils.booru_tagger import Tagger
29
+
30
+ from .models.animeseg_refine import AnimeSegmentation, load_refinenet, get_mask
31
+ from .models.rtmdet_inshead_custom import RTMDetInsSepBNHeadCustom
32
+
33
+ from torchvision.ops.boxes import box_iou
34
+ import torch.nn.functional as F
35
+
36
+
37
+ def prepare_refine_batch(segmentations: np.ndarray, img: np.ndarray, max_batch_size: int = 4, device: str = 'cpu', input_size: int = 720):
38
+
39
+ img, (pt, pb, pl, pr) = resize_pad(img, input_size, pad_value=(0, 0, 0))
40
+
41
+ img = img.transpose((2, 0, 1)).astype(np.float32) / 255.
42
+
43
+ batch = []
44
+ num_seg = len(segmentations)
45
+
46
+ for ii, seg in enumerate(segmentations):
47
+ seg, _ = resize_pad(seg, input_size, 0)
48
+ seg = seg[None, ...]
49
+ batch.append(np.concatenate((img, seg)))
50
+
51
+ if ii == num_seg - 1:
52
+ yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
53
+ elif len(batch) >= max_batch_size:
54
+ yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
55
+ batch = []
56
+
57
+
58
+ VALID_REFINEMETHODS = {'animeseg', 'none'}
59
+
60
+ register_all_modules()
61
+
62
+
63
+ def single_image_preprocess(img: Union[str, np.ndarray], pipeline: Compose):
64
+ if isinstance(img, str):
65
+ img = mmcv.imread(img)
66
+ elif not isinstance(img, np.ndarray):
67
+ raise NotImplementedError
68
+
69
+ # img = square_pad_resize(img, 1024)[0]
70
+
71
+ data_ = dict(img=img, img_id=0)
72
+ data_ = pipeline(data_)
73
+ data_['inputs'] = [data_['inputs']]
74
+ data_['data_samples'] = [data_['data_samples']]
75
+
76
+ return data_, img
77
+
78
+ def animeseg_refine(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
79
+
80
+ num_pred = len(det_pred.pred_instances)
81
+ if num_pred < 1:
82
+ return
83
+
84
+ with torch.no_grad():
85
+ if to_rgb:
86
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
87
+ seg_thr = 0.5
88
+ mask = get_mask(net, img, s=input_size)[..., 0]
89
+ mask = (mask > seg_thr)
90
+
91
+ ins_masks = det_pred.pred_instances.masks
92
+
93
+ if isinstance(ins_masks, torch.Tensor):
94
+ tensor_device = ins_masks.device
95
+ tensor_dtype = ins_masks.dtype
96
+ to_tensor = True
97
+ ins_masks = ins_masks.cpu().numpy()
98
+
99
+ area_original = np.sum(ins_masks, axis=(1, 2))
100
+ masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
101
+ area_refined = np.sum(masks_refined, axis=(1, 2))
102
+
103
+ for ii in range(num_pred):
104
+ if area_refined[ii] / area_original[ii] > 0.3:
105
+ ins_masks[ii] = masks_refined[ii]
106
+ ins_masks = np.ascontiguousarray(ins_masks)
107
+
108
+ # for ii, insm in enumerate(ins_masks):
109
+ # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
110
+
111
+ if to_tensor:
112
+ ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
113
+
114
+ det_pred.pred_instances.masks = ins_masks
115
+ # rst = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
116
+ # cv2.imwrite('rst.png', rst)
117
+
118
+
119
+ # def refinenet_forward(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
120
+
121
+ # num_pred = len(det_pred.pred_instances)
122
+ # if num_pred < 1:
123
+ # return
124
+
125
+ # with torch.no_grad():
126
+ # if to_rgb:
127
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
128
+ # seg_thr = 0.5
129
+
130
+ # h0, w0 = h, w = img.shape[0], img.shape[1]
131
+ # if h > w:
132
+ # h, w = input_size, int(input_size * w / h)
133
+ # else:
134
+ # h, w = int(input_size * h / w), input_size
135
+ # ph, pw = input_size - h, input_size - w
136
+ # tmpImg = np.zeros([s, s, 3], dtype=np.float32)
137
+ # tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
138
+ # tmpImg = tmpImg.transpose((2, 0, 1))
139
+ # tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
140
+ # with torch.no_grad():
141
+ # if use_amp:
142
+ # with amp.autocast():
143
+ # pred = model(tmpImg)
144
+ # pred = pred.to(dtype=torch.float32)
145
+ # else:
146
+ # pred = model(tmpImg)
147
+ # pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
148
+ # pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
149
+ # return pred
150
+
151
+ # mask = (mask > seg_thr)
152
+
153
+ # ins_masks = det_pred.pred_instances.masks
154
+
155
+ # if isinstance(ins_masks, torch.Tensor):
156
+ # tensor_device = ins_masks.device
157
+ # tensor_dtype = ins_masks.dtype
158
+ # to_tensor = True
159
+ # ins_masks = ins_masks.cpu().numpy()
160
+
161
+ # area_original = np.sum(ins_masks, axis=(1, 2))
162
+ # masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
163
+ # area_refined = np.sum(masks_refined, axis=(1, 2))
164
+
165
+ # for ii in range(num_pred):
166
+ # if area_refined[ii] / area_original[ii] > 0.3:
167
+ # ins_masks[ii] = masks_refined[ii]
168
+ # ins_masks = np.ascontiguousarray(ins_masks)
169
+
170
+ # # for ii, insm in enumerate(ins_masks):
171
+ # # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
172
+
173
+ # if to_tensor:
174
+ # ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
175
+
176
+ # det_pred.pred_instances.masks = ins_masks
177
+
178
+
179
+ def read_imglst_from_txt(filep) -> List[str]:
180
+ with open(filep, 'r', encoding='utf8') as f:
181
+ lines = f.read().splitlines()
182
+ return lines
183
+
184
+
185
+ class AnimeInsSeg:
186
+
187
+ def __init__(self, ckpt: str, default_det_size: int = 640, device: str = None,
188
+ refine_kwargs: dict = {'refine_method': 'refinenet_isnet'},
189
+ tagger_path: str = 'models/wd-v1-4-swinv2-tagger-v2/model.onnx', mask_thr=0.3) -> None:
190
+ self.ckpt = ckpt
191
+ self.default_det_size = default_det_size
192
+ self.device = DEFAULT_DEVICE if device is None else device
193
+
194
+ # init detector in mmdet's way
195
+
196
+ ckpt = torch.load(ckpt, map_location='cpu')
197
+ cfg = Config.fromstring(ckpt['meta']['cfg'].replace('file_client_args', 'backend_args'), file_format='.py')
198
+ cfg.visualizer = []
199
+ cfg.vis_backends = {}
200
+ cfg.default_hooks.pop('visualization')
201
+
202
+
203
+ # self.model: SingleStageDetector = init_detector(cfg, checkpoint=None, device='cpu')
204
+ model = MODELS.build(cfg.model)
205
+ model = revert_sync_batchnorm(model)
206
+
207
+ self.model = model.to(self.device).eval()
208
+ self.model.load_state_dict(ckpt['state_dict'], strict=False)
209
+ self.model = self.model.to(self.device).eval()
210
+ self.cfg = cfg.copy()
211
+
212
+ test_pipeline = get_test_pipeline_cfg(self.cfg.copy())
213
+ test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
214
+ test_pipeline = Compose(test_pipeline)
215
+ self.default_data_pipeline = test_pipeline
216
+
217
+ self.refinenet = None
218
+ self.refinenet_animeseg: AnimeSegmentation = None
219
+ self.postprocess_refine: Callable = None
220
+
221
+ if refine_kwargs is not None:
222
+ self.set_refine_method(**refine_kwargs)
223
+
224
+ self.tagger = None
225
+ self.tagger_path = tagger_path
226
+
227
+ self.mask_thr = mask_thr
228
+
229
+ def init_tagger(self, tagger_path: str = None):
230
+ tagger_path = self.tagger_path if tagger_path is None else tagger_path
231
+ self.tagger = Tagger(self.tagger_path)
232
+
233
+ def infer_tags(self, instances: AnimeInstances, img: np.ndarray, infer_grey: bool = False):
234
+ if self.tagger is None:
235
+ self.init_tagger()
236
+
237
+ if infer_grey:
238
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None][..., [0, 0, 0]]
239
+
240
+ num_ins = len(instances)
241
+ for ii in range(num_ins):
242
+ bbox = instances.bboxes[ii]
243
+ mask = instances.masks[ii]
244
+ if isinstance(bbox, torch.Tensor):
245
+ bbox = bbox.cpu().numpy()
246
+ mask = mask.cpu().numpy()
247
+ bbox = bbox.astype(np.int32)
248
+
249
+ crop = img[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]].copy()
250
+ mask = mask[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]]
251
+ crop[mask == 0] = 255
252
+ tags, character_tags = self.tagger.label_cv2_bgr(crop)
253
+ exclude_tags = ['simple_background', 'white_background']
254
+ valid_tags = []
255
+ for tag in tags:
256
+ if tag in exclude_tags:
257
+ continue
258
+ valid_tags.append(tag)
259
+ instances.tags[ii] = ' '.join(valid_tags)
260
+ instances.character_tags[ii] = character_tags
261
+
262
+ @torch.no_grad()
263
+ def infer_embeddings(self, imgs, det_size = None):
264
+
265
+ def hijack_bbox_mask_post_process(
266
+ self,
267
+ results,
268
+ mask_feat,
269
+ cfg,
270
+ rescale: bool = False,
271
+ with_nms: bool = True,
272
+ img_meta: Optional[dict] = None):
273
+
274
+ stride = self.prior_generator.strides[0][0]
275
+ if rescale:
276
+ assert img_meta.get('scale_factor') is not None
277
+ scale_factor = [1 / s for s in img_meta['scale_factor']]
278
+ results.bboxes = scale_boxes(results.bboxes, scale_factor)
279
+
280
+ if hasattr(results, 'score_factors'):
281
+ # TODO: Add sqrt operation in order to be consistent with
282
+ # the paper.
283
+ score_factors = results.pop('score_factors')
284
+ results.scores = results.scores * score_factors
285
+
286
+ # filter small size bboxes
287
+ if cfg.get('min_bbox_size', -1) >= 0:
288
+ w, h = get_box_wh(results.bboxes)
289
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
290
+ if not valid_mask.all():
291
+ results = results[valid_mask]
292
+
293
+ # results.mask_feat = mask_feat
294
+ return results, mask_feat
295
+
296
+ def hijack_detector_predict(self: SingleStageDetector,
297
+ batch_inputs: torch.Tensor,
298
+ batch_data_samples: SampleList,
299
+ rescale: bool = True) -> SampleList:
300
+ x = self.extract_feat(batch_inputs)
301
+
302
+ bbox_head: RTMDetInsSepBNHeadCustom = self.bbox_head
303
+ old_postprocess = RTMDetInsSepBNHeadCustom._bbox_mask_post_process
304
+ RTMDetInsSepBNHeadCustom._bbox_mask_post_process = hijack_bbox_mask_post_process
305
+ # results_list = bbox_head.predict(
306
+ # x, batch_data_samples, rescale=rescale)
307
+
308
+ batch_img_metas = [
309
+ data_samples.metainfo for data_samples in batch_data_samples
310
+ ]
311
+
312
+ outs = bbox_head(x)
313
+
314
+ results_list = bbox_head.predict_by_feat(
315
+ *outs, batch_img_metas=batch_img_metas, rescale=rescale)
316
+
317
+ # batch_data_samples = self.add_pred_to_datasample(
318
+ # batch_data_samples, results_list)
319
+
320
+ RTMDetInsSepBNHeadCustom._bbox_mask_post_process = old_postprocess
321
+ return results_list
322
+
323
+ old_predict = SingleStageDetector.predict
324
+ SingleStageDetector.predict = hijack_detector_predict
325
+ test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
326
+
327
+ if len(imgs) > 1:
328
+ imgs = tqdm(imgs)
329
+ model = self.model
330
+ img = imgs[0]
331
+ data_, img = test_pipeline(img)
332
+ data = model.data_preprocessor(data_, False)
333
+ instance_data, mask_feat = model(**data, mode='predict')[0]
334
+ SingleStageDetector.predict = old_predict
335
+
336
+ # print((instance_data.scores > 0.9).sum())
337
+ return img, instance_data, mask_feat
338
+
339
+ def segment_with_bboxes(self, img, bboxes: torch.Tensor, instance_data, mask_feat: torch.Tensor):
340
+ # instance_data.bboxes: x1, y1, x2, y2
341
+ maxidx = torch.argmax(instance_data.scores)
342
+ bbox = instance_data.bboxes[maxidx].cpu().numpy()
343
+ p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
344
+ tgt_bboxes = instance_data.bboxes
345
+
346
+ im_h, im_w = img.shape[:2]
347
+ long_side = max(im_h, im_w)
348
+ bbox_head: RTMDetInsSepBNHeadCustom = self.model.bbox_head
349
+ priors, kernels = instance_data.priors, instance_data.kernels
350
+ stride = bbox_head.prior_generator.strides[0][0]
351
+
352
+ ins_bboxes, ins_segs, scores = [], [], []
353
+ for bbox in bboxes:
354
+ bbox = torch.from_numpy(np.array([bbox])).to(tgt_bboxes.dtype).to(tgt_bboxes.device)
355
+ ioulst = box_iou(bbox, tgt_bboxes).squeeze()
356
+ matched_idx = torch.argmax(ioulst)
357
+
358
+ mask_logits = bbox_head._mask_predict_by_feat_single(
359
+ mask_feat, kernels[matched_idx][None, ...], priors[matched_idx][None, ...])
360
+
361
+ mask_logits = F.interpolate(
362
+ mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
363
+
364
+ mask_logits = F.interpolate(
365
+ mask_logits,
366
+ size=[long_side, long_side],
367
+ mode='bilinear',
368
+ align_corners=False)[..., :im_h, :im_w]
369
+ mask = mask_logits.sigmoid().squeeze()
370
+ mask = mask > 0.5
371
+ mask = mask.cpu().numpy()
372
+ ins_segs.append(mask)
373
+
374
+ matched_iou_score = ioulst[matched_idx]
375
+ matched_score = instance_data.scores[matched_idx]
376
+ scores.append(matched_score.cpu().item())
377
+ matched_bbox = tgt_bboxes[matched_idx]
378
+
379
+ ins_bboxes.append(matched_bbox.cpu().numpy())
380
+ # p1, p2 = (int(matched_bbox[0]), int(matched_bbox[1])), (int(matched_bbox[2]), int(matched_bbox[3]))
381
+
382
+ if len(ins_bboxes) > 0:
383
+ ins_bboxes = np.array(ins_bboxes).astype(np.int32)
384
+ ins_bboxes[:, 2:] -= ins_bboxes[:, :2]
385
+ ins_segs = np.array(ins_segs)
386
+ instances = AnimeInstances(ins_segs, ins_bboxes, scores)
387
+
388
+ self._postprocess_refine(instances, img)
389
+ drawed = instances.draw_instances(img)
390
+ # cv2.imshow('drawed', drawed)
391
+ # cv2.waitKey(0)
392
+
393
+ return instances
394
+
395
+ def set_detect_size(self, det_size: Union[int, Tuple]):
396
+ if isinstance(det_size, int):
397
+ det_size = (det_size, det_size)
398
+ self.default_data_pipeline.transforms[1].scale = det_size
399
+ self.default_data_pipeline.transforms[2].size = det_size
400
+
401
+ @torch.no_grad()
402
+ def infer(self, imgs: Union[List, str, np.ndarray],
403
+ pred_score_thr: float = 0.3,
404
+ refine_kwargs: dict = None,
405
+ output_type: str="tensor",
406
+ det_size: int = None,
407
+ save_dir: str = '',
408
+ save_visualization: bool = False,
409
+ save_annotation: str = '',
410
+ infer_tags: bool = False,
411
+ obj_id_start: int = -1,
412
+ img_id_start: int = -1,
413
+ verbose: bool = False,
414
+ infer_grey: bool = False,
415
+ save_mask_only: bool = False,
416
+ val_dir=None,
417
+ max_instances: int = 100,
418
+ **kwargs) -> Union[List[AnimeInstances], AnimeInstances, None]:
419
+
420
+ """
421
+ Args:
422
+ imgs (str, ndarray, Sequence[str/ndarray]):
423
+ Either image files or loaded images.
424
+
425
+ Returns:
426
+ :obj:`AnimeInstances` or list[:obj:`AnimeInstances`]:
427
+ If save_annotation or save_annotation, return None.
428
+ """
429
+
430
+ if det_size is not None:
431
+ self.set_detect_size(det_size)
432
+ if refine_kwargs is not None:
433
+ self.set_refine_method(**refine_kwargs)
434
+
435
+ self.set_max_instance(max_instances)
436
+
437
+ if isinstance(imgs, str):
438
+ if imgs.endswith('.txt'):
439
+ imgs = read_imglst_from_txt(imgs)
440
+
441
+ if save_annotation or save_visualization:
442
+ return self._infer_save_annotations(imgs, pred_score_thr, det_size, save_dir, save_visualization, \
443
+ save_annotation, infer_tags, obj_id_start, img_id_start, val_dir=val_dir)
444
+ else:
445
+ return self._infer_simple(imgs, pred_score_thr, det_size, output_type, infer_tags, verbose=verbose, infer_grey=infer_grey)
446
+
447
+ def _det_forward(self, img, test_pipeline, pred_score_thr: float = 0.3) -> Tuple[AnimeInstances, np.ndarray]:
448
+ data_, img = test_pipeline(img)
449
+ with torch.no_grad():
450
+ results: DetDataSample = self.model.test_step(data_)[0]
451
+ pred_instances = results.pred_instances
452
+ pred_instances = pred_instances[pred_instances.scores > pred_score_thr]
453
+ if len(pred_instances) < 1:
454
+ return AnimeInstances(), img
455
+
456
+ del data_
457
+
458
+ bboxes = pred_instances.bboxes.to(torch.int32)
459
+ bboxes[:, 2:] -= bboxes[:, :2]
460
+ masks = pred_instances.masks
461
+ scores = pred_instances.scores
462
+ return AnimeInstances(masks, bboxes, scores), img
463
+
464
+ def _infer_simple(self, imgs: Union[List, str, np.ndarray],
465
+ pred_score_thr: float = 0.3,
466
+ det_size: int = None,
467
+ output_type: str = "tensor",
468
+ infer_tags: bool = False,
469
+ infer_grey: bool = False,
470
+ verbose: bool = False) -> Union[DetDataSample, List[DetDataSample]]:
471
+
472
+ if isinstance(imgs, List):
473
+ return_list = True
474
+ else:
475
+ return_list = False
476
+
477
+ assert output_type in {'tensor', 'numpy'}
478
+
479
+ test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
480
+ predictions = []
481
+
482
+ if len(imgs) > 1:
483
+ imgs = tqdm(imgs)
484
+
485
+ for img in imgs:
486
+ instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
487
+ # drawed = instances.draw_instances(img)
488
+ # cv2.imwrite('drawed.jpg', drawed)
489
+ self.postprocess_results(instances, img)
490
+ # drawed = instances.draw_instances(img)
491
+ # cv2.imwrite('drawed_post.jpg', drawed)
492
+
493
+ if infer_tags:
494
+ self.infer_tags(instances, img, infer_grey)
495
+
496
+ if output_type == 'numpy':
497
+ instances.to_numpy()
498
+
499
+ predictions.append(instances)
500
+
501
+ if return_list:
502
+ return predictions
503
+ else:
504
+ return predictions[0]
505
+
506
+ def _infer_save_annotations(self, imgs: Union[List, str, np.ndarray],
507
+ pred_score_thr: float = 0.3,
508
+ det_size: int = None,
509
+ save_dir: str = '',
510
+ save_visualization: bool = False,
511
+ save_annotation: str = '',
512
+ infer_tags: bool = False,
513
+ obj_id_start: int = 100000000000,
514
+ img_id_start: int = 100000000000,
515
+ save_mask_only: bool = False,
516
+ val_dir = None,
517
+ **kwargs) -> None:
518
+
519
+ coco_api = None
520
+ if isinstance(imgs, str) and imgs.endswith('.json'):
521
+ coco_api = COCO(imgs)
522
+
523
+ if val_dir is None:
524
+ val_dir = osp.join(osp.dirname(osp.dirname(imgs)), 'val')
525
+ imgs = coco_api.getImgIds()
526
+ imgp2ids = {}
527
+ imgps, coco_imgmetas = [], []
528
+ for imgid in imgs:
529
+ imeta = coco_api.loadImgs(imgid)[0]
530
+ imgname = imeta['file_name']
531
+ imgp = osp.join(val_dir, imgname)
532
+ imgp2ids[imgp] = imgid
533
+ imgps.append(imgp)
534
+ coco_imgmetas.append(imeta)
535
+ imgs = imgps
536
+
537
+ test_pipeline, imgs, target_dir = self.prepare_data_pipeline(imgs, det_size)
538
+ if save_dir == '':
539
+ save_dir = osp.join(target_dir, \
540
+ osp.basename(self.ckpt).replace('.ckpt', '').replace('.pth', '').replace('.pt', ''))
541
+
542
+ if not osp.exists(save_dir):
543
+ os.makedirs(save_dir)
544
+
545
+ det_annotations = []
546
+ image_meta = []
547
+ obj_id = obj_id_start + 1
548
+ image_id = img_id_start + 1
549
+
550
+ for ii, img in enumerate(tqdm(imgs)):
551
+ # prepare data
552
+ if isinstance(img, str):
553
+ img_name = osp.basename(img)
554
+ else:
555
+ img_name = f'{ii}'.zfill(12) + '.jpg'
556
+
557
+ if coco_api is not None:
558
+ image_id = imgp2ids[img]
559
+
560
+ try:
561
+ instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
562
+ except Exception as e:
563
+ raise e
564
+ if isinstance(e, torch.cuda.OutOfMemoryError):
565
+ gc.collect()
566
+ torch.cuda.empty_cache()
567
+ torch.cuda.ipc_collect()
568
+ try:
569
+ instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
570
+ except:
571
+ LOGGER.warning(f'cuda out of memory: {img_name}')
572
+ if isinstance(img, str):
573
+ img = cv2.imread(img)
574
+ instances = None
575
+
576
+ if instances is not None:
577
+ self.postprocess_results(instances, img)
578
+
579
+ if infer_tags:
580
+ self.infer_tags(instances, img)
581
+
582
+ if save_visualization:
583
+ out_file = osp.join(save_dir, img_name)
584
+ self.save_visualization(out_file, img, instances)
585
+
586
+ if save_annotation:
587
+ im_h, im_w = img.shape[:2]
588
+ image_meta.append({
589
+ "id": image_id,"height": im_h,"width": im_w,
590
+ "file_name": img_name, "id": image_id
591
+ })
592
+ if instances is not None:
593
+ for ii in range(len(instances)):
594
+ segmentation = instances.masks[ii].squeeze().cpu().numpy().astype(np.uint8)
595
+ area = segmentation.sum()
596
+ segmentation *= 255
597
+ if save_mask_only:
598
+ cv2.imwrite(osp.join(save_dir, 'mask_' + str(ii).zfill(3) + '_' +img_name+'.png'), segmentation)
599
+ else:
600
+ score = instances.scores[ii]
601
+ if isinstance(score, torch.Tensor):
602
+ score = score.item()
603
+ score = float(score)
604
+ bbox = instances.bboxes[ii].cpu().numpy()
605
+ bbox = bbox.astype(np.float32).tolist()
606
+ segmentation = mask2rle(segmentation)
607
+ tag_string = instances.tags[ii]
608
+ tag_string_character = instances.character_tags[ii]
609
+ det_annotations.append({'id': obj_id, 'category_id': 0, 'iscrowd': 0, 'score': score,
610
+ 'segmentation': segmentation, 'image_id': image_id, 'area': area,
611
+ 'tag_string': tag_string, 'tag_string_character': tag_string_character, 'bbox': bbox
612
+ })
613
+ obj_id += 1
614
+ image_id += 1
615
+
616
+ if save_annotation != '' and not save_mask_only:
617
+ det_meta = {"info": {},"licenses": [], "images": image_meta,
618
+ "annotations": det_annotations, "categories": CATEGORIES}
619
+ detp = save_annotation
620
+ dict2json(det_meta, detp)
621
+ LOGGER.info(f'annotations saved to {detp}')
622
+
623
+ def set_refine_method(self, refine_method: str = 'none', refine_size: int = 720):
624
+ if refine_method == 'none':
625
+ self.postprocess_refine = None
626
+ elif refine_method == 'animeseg':
627
+ if self.refinenet_animeseg is None:
628
+ self.refinenet_animeseg = load_refinenet(refine_method)
629
+ self.postprocess_refine = lambda det_pred, img: \
630
+ animeseg_refine(det_pred, img, self.refinenet_animeseg, True, refine_size)
631
+ elif refine_method == 'refinenet_isnet':
632
+ if self.refinenet is None:
633
+ self.refinenet = load_refinenet(refine_method)
634
+ self.postprocess_refine = self._postprocess_refine
635
+ else:
636
+ raise NotImplementedError(f'Invalid refine method: {refine_method}')
637
+
638
+ def _postprocess_refine(self, instances: AnimeInstances, img: np.ndarray, refine_size: int = 720, max_refine_batch: int = 4, **kwargs):
639
+
640
+ if instances.is_empty:
641
+ return
642
+
643
+ segs = instances.masks
644
+ is_tensor = instances.is_tensor
645
+ if is_tensor:
646
+ segs = segs.cpu().numpy()
647
+ segs = segs.astype(np.float32)
648
+ im_h, im_w = img.shape[:2]
649
+
650
+ masks = []
651
+ with torch.no_grad():
652
+ for batch, (pt, pb, pl, pr) in prepare_refine_batch(segs, img, max_refine_batch, self.device, refine_size):
653
+ preds = self.refinenet(batch)[0][0].sigmoid()
654
+ if pb == 0:
655
+ pb = -im_h
656
+ if pr == 0:
657
+ pr = -im_w
658
+ preds = preds[..., pt: -pb, pl: -pr]
659
+ preds = torch.nn.functional.interpolate(preds, (im_h, im_w), mode='bilinear', align_corners=True)
660
+ masks.append(preds.cpu()[:, 0])
661
+
662
+ masks = (torch.concat(masks, dim=0) > self.mask_thr).to(self.device)
663
+ if not is_tensor:
664
+ masks = masks.cpu().numpy()
665
+ instances.masks = masks
666
+
667
+
668
+ def prepare_data_pipeline(self, imgs: Union[str, np.ndarray, List], det_size: int) -> Tuple[Compose, List, str]:
669
+
670
+ if det_size is None:
671
+ det_size = self.default_det_size
672
+
673
+ target_dir = './workspace/output'
674
+ # cast imgs to a list of np.ndarray or image_file_path if necessary
675
+ if isinstance(imgs, str):
676
+ if osp.isdir(imgs):
677
+ target_dir = imgs
678
+ imgs = find_all_imgs(imgs, abs_path=True)
679
+ elif osp.isfile(imgs):
680
+ target_dir = osp.dirname(imgs)
681
+ imgs = [imgs]
682
+ elif isinstance(imgs, np.ndarray) or isinstance(imgs, str):
683
+ imgs = [imgs]
684
+ elif isinstance(imgs, List):
685
+ if len(imgs) > 0:
686
+ if isinstance(imgs[0], np.ndarray) or isinstance(imgs[0], str):
687
+ pass
688
+ else:
689
+ raise NotImplementedError
690
+ else:
691
+ raise NotImplementedError
692
+
693
+ test_pipeline = lambda img: single_image_preprocess(img, pipeline=self.default_data_pipeline)
694
+ return test_pipeline, imgs, target_dir
695
+
696
+ def save_visualization(self, out_file: str, img: np.ndarray, instances: AnimeInstances):
697
+ drawed = instances.draw_instances(img)
698
+ mmcv.imwrite(drawed, out_file)
699
+
700
+ def postprocess_results(self, results: DetDataSample, img: np.ndarray) -> None:
701
+ if self.postprocess_refine is not None:
702
+ self.postprocess_refine(results, img)
703
+
704
+ def set_mask_threshold(self, mask_thr: float):
705
+ self.model.bbox_head.test_cfg['mask_thr_binary'] = mask_thr
706
+
707
+ def set_max_instance(self, num_ins):
708
+ self.model.bbox_head.test_cfg['max_per_img'] = num_ins
animeinsseg/anime_instances.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from typing import List, Union, Tuple
4
+ import torch
5
+ from utils.constants import COLOR_PALETTE
6
+ from utils.constants import get_color
7
+ import cv2
8
+
9
+ def tags2multilines(tags: Union[str, List], lw, tf, max_width):
10
+ if isinstance(tags, str):
11
+ taglist = tags.split(' ')
12
+ else:
13
+ taglist = tags
14
+
15
+ sz = cv2.getTextSize(' ', 0, lw / 3, tf)
16
+ line_height = sz[0][1]
17
+ line_width = 0
18
+ if len(taglist) > 0:
19
+ lines = [taglist[0]]
20
+ if len(taglist) > 1:
21
+ for t in taglist[1:]:
22
+ textl = len(t) * line_height
23
+ if line_width + line_height + textl > max_width:
24
+ lines.append(t)
25
+ line_width = 0
26
+ else:
27
+ line_width = line_width + line_height + textl
28
+ lines[-1] = lines[-1] + ' ' + t
29
+ return lines, line_height
30
+
31
+ class AnimeInstances:
32
+
33
+ def __init__(self,
34
+ masks: Union[np.ndarray, torch.Tensor ]= None,
35
+ bboxes: Union[np.ndarray, torch.Tensor ] = None,
36
+ scores: Union[np.ndarray, torch.Tensor ] = None,
37
+ tags: List[str] = None, character_tags: List[str] = None) -> None:
38
+ self.masks = masks
39
+ self.tags = tags
40
+ self.bboxes = bboxes
41
+
42
+
43
+ if scores is None:
44
+ scores = [1.] * len(self)
45
+ if self.is_numpy:
46
+ scores = np.array(scores)
47
+ elif self.is_tensor:
48
+ scores = torch.tensor(scores)
49
+
50
+ self.scores = scores
51
+
52
+ if tags is None:
53
+ self.tags = [''] * len(self)
54
+ self.character_tags = [''] * len(self)
55
+ else:
56
+ self.tags = tags
57
+ self.character_tags = character_tags
58
+
59
+ @property
60
+ def is_cuda(self):
61
+ if isinstance(self.masks, torch.Tensor) and self.masks.is_cuda:
62
+ return True
63
+ else:
64
+ return False
65
+
66
+ @property
67
+ def is_tensor(self):
68
+ if self.is_empty:
69
+ return False
70
+ else:
71
+ return isinstance(self.masks, torch.Tensor)
72
+
73
+ @property
74
+ def is_numpy(self):
75
+ if self.is_empty:
76
+ return True
77
+ else:
78
+ return isinstance(self.masks, np.ndarray)
79
+
80
+ @property
81
+ def is_empty(self):
82
+ return self.masks is None or len(self.masks) == 0\
83
+
84
+ def remove_duplicated(self):
85
+
86
+ num_masks = len(self)
87
+ if num_masks < 2:
88
+ return
89
+
90
+ need_cvt = False
91
+ if self.is_numpy:
92
+ need_cvt = True
93
+ self.to_tensor()
94
+
95
+ mask_areas = torch.Tensor([mask.sum() for mask in self.masks])
96
+ sids = torch.argsort(mask_areas, descending=True)
97
+ sids = sids.cpu().numpy().tolist()
98
+ mask_areas = mask_areas[sids]
99
+ masks = self.masks[sids]
100
+ bboxes = self.bboxes[sids]
101
+ tags = [self.tags[sid] for sid in sids]
102
+ scores = self.scores[sids]
103
+
104
+ canvas = masks[0]
105
+
106
+ valid_ids: List = np.arange(num_masks).tolist()
107
+ for ii, mask in enumerate(masks[1:]):
108
+
109
+ mask_id = ii + 1
110
+ canvas_and = torch.bitwise_and(canvas, mask)
111
+
112
+ and_area = canvas_and.sum()
113
+ mask_area = mask_areas[mask_id]
114
+
115
+ if and_area / mask_area > 0.8:
116
+ valid_ids.remove(mask_id)
117
+ elif mask_id != num_masks - 1:
118
+ canvas = torch.bitwise_or(canvas, mask)
119
+
120
+ sids = valid_ids
121
+ self.masks = masks[sids]
122
+ self.bboxes = bboxes[sids]
123
+ self.tags = [tags[sid] for sid in sids]
124
+ self.scores = scores[sids]
125
+
126
+ if need_cvt:
127
+ self.to_numpy()
128
+
129
+ # sids =
130
+
131
+ def draw_instances(self,
132
+ img: np.ndarray,
133
+ draw_bbox: bool = True,
134
+ draw_ins_mask: bool = True,
135
+ draw_ins_contour: bool = True,
136
+ draw_tags: bool = False,
137
+ draw_indices: List = None,
138
+ mask_alpha: float = 0.4):
139
+
140
+ mask_alpha = 0.75
141
+
142
+
143
+ drawed = img.copy()
144
+
145
+ if self.is_empty:
146
+ return drawed
147
+
148
+ im_h, im_w = img.shape[:2]
149
+
150
+ mask_shape = self.masks[0].shape
151
+ if mask_shape[0] != im_h or mask_shape[1] != im_w:
152
+ drawed = cv2.resize(drawed, (mask_shape[1], mask_shape[0]), interpolation=cv2.INTER_AREA)
153
+ im_h, im_w = mask_shape[0], mask_shape[1]
154
+
155
+ if draw_indices is None:
156
+ draw_indices = list(range(len(self)))
157
+ ins_dict = {'mask': [], 'tags': [], 'score': [], 'bbox': [], 'character_tags': []}
158
+ colors = []
159
+ for idx in draw_indices:
160
+ ins = self.get_instance(idx, out_type='numpy')
161
+ for key, data in ins.items():
162
+ ins_dict[key].append(data)
163
+ colors.append(get_color(idx))
164
+
165
+ if draw_bbox:
166
+ lw = max(round(sum(drawed.shape) / 2 * 0.003), 2)
167
+ for color, bbox in zip(colors, ins_dict['bbox']):
168
+ p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2] + bbox[0]), int(bbox[3] + bbox[1]))
169
+ cv2.rectangle(drawed, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
170
+
171
+ if draw_ins_mask:
172
+ drawed = drawed.astype(np.float32)
173
+ for color, mask in zip(colors, ins_dict['mask']):
174
+ p = mask.astype(np.float32)
175
+ blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
176
+ alpha_msk = (mask_alpha * p)[..., None]
177
+ alpha_ori = 1 - alpha_msk
178
+ drawed = drawed * alpha_ori + alpha_msk * blend_mask
179
+ drawed = drawed.astype(np.uint8)
180
+
181
+ if draw_tags:
182
+ lw = max(round(sum(drawed.shape) / 2 * 0.002), 2)
183
+ tf = max(lw - 1, 1)
184
+ for color, tags, bbox in zip(colors, ins_dict['tags'], ins_dict['bbox']):
185
+ if not tags:
186
+ continue
187
+ lines, line_height = tags2multilines(tags, lw, tf, bbox[2])
188
+ for ii, l in enumerate(lines):
189
+ xy = (bbox[0], bbox[1] + line_height + int(line_height * 1.2 * ii))
190
+ cv2.putText(drawed, l, xy, 0, lw / 3, color, thickness=tf, lineType=cv2.LINE_AA)
191
+
192
+ # cv2.imshow('canvas', drawed)
193
+ # cv2.waitKey(0)
194
+ return drawed
195
+
196
+
197
+ def cuda(self):
198
+ if self.is_empty:
199
+ return self
200
+ self.to_tensor(device='cuda')
201
+ return self
202
+
203
+ def cpu(self):
204
+ if not self.is_tensor or not self.is_cuda:
205
+ return self
206
+ self.masks = self.masks.cpu()
207
+ self.scores = self.scores.cpu()
208
+ self.bboxes = self.bboxes.cpu()
209
+ return self
210
+
211
+ def to_tensor(self, device: str = 'cpu'):
212
+ if self.is_empty:
213
+ return self
214
+ elif self.is_tensor and self.masks.device == device:
215
+ return self
216
+ self.masks = torch.from_numpy(self.masks).to(device)
217
+ self.bboxes = torch.from_numpy(self.bboxes).to(device)
218
+ self.scores = torch.from_numpy(self.scores ).to(device)
219
+ return self
220
+
221
+ def to_numpy(self):
222
+ if self.is_numpy:
223
+ return self
224
+ if self.is_cuda:
225
+ self.masks = self.masks.cpu().numpy()
226
+ self.scores = self.scores.cpu().numpy()
227
+ self.bboxes = self.bboxes.cpu().numpy()
228
+ else:
229
+ self.masks = self.masks.numpy()
230
+ self.scores = self.scores.numpy()
231
+ self.bboxes = self.bboxes.numpy()
232
+ return self
233
+
234
+ def get_instance(self, ins_idx: int, out_type: str = None, device: str = None):
235
+ mask = self.masks[ins_idx]
236
+ tags = self.tags[ins_idx]
237
+ character_tags = self.character_tags[ins_idx]
238
+ bbox = self.bboxes[ins_idx]
239
+ score = self.scores[ins_idx]
240
+ if out_type is not None:
241
+ if out_type == 'numpy' and not self.is_numpy:
242
+ mask = mask.cpu().numpy()
243
+ bbox = bbox.cpu().numpy()
244
+ score = score.cpu().numpy()
245
+ if out_type == 'tensor' and not self.is_tensor:
246
+ mask = torch.from_numpy(mask)
247
+ bbox = torch.from_numpy(bbox)
248
+ score = torch.from_numpy(score)
249
+ if isinstance(mask, torch.Tensor) and device is not None and mask.device != device:
250
+ mask = mask.to(device)
251
+ bbox = bbox.to(device)
252
+ score = score.to(device)
253
+
254
+ return {
255
+ 'mask': mask,
256
+ 'tags': tags,
257
+ 'character_tags': character_tags,
258
+ 'bbox': bbox,
259
+ 'score': score
260
+ }
261
+
262
+ def __len__(self):
263
+ if self.is_empty:
264
+ return 0
265
+ else:
266
+ return len(self.masks)
267
+
268
+ def resize(self, h, w, mode = 'area'):
269
+ if self.is_empty:
270
+ return
271
+ if self.is_tensor:
272
+ masks = self.masks.to(torch.float).unsqueeze(1)
273
+ oh, ow = masks.shape[2], masks.shape[3]
274
+ hs, ws = h / oh, w / ow
275
+ bboxes = self.bboxes.float()
276
+ bboxes[:, ::2] *= hs
277
+ bboxes[:, 1::2] *= ws
278
+ self.bboxes = torch.round(bboxes).int()
279
+ masks = torch.nn.functional.interpolate(masks, (h, w), mode=mode)
280
+ self.masks = masks.squeeze(1) > 0.3
281
+
282
+ def compose_masks(self, output_type=None):
283
+ if self.is_empty:
284
+ return None
285
+ else:
286
+ mask = self.masks[0]
287
+ if len(self.masks) > 1:
288
+ for m in self.masks[1:]:
289
+ if self.is_numpy:
290
+ mask = np.logical_or(mask, m)
291
+ else:
292
+ mask = torch.logical_or(mask, m)
293
+ if output_type is not None:
294
+ if output_type == 'numpy' and not self.is_numpy:
295
+ mask = mask.cpu().numpy()
296
+ if output_type == 'tensor' and not self.is_tensor:
297
+ mask = torch.from_numpy(mask)
298
+ return mask
299
+
300
+
301
+
animeinsseg/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .dataset import *
2
+ # from .syndataset import *
animeinsseg/data/dataset.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ from typing import List, Optional, Sequence, Tuple, Union
4
+ import copy
5
+ from time import time
6
+ import mmcv
7
+ from mmcv.transforms import to_tensor
8
+ from mmdet.datasets.transforms import LoadAnnotations, RandomCrop, PackDetInputs, Mosaic, CachedMosaic, CachedMixUp, FilterAnnotations
9
+ from mmdet.structures.mask import BitmapMasks, PolygonMasks
10
+ from mmdet.datasets import CocoDataset
11
+ from mmdet.registry import DATASETS, TRANSFORMS
12
+ from numpy import random
13
+ from mmdet.structures.bbox import autocast_box_type, BaseBoxes
14
+ from mmengine.structures import InstanceData, PixelData
15
+ from mmdet.structures import DetDataSample
16
+ from utils.io_utils import bbox_overlap_xy
17
+ from utils.logger import LOGGER
18
+
19
+ @DATASETS.register_module()
20
+ class AnimeMangaMixedDataset(CocoDataset):
21
+
22
+ def __init__(self, animeins_root: str = None, animeins_annfile: str = None, manga109_annfile: str = None, manga109_root: str = None, *args, **kwargs) -> None:
23
+ self.animeins_annfile = animeins_annfile
24
+ self.animeins_root = animeins_root
25
+ self.manga109_annfile = manga109_annfile
26
+ self.manga109_root = manga109_root
27
+ self.cat_ids = []
28
+ self.cat_img_map = {}
29
+ super().__init__(*args, **kwargs)
30
+ LOGGER.info(f'total num data: {len(self.data_list)}')
31
+
32
+
33
+ def parse_data_info(self, raw_data_info: dict, data_prefix: str) -> Union[dict, List[dict]]:
34
+ """Parse raw annotation to target format.
35
+
36
+ Args:
37
+ raw_data_info (dict): Raw data information load from ``ann_file``
38
+
39
+ Returns:
40
+ Union[dict, List[dict]]: Parsed annotation.
41
+ """
42
+ img_info = raw_data_info['raw_img_info']
43
+ ann_info = raw_data_info['raw_ann_info']
44
+
45
+ data_info = {}
46
+
47
+ # TODO: need to change data_prefix['img'] to data_prefix['img_path']
48
+ img_path = osp.join(data_prefix, img_info['file_name'])
49
+ if self.data_prefix.get('seg', None):
50
+ seg_map_path = osp.join(
51
+ self.data_prefix['seg'],
52
+ img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
53
+ else:
54
+ seg_map_path = None
55
+ data_info['img_path'] = img_path
56
+ data_info['img_id'] = img_info['img_id']
57
+ data_info['seg_map_path'] = seg_map_path
58
+ data_info['height'] = img_info['height']
59
+ data_info['width'] = img_info['width']
60
+
61
+ instances = []
62
+ for i, ann in enumerate(ann_info):
63
+ instance = {}
64
+
65
+ if ann.get('ignore', False):
66
+ continue
67
+ x1, y1, w, h = ann['bbox']
68
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
69
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
70
+ if inter_w * inter_h == 0:
71
+ continue
72
+ if ann['area'] <= 0 or w < 1 or h < 1:
73
+ continue
74
+ if ann['category_id'] not in self.cat_ids:
75
+ continue
76
+ bbox = [x1, y1, x1 + w, y1 + h]
77
+
78
+ if ann.get('iscrowd', False):
79
+ instance['ignore_flag'] = 1
80
+ else:
81
+ instance['ignore_flag'] = 0
82
+ instance['bbox'] = bbox
83
+ instance['bbox_label'] = self.cat2label[ann['category_id']]
84
+
85
+ if ann.get('segmentation', None):
86
+ instance['mask'] = ann['segmentation']
87
+
88
+ instances.append(instance)
89
+ data_info['instances'] = instances
90
+ return data_info
91
+
92
+
93
+ def load_data_list(self) -> List[dict]:
94
+ data_lst = []
95
+ if self.manga109_root is not None:
96
+ data_lst += self._data_list(self.manga109_annfile, osp.join(self.manga109_root, 'images'))
97
+ # if len(data_lst) > 8000:
98
+ # data_lst = data_lst[:500]
99
+ LOGGER.info(f'num data from manga109: {len(data_lst)}')
100
+ if self.animeins_root is not None:
101
+ animeins_annfile = osp.join(self.animeins_root, self.animeins_annfile)
102
+ data_prefix = osp.join(self.animeins_root, self.data_prefix['img'])
103
+ anime_lst = self._data_list(animeins_annfile, data_prefix)
104
+ # if len(anime_lst) > 8000:
105
+ # anime_lst = anime_lst[:500]
106
+ data_lst += anime_lst
107
+ LOGGER.info(f'num data from animeins: {len(data_lst)}')
108
+ return data_lst
109
+
110
+ def _data_list(self, annfile: str, data_prefix: str) -> List[dict]:
111
+ """Load annotations from an annotation file named as ``ann_file``
112
+
113
+ Returns:
114
+ List[dict]: A list of annotation.
115
+ """ # noqa: E501
116
+ with self.file_client.get_local_path(annfile) as local_path:
117
+ self.coco = self.COCOAPI(local_path)
118
+ # The order of returned `cat_ids` will not
119
+ # change with the order of the `classes`
120
+ self.cat_ids = self.coco.get_cat_ids(
121
+ cat_names=self.metainfo['classes'])
122
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
123
+ cat_img_map = copy.deepcopy(self.coco.cat_img_map)
124
+ for key, val in cat_img_map.items():
125
+ if key in self.cat_img_map:
126
+ self.cat_img_map[key] += val
127
+ else:
128
+ self.cat_img_map[key] = val
129
+
130
+ img_ids = self.coco.get_img_ids()
131
+ data_list = []
132
+ total_ann_ids = []
133
+ for img_id in img_ids:
134
+ raw_img_info = self.coco.load_imgs([img_id])[0]
135
+ raw_img_info['img_id'] = img_id
136
+
137
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
138
+ raw_ann_info = self.coco.load_anns(ann_ids)
139
+ total_ann_ids.extend(ann_ids)
140
+
141
+ parsed_data_info = self.parse_data_info({
142
+ 'raw_ann_info':
143
+ raw_ann_info,
144
+ 'raw_img_info':
145
+ raw_img_info
146
+ }, data_prefix)
147
+ data_list.append(parsed_data_info)
148
+ if self.ANN_ID_UNIQUE:
149
+ assert len(set(total_ann_ids)) == len(
150
+ total_ann_ids
151
+ ), f"Annotation ids in '{annfile}' are not unique!"
152
+
153
+ del self.coco
154
+
155
+ return data_list
156
+
157
+
158
+
159
+ @TRANSFORMS.register_module()
160
+ class LoadAnnotationsNoSegs(LoadAnnotations):
161
+
162
+ def _process_masks(self, results: dict) -> list:
163
+ """Process gt_masks and filter invalid polygons.
164
+
165
+ Args:
166
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
167
+
168
+ Returns:
169
+ list: Processed gt_masks.
170
+ """
171
+ gt_masks = []
172
+ gt_ignore_flags = []
173
+ gt_ignore_mask_flags = []
174
+ for instance in results.get('instances', []):
175
+ gt_mask = instance['mask']
176
+ ignore_mask = False
177
+ # If the annotation of segmentation mask is invalid,
178
+ # ignore the whole instance.
179
+ if isinstance(gt_mask, list):
180
+ gt_mask = [
181
+ np.array(polygon) for polygon in gt_mask
182
+ if len(polygon) % 2 == 0 and len(polygon) >= 6
183
+ ]
184
+ if len(gt_mask) == 0:
185
+ # ignore this instance and set gt_mask to a fake mask
186
+ instance['ignore_flag'] = 1
187
+ gt_mask = [np.zeros(6)]
188
+ elif not self.poly2mask:
189
+ # `PolygonMasks` requires a ploygon of format List[np.array],
190
+ # other formats are invalid.
191
+ instance['ignore_flag'] = 1
192
+ gt_mask = [np.zeros(6)]
193
+ elif isinstance(gt_mask, dict) and \
194
+ not (gt_mask.get('counts') is not None and
195
+ gt_mask.get('size') is not None and
196
+ isinstance(gt_mask['counts'], (list, str))):
197
+ # if gt_mask is a dict, it should include `counts` and `size`,
198
+ # so that `BitmapMasks` can uncompressed RLE
199
+ # instance['ignore_flag'] = 1
200
+ ignore_mask = True
201
+ gt_mask = [np.zeros(6)]
202
+ gt_masks.append(gt_mask)
203
+ # re-process gt_ignore_flags
204
+ gt_ignore_flags.append(instance['ignore_flag'])
205
+ gt_ignore_mask_flags.append(ignore_mask)
206
+ results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
207
+ results['gt_ignore_mask_flags'] = np.array(gt_ignore_mask_flags, dtype=bool)
208
+ return gt_masks
209
+
210
+ def _load_masks(self, results: dict) -> None:
211
+ """Private function to load mask annotations.
212
+
213
+ Args:
214
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
215
+ """
216
+ h, w = results['ori_shape']
217
+ gt_masks = self._process_masks(results)
218
+ if self.poly2mask:
219
+ p2masks = []
220
+ if len(gt_masks) > 0:
221
+ for ins, mask, ignore_mask in zip(results['instances'], gt_masks, results['gt_ignore_mask_flags']):
222
+ bbox = [int(c) for c in ins['bbox']]
223
+ if ignore_mask:
224
+ m = np.zeros((h, w), dtype=np.uint8)
225
+ m[bbox[1]:bbox[3], bbox[0]: bbox[2]] = 255
226
+ # m[bbox[1]:bbox[3], bbox[0]: bbox[2]]
227
+ p2masks.append(m)
228
+ else:
229
+ p2masks.append(self._poly2mask(mask, h, w))
230
+ # import cv2
231
+ # # cv2.imwrite('tmp_mask.png', p2masks[-1] * 255)
232
+ # cv2.imwrite('tmp_img.png', results['img'])
233
+ # cv2.imwrite('tmp_bbox.png', m * 225)
234
+ # print(p2masks[-1].shape, p2masks[-1].dtype)
235
+ gt_masks = BitmapMasks(p2masks, h, w)
236
+ else:
237
+ # fake polygon masks will be ignored in `PackDetInputs`
238
+ gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
239
+ results['gt_masks'] = gt_masks
240
+
241
+ def transform(self, results: dict) -> dict:
242
+ """Function to load multiple types annotations.
243
+
244
+ Args:
245
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
246
+
247
+ Returns:
248
+ dict: The dict contains loaded bounding box, label and
249
+ semantic segmentation.
250
+ """
251
+
252
+ if self.with_bbox:
253
+ self._load_bboxes(results)
254
+ if self.with_label:
255
+ self._load_labels(results)
256
+ if self.with_mask:
257
+ self._load_masks(results)
258
+ if self.with_seg:
259
+ self._load_seg_map(results)
260
+
261
+ return results
262
+
263
+
264
+
265
+ @TRANSFORMS.register_module()
266
+ class PackDetIputsNoSeg(PackDetInputs):
267
+
268
+ mapping_table = {
269
+ 'gt_bboxes': 'bboxes',
270
+ 'gt_bboxes_labels': 'labels',
271
+ 'gt_ignore_mask_flags': 'ignore_mask',
272
+ 'gt_masks': 'masks'
273
+ }
274
+
275
+ def transform(self, results: dict) -> dict:
276
+ """Method to pack the input data.
277
+
278
+ Args:
279
+ results (dict): Result dict from the data pipeline.
280
+
281
+ Returns:
282
+ dict:
283
+
284
+ - 'inputs' (obj:`torch.Tensor`): The forward data of models.
285
+ - 'data_sample' (obj:`DetDataSample`): The annotation info of the
286
+ sample.
287
+ """
288
+ packed_results = dict()
289
+ if 'img' in results:
290
+ img = results['img']
291
+ if len(img.shape) < 3:
292
+ img = np.expand_dims(img, -1)
293
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
294
+ packed_results['inputs'] = to_tensor(img)
295
+
296
+ if 'gt_ignore_flags' in results:
297
+ valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
298
+ ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
299
+
300
+ data_sample = DetDataSample()
301
+ instance_data = InstanceData()
302
+ ignore_instance_data = InstanceData()
303
+
304
+ for key in self.mapping_table.keys():
305
+ if key not in results:
306
+ continue
307
+ if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
308
+ if 'gt_ignore_flags' in results:
309
+ instance_data[
310
+ self.mapping_table[key]] = results[key][valid_idx]
311
+ ignore_instance_data[
312
+ self.mapping_table[key]] = results[key][ignore_idx]
313
+ else:
314
+ instance_data[self.mapping_table[key]] = results[key]
315
+ else:
316
+ if 'gt_ignore_flags' in results:
317
+ instance_data[self.mapping_table[key]] = to_tensor(
318
+ results[key][valid_idx])
319
+ ignore_instance_data[self.mapping_table[key]] = to_tensor(
320
+ results[key][ignore_idx])
321
+ else:
322
+ instance_data[self.mapping_table[key]] = to_tensor(
323
+ results[key])
324
+ data_sample.gt_instances = instance_data
325
+ data_sample.ignored_instances = ignore_instance_data
326
+
327
+ if 'proposals' in results:
328
+ proposals = InstanceData(
329
+ bboxes=to_tensor(results['proposals']),
330
+ scores=to_tensor(results['proposals_scores']))
331
+ data_sample.proposals = proposals
332
+
333
+ if 'gt_seg_map' in results:
334
+ gt_sem_seg_data = dict(
335
+ sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
336
+ data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
337
+
338
+ img_meta = {}
339
+ for key in self.meta_keys:
340
+ assert key in results, f'`{key}` is not found in `results`, ' \
341
+ f'the valid keys are {list(results)}.'
342
+ img_meta[key] = results[key]
343
+
344
+ data_sample.set_metainfo(img_meta)
345
+ packed_results['data_samples'] = data_sample
346
+
347
+ return packed_results
348
+
349
+
350
+
351
+ def translate_bitmapmask(bitmap_masks: BitmapMasks,
352
+ out_shape,
353
+ offset_x,
354
+ offset_y,):
355
+
356
+ if len(bitmap_masks.masks) == 0:
357
+ translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
358
+ else:
359
+ masks = bitmap_masks.masks
360
+ out_h, out_w = out_shape
361
+ mask_h, mask_w = masks.shape[1:]
362
+
363
+ translated_masks = np.zeros((masks.shape[0], *out_shape),
364
+ dtype=masks.dtype)
365
+
366
+ ix, iy = bbox_overlap_xy([0, 0, out_w, out_h], [offset_x, offset_y, mask_w, mask_h])
367
+ if ix > 2 and iy > 2:
368
+ if offset_x > 0:
369
+ mx1 = 0
370
+ tx1 = offset_x
371
+ else:
372
+ mx1 = -offset_x
373
+ tx1 = 0
374
+ mx2 = min(out_w - offset_x, mask_w)
375
+ tx2 = tx1 + mx2 - mx1
376
+
377
+ if offset_y > 0:
378
+ my1 = 0
379
+ ty1 = offset_y
380
+ else:
381
+ my1 = -offset_y
382
+ ty1 = 0
383
+ my2 = min(out_h - offset_y, mask_h)
384
+ ty2 = ty1 + my2 - my1
385
+
386
+ translated_masks[:, ty1: ty2, tx1: tx2] = \
387
+ masks[:, my1: my2, mx1: mx2]
388
+
389
+ return BitmapMasks(translated_masks, *out_shape)
390
+
391
+
392
+ @TRANSFORMS.register_module()
393
+ class CachedMosaicNoSeg(CachedMosaic):
394
+
395
+ @autocast_box_type()
396
+ def transform(self, results: dict) -> dict:
397
+
398
+ """Mosaic transform function.
399
+
400
+ Args:
401
+ results (dict): Result dict.
402
+
403
+ Returns:
404
+ dict: Updated result dict.
405
+ """
406
+ # cache and pop images
407
+ self.results_cache.append(copy.deepcopy(results))
408
+ if len(self.results_cache) > self.max_cached_images:
409
+ if self.random_pop:
410
+ index = random.randint(0, len(self.results_cache) - 1)
411
+ else:
412
+ index = 0
413
+ self.results_cache.pop(index)
414
+
415
+ if len(self.results_cache) <= 4:
416
+ return results
417
+
418
+ if random.uniform(0, 1) > self.prob:
419
+ return results
420
+ indices = self.get_indexes(self.results_cache)
421
+ mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices]
422
+
423
+ # TODO: refactor mosaic to reuse these code.
424
+ mosaic_bboxes = []
425
+ mosaic_bboxes_labels = []
426
+ mosaic_ignore_flags = []
427
+ mosaic_masks = []
428
+ mosaic_ignore_mask_flags = []
429
+ with_mask = True if 'gt_masks' in results else False
430
+
431
+ if len(results['img'].shape) == 3:
432
+ mosaic_img = np.full(
433
+ (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
434
+ self.pad_val,
435
+ dtype=results['img'].dtype)
436
+ else:
437
+ mosaic_img = np.full(
438
+ (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
439
+ self.pad_val,
440
+ dtype=results['img'].dtype)
441
+
442
+ # mosaic center x, y
443
+ center_x = int(
444
+ random.uniform(*self.center_ratio_range) * self.img_scale[0])
445
+ center_y = int(
446
+ random.uniform(*self.center_ratio_range) * self.img_scale[1])
447
+ center_position = (center_x, center_y)
448
+
449
+ loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
450
+
451
+ n_manga = 0
452
+ for i, loc in enumerate(loc_strs):
453
+ if loc == 'top_left':
454
+ results_patch = copy.deepcopy(results)
455
+ else:
456
+ results_patch = copy.deepcopy(mix_results[i - 1])
457
+
458
+ is_manga = results_patch['img_id'] > 900000000
459
+ if is_manga:
460
+ n_manga += 1
461
+ if n_manga > 3:
462
+ continue
463
+ im_h, im_w = results_patch['img'].shape[:2]
464
+ if im_w > im_h and random.random() < 0.75:
465
+ results_patch = hcrop(results_patch, (im_h, im_w // 2), True)
466
+
467
+ img_i = results_patch['img']
468
+ h_i, w_i = img_i.shape[:2]
469
+ # keep_ratio resize
470
+ scale_ratio_i = min(self.img_scale[1] / h_i,
471
+ self.img_scale[0] / w_i)
472
+ img_i = mmcv.imresize(
473
+ img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
474
+
475
+ # compute the combine parameters
476
+ paste_coord, crop_coord = self._mosaic_combine(
477
+ loc, center_position, img_i.shape[:2][::-1])
478
+ x1_p, y1_p, x2_p, y2_p = paste_coord
479
+ x1_c, y1_c, x2_c, y2_c = crop_coord
480
+
481
+ # crop and paste image
482
+ mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
483
+
484
+ # adjust coordinate
485
+ gt_bboxes_i = results_patch['gt_bboxes']
486
+ gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
487
+ gt_ignore_flags_i = results_patch['gt_ignore_flags']
488
+ gt_ignore_mask_i = results_patch['gt_ignore_mask_flags']
489
+
490
+ padw = x1_p - x1_c
491
+ padh = y1_p - y1_c
492
+ gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
493
+ gt_bboxes_i.translate_([padw, padh])
494
+ mosaic_bboxes.append(gt_bboxes_i)
495
+ mosaic_bboxes_labels.append(gt_bboxes_labels_i)
496
+ mosaic_ignore_flags.append(gt_ignore_flags_i)
497
+ mosaic_ignore_mask_flags.append(gt_ignore_mask_i)
498
+ if with_mask and results_patch.get('gt_masks', None) is not None:
499
+
500
+ gt_masks_i = results_patch['gt_masks']
501
+ gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
502
+
503
+ gt_masks_i = translate_bitmapmask(gt_masks_i,
504
+ out_shape=(int(self.img_scale[0] * 2),
505
+ int(self.img_scale[1] * 2)),
506
+ offset_x=padw, offset_y=padh)
507
+
508
+ # gt_masks_i = gt_masks_i.translate(
509
+ # out_shape=(int(self.img_scale[0] * 2),
510
+ # int(self.img_scale[1] * 2)),
511
+ # offset=padw,
512
+ # direction='horizontal')
513
+ # gt_masks_i = gt_masks_i.translate(
514
+ # out_shape=(int(self.img_scale[0] * 2),
515
+ # int(self.img_scale[1] * 2)),
516
+ # offset=padh,
517
+ # direction='vertical')
518
+ mosaic_masks.append(gt_masks_i)
519
+
520
+ mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
521
+ mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
522
+ mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
523
+ mosaic_ignore_mask_flags = np.concatenate(mosaic_ignore_mask_flags, 0)
524
+
525
+ if self.bbox_clip_border:
526
+ mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
527
+ # remove outside bboxes
528
+ inside_inds = mosaic_bboxes.is_inside(
529
+ [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
530
+
531
+ mosaic_bboxes = mosaic_bboxes[inside_inds]
532
+ mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
533
+ mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
534
+ mosaic_ignore_mask_flags = mosaic_ignore_mask_flags[inside_inds]
535
+
536
+ results['img'] = mosaic_img
537
+ results['img_shape'] = mosaic_img.shape
538
+ results['gt_bboxes'] = mosaic_bboxes
539
+ results['gt_bboxes_labels'] = mosaic_bboxes_labels
540
+ results['gt_ignore_flags'] = mosaic_ignore_flags
541
+ results['gt_ignore_mask_flags'] = mosaic_ignore_mask_flags
542
+
543
+
544
+ if with_mask:
545
+ total_instances = len(inside_inds)
546
+ assert total_instances == np.array([m.masks.shape[0] for m in mosaic_masks]).sum()
547
+ if total_instances > 10:
548
+ masks = np.empty((inside_inds.sum(), mosaic_masks[0].height, mosaic_masks[0].width), dtype=np.uint8)
549
+ msk_idx = 0
550
+ mmsk_idx = 0
551
+ for m in mosaic_masks:
552
+ for ii in range(m.masks.shape[0]):
553
+ if inside_inds[msk_idx]:
554
+ masks[mmsk_idx] = m.masks[ii]
555
+ mmsk_idx += 1
556
+ msk_idx += 1
557
+ results['gt_masks'] = BitmapMasks(masks, mosaic_masks[0].height, mosaic_masks[0].width)
558
+ else:
559
+ mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
560
+ results['gt_masks'] = mosaic_masks[inside_inds]
561
+ # assert np.all(results['gt_masks'].masks == masks) and results['gt_masks'].masks.shape == masks.shape
562
+
563
+ # assert inside_inds.sum() == results['gt_masks'].masks.shape[0]
564
+ return results
565
+
566
+ @TRANSFORMS.register_module()
567
+ class FilterAnnotationsNoSeg(FilterAnnotations):
568
+
569
+ def __init__(self,
570
+ min_gt_bbox_wh: Tuple[int, int] = (1, 1),
571
+ min_gt_mask_area: int = 1,
572
+ by_box: bool = True,
573
+ by_mask: bool = False,
574
+ keep_empty: bool = True) -> None:
575
+ # TODO: add more filter options
576
+ assert by_box or by_mask
577
+ self.min_gt_bbox_wh = min_gt_bbox_wh
578
+ self.min_gt_mask_area = min_gt_mask_area
579
+ self.by_box = by_box
580
+ self.by_mask = by_mask
581
+ self.keep_empty = keep_empty
582
+
583
+ @autocast_box_type()
584
+ def transform(self, results: dict) -> Union[dict, None]:
585
+ """Transform function to filter annotations.
586
+
587
+ Args:
588
+ results (dict): Result dict.
589
+
590
+ Returns:
591
+ dict: Updated result dict.
592
+ """
593
+ assert 'gt_bboxes' in results
594
+ gt_bboxes = results['gt_bboxes']
595
+ if gt_bboxes.shape[0] == 0:
596
+ return results
597
+
598
+ tests = []
599
+ if self.by_box:
600
+ tests.append(
601
+ ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
602
+ (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
603
+
604
+ if self.by_mask:
605
+ assert 'gt_masks' in results
606
+ gt_masks = results['gt_masks']
607
+ tests.append(gt_masks.areas >= self.min_gt_mask_area)
608
+
609
+ keep = tests[0]
610
+ for t in tests[1:]:
611
+ keep = keep & t
612
+
613
+ # if not keep.any():
614
+ # if self.keep_empty:
615
+ # return None
616
+
617
+ assert len(results['gt_ignore_flags']) == len(results['gt_ignore_mask_flags'])
618
+ keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags', 'gt_ignore_mask_flags')
619
+ for key in keys:
620
+ if key in results:
621
+ try:
622
+ results[key] = results[key][keep]
623
+ except Exception as e:
624
+ raise e
625
+
626
+ return results
627
+
628
+
629
+ def hcrop(results: dict, crop_size: Tuple[int, int],
630
+ allow_negative_crop: bool) -> Union[dict, None]:
631
+
632
+ assert crop_size[0] > 0 and crop_size[1] > 0
633
+ img = results['img']
634
+ offset_h, offset_w = 0, random.choice([0, crop_size[1]])
635
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
636
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
637
+
638
+ # Record the homography matrix for the RandomCrop
639
+ homography_matrix = np.array(
640
+ [[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
641
+ dtype=np.float32)
642
+ if results.get('homography_matrix', None) is None:
643
+ results['homography_matrix'] = homography_matrix
644
+ else:
645
+ results['homography_matrix'] = homography_matrix @ results[
646
+ 'homography_matrix']
647
+
648
+ # crop the image
649
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
650
+ img_shape = img.shape
651
+ results['img'] = img
652
+ results['img_shape'] = img_shape
653
+
654
+ # crop bboxes accordingly and clip to the image boundary
655
+ if results.get('gt_bboxes', None) is not None:
656
+ bboxes = results['gt_bboxes']
657
+ bboxes.translate_([-offset_w, -offset_h])
658
+ bboxes.clip_(img_shape[:2])
659
+ valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
660
+ # If the crop does not contain any gt-bbox area and
661
+ # allow_negative_crop is False, skip this image.
662
+ if (not valid_inds.any() and not allow_negative_crop):
663
+ return None
664
+
665
+ results['gt_bboxes'] = bboxes[valid_inds]
666
+
667
+ if results.get('gt_ignore_flags', None) is not None:
668
+ results['gt_ignore_flags'] = \
669
+ results['gt_ignore_flags'][valid_inds]
670
+
671
+ if results.get('gt_ignore_mask_flags', None) is not None:
672
+ results['gt_ignore_mask_flags'] = \
673
+ results['gt_ignore_mask_flags'][valid_inds]
674
+
675
+ if results.get('gt_bboxes_labels', None) is not None:
676
+ results['gt_bboxes_labels'] = \
677
+ results['gt_bboxes_labels'][valid_inds]
678
+
679
+ if results.get('gt_masks', None) is not None:
680
+ results['gt_masks'] = results['gt_masks'][
681
+ valid_inds.nonzero()[0]].crop(
682
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
683
+ results['gt_bboxes'] = results['gt_masks'].get_bboxes(
684
+ type(results['gt_bboxes']))
685
+
686
+ # crop semantic seg
687
+ if results.get('gt_seg_map', None) is not None:
688
+ results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
689
+ crop_x1:crop_x2]
690
+
691
+ return results
692
+
693
+
694
+ @TRANSFORMS.register_module()
695
+ class RandomCropNoSeg(RandomCrop):
696
+
697
+ def _crop_data(self, results: dict, crop_size: Tuple[int, int],
698
+ allow_negative_crop: bool) -> Union[dict, None]:
699
+
700
+ assert crop_size[0] > 0 and crop_size[1] > 0
701
+ img = results['img']
702
+ margin_h = max(img.shape[0] - crop_size[0], 0)
703
+ margin_w = max(img.shape[1] - crop_size[1], 0)
704
+ offset_h, offset_w = self._rand_offset((margin_h, margin_w))
705
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
706
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
707
+
708
+ # Record the homography matrix for the RandomCrop
709
+ homography_matrix = np.array(
710
+ [[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
711
+ dtype=np.float32)
712
+ if results.get('homography_matrix', None) is None:
713
+ results['homography_matrix'] = homography_matrix
714
+ else:
715
+ results['homography_matrix'] = homography_matrix @ results[
716
+ 'homography_matrix']
717
+
718
+ # crop the image
719
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
720
+ img_shape = img.shape
721
+ results['img'] = img
722
+ results['img_shape'] = img_shape
723
+
724
+ # crop bboxes accordingly and clip to the image boundary
725
+ if results.get('gt_bboxes', None) is not None:
726
+ bboxes = results['gt_bboxes']
727
+ bboxes.translate_([-offset_w, -offset_h])
728
+ if self.bbox_clip_border:
729
+ bboxes.clip_(img_shape[:2])
730
+ valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
731
+ # If the crop does not contain any gt-bbox area and
732
+ # allow_negative_crop is False, skip this image.
733
+ if (not valid_inds.any() and not allow_negative_crop):
734
+ return None
735
+
736
+ results['gt_bboxes'] = bboxes[valid_inds]
737
+
738
+ if results.get('gt_ignore_flags', None) is not None:
739
+ results['gt_ignore_flags'] = \
740
+ results['gt_ignore_flags'][valid_inds]
741
+
742
+ if results.get('gt_ignore_mask_flags', None) is not None:
743
+ results['gt_ignore_mask_flags'] = \
744
+ results['gt_ignore_mask_flags'][valid_inds]
745
+
746
+ if results.get('gt_bboxes_labels', None) is not None:
747
+ results['gt_bboxes_labels'] = \
748
+ results['gt_bboxes_labels'][valid_inds]
749
+
750
+ if results.get('gt_masks', None) is not None:
751
+ results['gt_masks'] = results['gt_masks'][
752
+ valid_inds.nonzero()[0]].crop(
753
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
754
+ if self.recompute_bbox:
755
+ results['gt_bboxes'] = results['gt_masks'].get_bboxes(
756
+ type(results['gt_bboxes']))
757
+
758
+ # crop semantic seg
759
+ if results.get('gt_seg_map', None) is not None:
760
+ results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
761
+ crop_x1:crop_x2]
762
+
763
+ return results
764
+
765
+
766
+
767
+ @TRANSFORMS.register_module()
768
+ class CachedMixUpNoSeg(CachedMixUp):
769
+
770
+ @autocast_box_type()
771
+ def transform(self, results: dict) -> dict:
772
+ """MixUp transform function.
773
+
774
+ Args:
775
+ results (dict): Result dict.
776
+
777
+ Returns:
778
+ dict: Updated result dict.
779
+ """
780
+ # cache and pop images
781
+ self.results_cache.append(copy.deepcopy(results))
782
+ if len(self.results_cache) > self.max_cached_images:
783
+ if self.random_pop:
784
+ index = random.randint(0, len(self.results_cache) - 1)
785
+ else:
786
+ index = 0
787
+ self.results_cache.pop(index)
788
+
789
+ if len(self.results_cache) <= 1:
790
+ return results
791
+
792
+ if random.uniform(0, 1) > self.prob:
793
+ return results
794
+
795
+ index = self.get_indexes(self.results_cache)
796
+ retrieve_results = copy.deepcopy(self.results_cache[index])
797
+
798
+ # TODO: refactor mixup to reuse these code.
799
+ if retrieve_results['gt_bboxes'].shape[0] == 0:
800
+ # empty bbox
801
+ return results
802
+
803
+ retrieve_img = retrieve_results['img']
804
+ with_mask = True if 'gt_masks' in results else False
805
+
806
+ jit_factor = random.uniform(*self.ratio_range)
807
+ is_filp = random.uniform(0, 1) > self.flip_ratio
808
+
809
+ if len(retrieve_img.shape) == 3:
810
+ out_img = np.ones(
811
+ (self.dynamic_scale[1], self.dynamic_scale[0], 3),
812
+ dtype=retrieve_img.dtype) * self.pad_val
813
+ else:
814
+ out_img = np.ones(
815
+ self.dynamic_scale[::-1],
816
+ dtype=retrieve_img.dtype) * self.pad_val
817
+
818
+ # 1. keep_ratio resize
819
+ scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
820
+ self.dynamic_scale[0] / retrieve_img.shape[1])
821
+ retrieve_img = mmcv.imresize(
822
+ retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
823
+ int(retrieve_img.shape[0] * scale_ratio)))
824
+
825
+ # 2. paste
826
+ out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
827
+
828
+ # 3. scale jit
829
+ scale_ratio *= jit_factor
830
+ out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
831
+ int(out_img.shape[0] * jit_factor)))
832
+
833
+ # 4. flip
834
+ if is_filp:
835
+ out_img = out_img[:, ::-1, :]
836
+
837
+ # 5. random crop
838
+ ori_img = results['img']
839
+ origin_h, origin_w = out_img.shape[:2]
840
+ target_h, target_w = ori_img.shape[:2]
841
+ padded_img = np.ones((max(origin_h, target_h), max(
842
+ origin_w, target_w), 3)) * self.pad_val
843
+ padded_img = padded_img.astype(np.uint8)
844
+ padded_img[:origin_h, :origin_w] = out_img
845
+
846
+ x_offset, y_offset = 0, 0
847
+ if padded_img.shape[0] > target_h:
848
+ y_offset = random.randint(0, padded_img.shape[0] - target_h)
849
+ if padded_img.shape[1] > target_w:
850
+ x_offset = random.randint(0, padded_img.shape[1] - target_w)
851
+ padded_cropped_img = padded_img[y_offset:y_offset + target_h,
852
+ x_offset:x_offset + target_w]
853
+
854
+ # 6. adjust bbox
855
+ retrieve_gt_bboxes = retrieve_results['gt_bboxes']
856
+ retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
857
+ if with_mask:
858
+ retrieve_gt_masks = retrieve_results['gt_masks'].rescale(
859
+ scale_ratio)
860
+
861
+ if self.bbox_clip_border:
862
+ retrieve_gt_bboxes.clip_([origin_h, origin_w])
863
+
864
+ if is_filp:
865
+ retrieve_gt_bboxes.flip_([origin_h, origin_w],
866
+ direction='horizontal')
867
+ if with_mask:
868
+ retrieve_gt_masks = retrieve_gt_masks.flip()
869
+
870
+ # 7. filter
871
+ cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
872
+ cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
873
+ if with_mask:
874
+
875
+ retrieve_gt_masks = translate_bitmapmask(retrieve_gt_masks,
876
+ out_shape=(target_h, target_w),
877
+ offset_x=-x_offset, offset_y=-y_offset)
878
+
879
+ # retrieve_gt_masks = retrieve_gt_masks.translate(
880
+ # out_shape=(target_h, target_w),
881
+ # offset=-x_offset,
882
+ # direction='horizontal')
883
+ # retrieve_gt_masks = retrieve_gt_masks.translate(
884
+ # out_shape=(target_h, target_w),
885
+ # offset=-y_offset,
886
+ # direction='vertical')
887
+
888
+ if self.bbox_clip_border:
889
+ cp_retrieve_gt_bboxes.clip_([target_h, target_w])
890
+
891
+ # 8. mix up
892
+ ori_img = ori_img.astype(np.float32)
893
+ mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
894
+
895
+ retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
896
+ retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
897
+ retrieve_gt_ignore_mask_flags = retrieve_results['gt_ignore_mask_flags']
898
+
899
+ mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
900
+ (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
901
+ mixup_gt_bboxes_labels = np.concatenate(
902
+ (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
903
+ mixup_gt_ignore_flags = np.concatenate(
904
+ (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
905
+ mixup_gt_ignore_mask_flags = np.concatenate(
906
+ (results['gt_ignore_mask_flags'], retrieve_gt_ignore_mask_flags), axis=0)
907
+
908
+ if with_mask:
909
+ mixup_gt_masks = retrieve_gt_masks.cat(
910
+ [results['gt_masks'], retrieve_gt_masks])
911
+
912
+ # remove outside bbox
913
+ inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
914
+ mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
915
+ mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
916
+ mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
917
+ mixup_gt_ignore_mask_flags = mixup_gt_ignore_mask_flags[inside_inds]
918
+ if with_mask:
919
+ mixup_gt_masks = mixup_gt_masks[inside_inds]
920
+
921
+ results['img'] = mixup_img.astype(np.uint8)
922
+ results['img_shape'] = mixup_img.shape
923
+ results['gt_bboxes'] = mixup_gt_bboxes
924
+ results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
925
+ results['gt_ignore_flags'] = mixup_gt_ignore_flags
926
+ results['gt_ignore_mask_flags'] = mixup_gt_ignore_mask_flags
927
+ if with_mask:
928
+ results['gt_masks'] = mixup_gt_masks
929
+ return results
animeinsseg/data/maskrefine_dataset.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import pycocotools.mask as maskUtils
5
+ from pycocotools.coco import COCO
6
+ import random
7
+ import os.path as osp
8
+ import cv2
9
+ import numpy as np
10
+ from scipy.ndimage import distance_transform_bf, distance_transform_edt, distance_transform_cdt
11
+
12
+
13
+ def is_grey(img: np.ndarray):
14
+ if len(img.shape) == 3 and img.shape[2] == 3:
15
+ return False
16
+ else:
17
+ return True
18
+
19
+
20
+ def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value = (0, 0, 0)):
21
+ h, w = img.shape[:2]
22
+ pad_h, pad_w = 0, 0
23
+
24
+ # make square image
25
+ if w < h:
26
+ pad_w = h - w
27
+ w += pad_w
28
+ elif h < w:
29
+ pad_h = w - h
30
+ h += pad_h
31
+
32
+ pad_size = tgt_size - h
33
+ if pad_size > 0:
34
+ pad_h += pad_size
35
+ pad_w += pad_size
36
+
37
+ if pad_h > 0 or pad_w > 0:
38
+ c = 1
39
+ if is_grey(img):
40
+ if isinstance(pad_value, tuple):
41
+ pad_value = pad_value[0]
42
+ else:
43
+ if isinstance(pad_value, int):
44
+ pad_value = (pad_value, pad_value, pad_value)
45
+
46
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
47
+
48
+ resize_ratio = tgt_size / img.shape[0]
49
+ if resize_ratio < 1:
50
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
51
+ elif resize_ratio > 1:
52
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR)
53
+
54
+ return img, resize_ratio, pad_h, pad_w
55
+
56
+
57
+ class MaskRefineDataset(Dataset):
58
+
59
+ def __init__(self,
60
+ refine_ann_path: str,
61
+ data_root: str,
62
+ load_instance_mask: bool = True,
63
+ aug_ins_prob: float = 0.,
64
+ ins_rect_prob: float = 0.,
65
+ output_size: int = 720,
66
+ augmentation: bool = False,
67
+ with_distance: bool = False):
68
+ self.load_instance_mask = load_instance_mask
69
+ self.ann_util = COCO(refine_ann_path)
70
+ self.img_ids = self.ann_util.getImgIds()
71
+ self.set_load_method(load_instance_mask)
72
+ self.data_root = data_root
73
+
74
+ self.ins_rect_prob = ins_rect_prob
75
+ self.aug_ins_prob = aug_ins_prob
76
+ self.augmentation = augmentation
77
+ if augmentation:
78
+ transform = [
79
+ A.OpticalDistortion(),
80
+ A.HorizontalFlip(),
81
+ A.CLAHE(),
82
+ A.Posterize(),
83
+ A.CropAndPad(percent=0.1, p=0.3, pad_mode=cv2.BORDER_CONSTANT, pad_cval=0, pad_cval_mask=0, keep_size=True),
84
+ A.RandomContrast(),
85
+ A.Rotate(30, p=0.3, mask_value=0, border_mode=cv2.BORDER_CONSTANT)
86
+ ]
87
+ self._aug_transform = A.Compose(transform)
88
+ else:
89
+ self._aug_transform = None
90
+
91
+ self.output_size = output_size
92
+ self.with_distance = with_distance
93
+
94
+ def set_output_size(self, size: int):
95
+ self.output_size = size
96
+
97
+ def set_load_method(self, load_instance_mask: bool):
98
+ if load_instance_mask:
99
+ self._load_mask = self._load_with_instance
100
+ else:
101
+ self._load_mask = self._load_without_instance
102
+
103
+ def __getitem__(self, idx: int):
104
+ img_id = self.img_ids[idx]
105
+ img_meta = self.ann_util.imgs[img_id]
106
+ img_path = osp.join(self.data_root, img_meta['file_name'])
107
+ img = cv2.imread(img_path)
108
+
109
+ annids = self.ann_util.getAnnIds([img_id])
110
+ if len(annids) > 0:
111
+ ann = random.choice(annids)
112
+ ann = self.ann_util.anns[ann]
113
+ assert ann['image_id'] == img_id
114
+ else:
115
+ ann = None
116
+
117
+ return self._load_mask(img, ann)
118
+
119
+ def transform(self, img: np.ndarray, mask: np.ndarray, ins_seg: np.ndarray = None) -> dict:
120
+ if ins_seg is not None:
121
+ use_seg = True
122
+ else:
123
+ use_seg = False
124
+
125
+ if self.augmentation:
126
+ masks = [mask]
127
+ if use_seg:
128
+ masks.append(ins_seg)
129
+ data = self._aug_transform(image=img, masks=masks)
130
+ img = data['image']
131
+ masks = data['masks']
132
+ mask = masks[0]
133
+ if use_seg:
134
+ ins_seg = masks[1]
135
+
136
+ img = square_pad_resize(img, self.output_size, random.randint(0, 255))[0]
137
+ mask = square_pad_resize(mask, self.output_size, 0)[0]
138
+ if ins_seg is not None:
139
+ ins_seg = square_pad_resize(ins_seg, self.output_size, 0)[0]
140
+
141
+ img = (img.astype(np.float32) / 255.).transpose((2, 0, 1))
142
+ mask = mask[None, ...]
143
+
144
+
145
+ if use_seg:
146
+ ins_seg = ins_seg[None, ...]
147
+ img = np.concatenate((img, ins_seg), axis=0)
148
+
149
+ data = {'img': img, 'mask': mask}
150
+ if self.with_distance:
151
+ dist = distance_transform_edt(mask[0])
152
+ dist_max = dist.max()
153
+ if dist_max != 0:
154
+ dist = 1 - dist / dist_max
155
+ # diff_mat = cv2.bitwise_xor(mask[0], ins_seg[0])
156
+ # dist = dist + diff_mat + 0.2
157
+ dist = dist + 0.2
158
+ dist = dist.size / (dist.sum() + 1) * dist
159
+ dist = np.clip(dist, 0, 20)
160
+ else:
161
+ dist = np.ones_like(dist)
162
+ # print(dist.max(), dist.min())
163
+ data['dist_weight'] = dist[None, ...]
164
+ return data
165
+
166
+ def _load_with_instance(self, img: np.ndarray, ann: dict):
167
+ if ann is None:
168
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
169
+ ins_seg = mask
170
+ else:
171
+ mask = maskUtils.decode(ann['segmentation']).astype(np.float32)
172
+ if self.augmentation and random.random() < self.ins_rect_prob:
173
+ ins_seg = np.zeros_like(mask)
174
+ bbox = [int(b) for b in ann['bbox']]
175
+ ins_seg[bbox[1]: bbox[1] + bbox[3], bbox[0]: bbox[0] + bbox[2]] = 1
176
+ elif len(ann['pred_segmentations']) > 0:
177
+ ins_seg = random.choice(ann['pred_segmentations'])
178
+ ins_seg = maskUtils.decode(ins_seg).astype(np.float32)
179
+ else:
180
+ ins_seg = mask
181
+ if self.augmentation and random.random() < self.aug_ins_prob:
182
+ ksize = random.choice([1, 3, 5, 7])
183
+ ksize = ksize * 2 + 1
184
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(ksize, ksize))
185
+ if random.random() < 0.5:
186
+ ins_seg = cv2.dilate(ins_seg, kernel)
187
+ else:
188
+ ins_seg = cv2.erode(ins_seg, kernel)
189
+
190
+ return self.transform(img, mask, ins_seg)
191
+
192
+ def _load_without_instance(self, img: np.ndarray, ann: dict):
193
+ if ann is None:
194
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
195
+ else:
196
+ mask = maskUtils.decode(ann['segmentation']).astype(np.float32)
197
+ return self.transform(img, mask)
198
+
199
+ def __len__(self):
200
+ return len(self.img_ids)
201
+
202
+
203
+ if __name__ == '__main__':
204
+ ann_path = r'workspace/test_syndata/annotations/refine_train.json'
205
+ data_root = r'workspace/test_syndata/train'
206
+
207
+ ann_path = r'workspace/test_syndata/annotations/refine_train.json'
208
+ data_root = r'workspace/test_syndata/train'
209
+ aug_ins_prob = 0.5
210
+ load_instance_mask = True
211
+ ins_rect_prob = 0.25
212
+ output_size = 640
213
+ augmentation = True
214
+
215
+ random.seed(0)
216
+
217
+ md = MaskRefineDataset(ann_path, data_root, load_instance_mask, aug_ins_prob, ins_rect_prob, output_size, augmentation, with_distance=True)
218
+
219
+ dl = DataLoader(md, batch_size=1, shuffle=False, persistent_workers=True,
220
+ num_workers=1, pin_memory=True)
221
+ for data in dl:
222
+ img = data['img'].cpu().numpy()
223
+ img = (img[0, :3].transpose((1, 2, 0)) * 255).astype(np.uint8)
224
+ mask = (data['mask'].cpu().numpy()[0][0] * 255).astype(np.uint8)
225
+ if load_instance_mask:
226
+ ins = (data['img'].cpu().numpy()[0][3] * 255).astype(np.uint8)
227
+ cv2.imshow('ins', ins)
228
+ dist = data['dist_weight'].cpu().numpy()[0][0]
229
+ dist = (dist / dist.max() * 255).astype(np.uint8)
230
+ cv2.imshow('img', img)
231
+ cv2.imshow('mask', mask)
232
+ cv2.imshow('dist_weight', dist)
233
+ cv2.waitKey(0)
234
+
235
+ # cv2.imwrite('')
animeinsseg/data/metrics.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import datetime
3
+ import itertools
4
+ import os.path as osp
5
+ import tempfile
6
+ from collections import OrderedDict
7
+ from typing import Dict, List, Optional, Sequence, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from mmengine.evaluator import BaseMetric
12
+ from mmengine.fileio import FileClient, dump, load
13
+ from mmengine.logging import MMLogger
14
+ from terminaltables import AsciiTable
15
+
16
+ from mmdet.datasets.api_wrappers import COCO, COCOeval
17
+ from mmdet.registry import METRICS
18
+ from mmdet.structures.mask import encode_mask_results
19
+ # from ..functional import eval_recalls
20
+ from mmdet.evaluation.metrics import CocoMetric
21
+
22
+
23
+ @METRICS.register_module()
24
+ class AnimeMangaMetric(CocoMetric):
25
+
26
+ def __init__(self,
27
+ manga109_annfile=None,
28
+ animeins_annfile=None,
29
+ ann_file: Optional[str] = None,
30
+ metric: Union[str, List[str]] = 'bbox',
31
+ classwise: bool = False,
32
+ proposal_nums: Sequence[int] = (100, 300, 1000),
33
+ iou_thrs: Optional[Union[float, Sequence[float]]] = None,
34
+ metric_items: Optional[Sequence[str]] = None,
35
+ format_only: bool = False,
36
+ outfile_prefix: Optional[str] = None,
37
+ file_client_args: dict = dict(backend='disk'),
38
+ collect_device: str = 'cpu',
39
+ prefix: Optional[str] = None,
40
+ sort_categories: bool = False) -> None:
41
+
42
+ super().__init__(ann_file, metric, classwise, proposal_nums, iou_thrs, metric_items, format_only, outfile_prefix, file_client_args, collect_device, prefix, sort_categories)
43
+
44
+ self.manga109_img_ids = set()
45
+ if manga109_annfile is not None:
46
+ with self.file_client.get_local_path(manga109_annfile) as local_path:
47
+ self._manga109_coco_api = COCO(local_path)
48
+ if sort_categories:
49
+ # 'categories' list in objects365_train.json and
50
+ # objects365_val.json is inconsistent, need sort
51
+ # list(or dict) before get cat_ids.
52
+ cats = self._manga109_coco_api.cats
53
+ sorted_cats = {i: cats[i] for i in sorted(cats)}
54
+ self._manga109_coco_api.cats = sorted_cats
55
+ categories = self._manga109_coco_api.dataset['categories']
56
+ sorted_categories = sorted(
57
+ categories, key=lambda i: i['id'])
58
+ self._manga109_coco_api.dataset['categories'] = sorted_categories
59
+ self.manga109_img_ids = set(self._manga109_coco_api.get_img_ids())
60
+ else:
61
+ self._manga109_coco_api = None
62
+
63
+ self.animeins_img_ids = set()
64
+ if animeins_annfile is not None:
65
+ with self.file_client.get_local_path(animeins_annfile) as local_path:
66
+ self._animeins_coco_api = COCO(local_path)
67
+ if sort_categories:
68
+ # 'categories' list in objects365_train.json and
69
+ # objects365_val.json is inconsistent, need sort
70
+ # list(or dict) before get cat_ids.
71
+ cats = self._animeins_coco_api.cats
72
+ sorted_cats = {i: cats[i] for i in sorted(cats)}
73
+ self._animeins_coco_api.cats = sorted_cats
74
+ categories = self._animeins_coco_api.dataset['categories']
75
+ sorted_categories = sorted(
76
+ categories, key=lambda i: i['id'])
77
+ self._animeins_coco_api.dataset['categories'] = sorted_categories
78
+ self.animeins_img_ids = set(self._animeins_coco_api.get_img_ids())
79
+ else:
80
+ self._animeins_coco_api = None
81
+
82
+ if self._animeins_coco_api is not None:
83
+ self._coco_api = self._animeins_coco_api
84
+ else:
85
+ self._coco_api = self._manga109_coco_api
86
+
87
+
88
+ def compute_metrics(self, results: list) -> Dict[str, float]:
89
+
90
+ # split gt and prediction list
91
+ gts, preds = zip(*results)
92
+
93
+ manga109_gts, animeins_gts = [], []
94
+ manga109_preds, animeins_preds = [], []
95
+ for gt, pred in zip(gts, preds):
96
+ if gt['img_id'] in self.manga109_img_ids:
97
+ manga109_gts.append(gt)
98
+ manga109_preds.append(pred)
99
+ else:
100
+ animeins_gts.append(gt)
101
+ animeins_preds.append(pred)
102
+
103
+ tmp_dir = None
104
+ if self.outfile_prefix is None:
105
+ tmp_dir = tempfile.TemporaryDirectory()
106
+ outfile_prefix = osp.join(tmp_dir.name, 'results')
107
+ else:
108
+ outfile_prefix = self.outfile_prefix
109
+
110
+ eval_results = OrderedDict()
111
+
112
+ if len(manga109_gts) > 0:
113
+ metrics = []
114
+ for m in self.metrics:
115
+ if m != 'segm':
116
+ metrics.append(m)
117
+
118
+ self.cat_ids = self._manga109_coco_api.get_cat_ids(cat_names=self.dataset_meta['classes'])
119
+ self.img_ids = self._manga109_coco_api.get_img_ids()
120
+ rst = self._compute_metrics(metrics, self._manga109_coco_api, manga109_preds, outfile_prefix, tmp_dir)
121
+ for key, item in rst.items():
122
+ eval_results['manga109_'+key] = item
123
+
124
+ if len(animeins_gts) > 0:
125
+ self.cat_ids = self._animeins_coco_api.get_cat_ids(cat_names=self.dataset_meta['classes'])
126
+ self.img_ids = self._animeins_coco_api.get_img_ids()
127
+ rst = self._compute_metrics(self.metrics, self._animeins_coco_api, animeins_preds, outfile_prefix, tmp_dir)
128
+ for key, item in rst.items():
129
+ eval_results['animeins_'+key] = item
130
+
131
+ return eval_results
132
+
133
+ def results2json(self, results: Sequence[dict],
134
+ outfile_prefix: str) -> dict:
135
+ """Dump the detection results to a COCO style json file.
136
+
137
+ There are 3 types of results: proposals, bbox predictions, mask
138
+ predictions, and they have different data types. This method will
139
+ automatically recognize the type, and dump them to json files.
140
+
141
+ Args:
142
+ results (Sequence[dict]): Testing results of the
143
+ dataset.
144
+ outfile_prefix (str): The filename prefix of the json files. If the
145
+ prefix is "somepath/xxx", the json files will be named
146
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
147
+ "somepath/xxx.proposal.json".
148
+
149
+ Returns:
150
+ dict: Possible keys are "bbox", "segm", "proposal", and
151
+ values are corresponding filenames.
152
+ """
153
+ bbox_json_results = []
154
+ segm_json_results = [] if 'masks' in results[0] else None
155
+ for idx, result in enumerate(results):
156
+ image_id = result.get('img_id', idx)
157
+ labels = result['labels']
158
+ bboxes = result['bboxes']
159
+ scores = result['scores']
160
+ # bbox results
161
+ for i, label in enumerate(labels):
162
+ data = dict()
163
+ data['image_id'] = image_id
164
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
165
+ data['score'] = float(scores[i])
166
+ data['category_id'] = self.cat_ids[label]
167
+ bbox_json_results.append(data)
168
+
169
+ if segm_json_results is None:
170
+ continue
171
+
172
+ # segm results
173
+ masks = result['masks']
174
+ mask_scores = result.get('mask_scores', scores)
175
+ for i, label in enumerate(labels):
176
+ data = dict()
177
+ data['image_id'] = image_id
178
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
179
+ data['score'] = float(mask_scores[i])
180
+ data['category_id'] = self.cat_ids[label]
181
+ if isinstance(masks[i]['counts'], bytes):
182
+ masks[i]['counts'] = masks[i]['counts'].decode()
183
+ data['segmentation'] = masks[i]
184
+ segm_json_results.append(data)
185
+
186
+ logger: MMLogger = MMLogger.get_current_instance()
187
+ logger.info('dumping predictions ... ')
188
+ result_files = dict()
189
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
190
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
191
+ dump(bbox_json_results, result_files['bbox'])
192
+
193
+ if segm_json_results is not None:
194
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
195
+ dump(segm_json_results, result_files['segm'])
196
+
197
+ return result_files
198
+
199
+ def _compute_metrics(self, metrics, tgt_api, preds, outfile_prefix, tmp_dir):
200
+ logger: MMLogger = MMLogger.get_current_instance()
201
+
202
+ result_files = self.results2json(preds, outfile_prefix)
203
+
204
+ eval_results = OrderedDict()
205
+ if self.format_only:
206
+ logger.info('results are saved in '
207
+ f'{osp.dirname(outfile_prefix)}')
208
+ return eval_results
209
+
210
+ for metric in metrics:
211
+ logger.info(f'Evaluating {metric}...')
212
+
213
+ # TODO: May refactor fast_eval_recall to an independent metric?
214
+ # fast eval recall
215
+ if metric == 'proposal_fast':
216
+ ar = self.fast_eval_recall(
217
+ preds, self.proposal_nums, self.iou_thrs, logger=logger)
218
+ log_msg = []
219
+ for i, num in enumerate(self.proposal_nums):
220
+ eval_results[f'AR@{num}'] = ar[i]
221
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
222
+ log_msg = ''.join(log_msg)
223
+ logger.info(log_msg)
224
+ continue
225
+
226
+ # evaluate proposal, bbox and segm
227
+ iou_type = 'bbox' if metric == 'proposal' else metric
228
+ if metric not in result_files:
229
+ raise KeyError(f'{metric} is not in results')
230
+ try:
231
+ predictions = load(result_files[metric])
232
+ if iou_type == 'segm':
233
+ # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
234
+ # When evaluating mask AP, if the results contain bbox,
235
+ # cocoapi will use the box area instead of the mask area
236
+ # for calculating the instance area. Though the overall AP
237
+ # is not affected, this leads to different
238
+ # small/medium/large mask AP results.
239
+ for x in predictions:
240
+ x.pop('bbox')
241
+ coco_dt = tgt_api.loadRes(predictions)
242
+
243
+ except IndexError:
244
+ logger.error(
245
+ 'The testing results of the whole dataset is empty.')
246
+ break
247
+
248
+ coco_eval = COCOeval(tgt_api, coco_dt, iou_type)
249
+
250
+ coco_eval.params.catIds = self.cat_ids
251
+ coco_eval.params.imgIds = self.img_ids
252
+ coco_eval.params.maxDets = list(self.proposal_nums)
253
+ coco_eval.params.iouThrs = self.iou_thrs
254
+
255
+ # mapping of cocoEval.stats
256
+ coco_metric_names = {
257
+ 'mAP': 0,
258
+ 'mAP_50': 1,
259
+ 'mAP_75': 2,
260
+ 'mAP_s': 3,
261
+ 'mAP_m': 4,
262
+ 'mAP_l': 5,
263
+ 'AR@100': 6,
264
+ 'AR@300': 7,
265
+ 'AR@1000': 8,
266
+ 'AR_s@1000': 9,
267
+ 'AR_m@1000': 10,
268
+ 'AR_l@1000': 11
269
+ }
270
+ metric_items = self.metric_items
271
+ if metric_items is not None:
272
+ for metric_item in metric_items:
273
+ if metric_item not in coco_metric_names:
274
+ raise KeyError(
275
+ f'metric item "{metric_item}" is not supported')
276
+
277
+ if metric == 'proposal':
278
+ coco_eval.params.useCats = 0
279
+ coco_eval.evaluate()
280
+ coco_eval.accumulate()
281
+ coco_eval.summarize()
282
+ if metric_items is None:
283
+ metric_items = [
284
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
285
+ 'AR_m@1000', 'AR_l@1000'
286
+ ]
287
+
288
+ for item in metric_items:
289
+ val = float(
290
+ f'{coco_eval.stats[coco_metric_names[item]]:.3f}')
291
+ eval_results[item] = val
292
+ else:
293
+ coco_eval.evaluate()
294
+ coco_eval.accumulate()
295
+ coco_eval.summarize()
296
+ if self.classwise: # Compute per-category AP
297
+ # Compute per-category AP
298
+ # from https://github.com/facebookresearch/detectron2/
299
+ precisions = coco_eval.eval['precision']
300
+ # precision: (iou, recall, cls, area range, max dets)
301
+ assert len(self.cat_ids) == precisions.shape[2]
302
+
303
+ results_per_category = []
304
+ for idx, cat_id in enumerate(self.cat_ids):
305
+ # area range index 0: all area ranges
306
+ # max dets index -1: typically 100 per image
307
+ nm = tgt_api.loadCats(cat_id)[0]
308
+ precision = precisions[:, :, idx, 0, -1]
309
+ precision = precision[precision > -1]
310
+ if precision.size:
311
+ ap = np.mean(precision)
312
+ else:
313
+ ap = float('nan')
314
+ results_per_category.append(
315
+ (f'{nm["name"]}', f'{round(ap, 3)}'))
316
+ eval_results[f'{nm["name"]}_precision'] = round(ap, 3)
317
+
318
+ num_columns = min(6, len(results_per_category) * 2)
319
+ results_flatten = list(
320
+ itertools.chain(*results_per_category))
321
+ headers = ['category', 'AP'] * (num_columns // 2)
322
+ results_2d = itertools.zip_longest(*[
323
+ results_flatten[i::num_columns]
324
+ for i in range(num_columns)
325
+ ])
326
+ table_data = [headers]
327
+ table_data += [result for result in results_2d]
328
+ table = AsciiTable(table_data)
329
+ logger.info('\n' + table.table)
330
+
331
+ if metric_items is None:
332
+ metric_items = [
333
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
334
+ ]
335
+
336
+ for metric_item in metric_items:
337
+ key = f'{metric}_{metric_item}'
338
+ val = coco_eval.stats[coco_metric_names[metric_item]]
339
+ eval_results[key] = float(f'{round(val, 3)}')
340
+
341
+ ap = coco_eval.stats[:6]
342
+ logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} '
343
+ f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
344
+ f'{ap[4]:.3f} {ap[5]:.3f}')
345
+
346
+ if tmp_dir is not None:
347
+ tmp_dir.cleanup()
348
+ return eval_results
animeinsseg/data/paste_methods.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union, Tuple, Dict
3
+ import random
4
+ from PIL import Image
5
+ import cv2
6
+ import os.path as osp
7
+ from tqdm import tqdm
8
+ from panopticapi.utils import rgb2id, id2rgb
9
+ from time import time
10
+ import traceback
11
+
12
+ from utils.io_utils import bbox_overlap_area
13
+ from utils.logger import LOGGER
14
+ from utils.constants import COLOR_PALETTE
15
+
16
+
17
+
18
+ class PartitionTree:
19
+
20
+ def __init__(self, bleft: int, btop: int, bright: int, bbottom: int, parent = None) -> None:
21
+ self.left: PartitionTree = None
22
+ self.right: PartitionTree = None
23
+ self.top: PartitionTree = None
24
+ self.bottom: PartitionTree = None
25
+
26
+ if bright < bleft:
27
+ bright = bleft
28
+ if bbottom < btop:
29
+ bbottom = btop
30
+
31
+ self.bleft = bleft
32
+ self.bright = bright
33
+ self.btop = btop
34
+ self.bbottom = bbottom
35
+ self.parent: PartitionTree = parent
36
+
37
+ def is_leaf(self):
38
+ return self.left is None
39
+
40
+ def new_partition(self, new_rect: List):
41
+ self.left = PartitionTree(self.bleft, self.btop, new_rect[0], self.bbottom, self)
42
+ self.top = PartitionTree(self.bleft, self.btop, self.bright, new_rect[1], self)
43
+ self.right = PartitionTree(new_rect[2], self.btop, self.bright, self.bbottom, self)
44
+ self.bottom = PartitionTree(self.bleft, new_rect[3], self.bright, self.bbottom, self)
45
+ if self.parent is not None:
46
+ self.root_update_rect(new_rect)
47
+
48
+ def root_update_rect(self, rect):
49
+ root = self.get_root()
50
+ root.update_child_rect(rect)
51
+
52
+ def update_child_rect(self, rect: List):
53
+ if self.is_leaf():
54
+ self.update_from_rect(rect)
55
+ else:
56
+ self.left.update_child_rect(rect)
57
+ self.right.update_child_rect(rect)
58
+ self.top.update_child_rect(rect)
59
+ self.bottom.update_child_rect(rect)
60
+
61
+ def get_root(self):
62
+ if self.parent is not None:
63
+ return self.parent.get_root()
64
+ else:
65
+ return self
66
+
67
+
68
+ def update_from_rect(self, rect: List):
69
+ if not self.is_leaf():
70
+ return
71
+ ix = min(self.bright, rect[2]) - max(self.bleft, rect[0])
72
+ iy = min(self.bbottom, rect[3]) - max(self.btop, rect[1])
73
+ if not (ix > 0 and iy > 0):
74
+ return
75
+
76
+ new_ltrb0 = np.array([self.bleft, self.btop, self.bright, self.bbottom])
77
+ new_ltrb1 = new_ltrb0.copy()
78
+
79
+ if rect[0] > self.bleft and rect[0] < self.bright:
80
+ new_ltrb0[2] = rect[0]
81
+ else:
82
+ new_ltrb0[0] = rect[2]
83
+
84
+ if rect[1] > self.btop and rect[1] < self.bbottom:
85
+ new_ltrb1[3]= rect[1]
86
+ else:
87
+ new_ltrb1[1] = rect[3]
88
+
89
+ if (new_ltrb0[2:] - new_ltrb0[:2]).prod() > (new_ltrb1[2:] - new_ltrb1[:2]).prod():
90
+ self.bleft, self.btop, self.bright, self.bbottom = new_ltrb0
91
+ else:
92
+ self.bleft, self.btop, self.bright, self.bbottom = new_ltrb1
93
+
94
+ @property
95
+ def width(self) -> int:
96
+ return self.bright - self.bleft
97
+
98
+ @property
99
+ def height(self) -> int:
100
+ return self.bbottom - self.btop
101
+
102
+ def prefer_partition(self, tgt_h: int, tgt_w: int):
103
+ if self.is_leaf():
104
+ return self, min(self.width / tgt_w, 1.2) * min(self.height / tgt_h, 1.2)
105
+ else:
106
+ lp, ls = self.left.prefer_partition(tgt_h, tgt_w)
107
+ rp, rs = self.right.prefer_partition(tgt_h, tgt_w)
108
+ tp, ts = self.top.prefer_partition(tgt_h, tgt_w)
109
+ bp, bs = self.bottom.prefer_partition(tgt_h, tgt_w)
110
+ preferp = [(p, s) for s, p in sorted(zip([ls, rs, ts, bs],[lp, rp, tp, bp]), key=lambda pair: pair[0], reverse=True)][0]
111
+ return preferp
112
+
113
+ def new_random_pos(self, fg_h: int, fg_w: int, im_h: int, im_w: int, random_sample: bool = False):
114
+ extx, exty = int(fg_w / 3), int(fg_h / 3)
115
+ extxb, extyb = int(fg_w / 10), int(fg_h / 10)
116
+ region_w, region_h = self.width + extx, self.height + exty
117
+ downscale_ratio = max(min(region_w / fg_w, region_h / fg_h), 0.8)
118
+ if downscale_ratio < 1:
119
+ fg_h = int(downscale_ratio * fg_h)
120
+ fg_w = int(downscale_ratio * fg_w)
121
+
122
+ max_x, max_y = self.bright + extx - fg_w, self.bbottom + exty - fg_h
123
+ max_x = min(im_w+extxb-fg_w, max_x)
124
+ max_y = min(im_h+extyb-fg_h, max_y)
125
+ min_x = max(min(self.bright + extx - fg_w, self.bleft - extx), -extx)
126
+ min_x = max(-extxb, min_x)
127
+ min_y = max(min(self.bbottom + exty - fg_h, self.btop - exty), -exty)
128
+ min_y = max(-extyb, min_y)
129
+ px, py = min_x, min_y
130
+ if min_x < max_x:
131
+ if random_sample:
132
+ px = random.randint(min_x, max_x)
133
+ else:
134
+ px = int((min_x + max_x) / 2)
135
+ if min_y < max_y:
136
+ if random_sample:
137
+ py = random.randint(min_y, max_y)
138
+ else:
139
+ py = int((min_y + max_y) / 2)
140
+ return px, py, downscale_ratio
141
+
142
+ def drawpartition(self, image: np.ndarray, color = None):
143
+ if color is None:
144
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
145
+ if not self.is_leaf():
146
+ cv2.rectangle(image, (self.bleft, self.btop), (self.bright, self.bbottom), color, 2)
147
+ if not self.is_leaf():
148
+ c = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
149
+ self.left.drawpartition(image, c)
150
+ self.right.drawpartition(image, c)
151
+ self.top.drawpartition(image, c)
152
+ self.bottom.drawpartition(image, c)
153
+
154
+
155
+ def paste_one_fg(fg_pil: Image, bg: Image, segments: np.ndarray, px: int, py: int, seg_color: Tuple, cal_area=True):
156
+
157
+ fg_h, fg_w = fg_pil.height, fg_pil.width
158
+ im_h, im_w = bg.height, bg.width
159
+
160
+ bg.paste(fg_pil, (px, py), mask=fg_pil)
161
+
162
+
163
+ bgx1, bgx2, bgy1, bgy2 = px, px+fg_w, py, py+fg_h
164
+ fgx1, fgx2, fgy1, fgy2 = 0, fg_w, 0, fg_h
165
+ if bgx1 < 0:
166
+ fgx1 = -bgx1
167
+ bgx1 = 0
168
+ if bgy1 < 0:
169
+ fgy1 = -bgy1
170
+ bgy1 = 0
171
+ if bgx2 > im_w:
172
+ fgx2 = im_w - bgx2
173
+ bgx2 = im_w
174
+ if bgy2 > im_h:
175
+ fgy2 = im_h - bgy2
176
+ bgy2 = im_h
177
+
178
+ fg_mask = np.array(fg_pil)[fgy1: fgy2, fgx1: fgx2, 3] > 30
179
+ segments[bgy1: bgy2, bgx1: bgx2][np.where(fg_mask)] = seg_color
180
+
181
+ if cal_area:
182
+ area = fg_mask.sum()
183
+ else:
184
+ area = 1
185
+ bbox = [bgx1, bgy1, bgx2-bgx1, bgy2-bgy1]
186
+ return area, bbox, [bgx1, bgy1, bgx2, bgy2]
187
+
188
+
189
+ def partition_paste(fg_list, bg: Image):
190
+ segments_info = []
191
+
192
+ fg_list.sort(key = lambda x: x['image'].shape[0] * x['image'].shape[1], reverse=True)
193
+ pnode: PartitionTree = None
194
+ im_h, im_w = bg.height, bg.width
195
+
196
+ ptree = PartitionTree(0, 0, bg.width, bg.height)
197
+
198
+ segments = np.zeros((im_h, im_w, 3), np.uint8)
199
+ for ii, fg_dict in enumerate(fg_list):
200
+ fg = fg_dict['image']
201
+ fg_h, fg_w = fg.shape[:2]
202
+ pnode, _ = ptree.prefer_partition(fg_h, fg_w)
203
+ px, py, downscale_ratio = pnode.new_random_pos(fg_h, fg_w, im_h, im_w, True)
204
+
205
+ fg_pil = Image.fromarray(fg)
206
+ if downscale_ratio < 1:
207
+ fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS)
208
+ # fg_h, fg_w = fg_pil.height, fg_pil.width
209
+
210
+ seg_color = COLOR_PALETTE[ii]
211
+ area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=False)
212
+ pnode.new_partition(xyxy)
213
+
214
+ segments_info.append({
215
+ 'id': rgb2id(seg_color),
216
+ 'bbox': bbox,
217
+ 'area': area
218
+ })
219
+
220
+ return segments_info, segments
221
+ # if downscale_ratio < 1:
222
+ # fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS)
223
+ # fg_h, fg_w = fg_pil.height, fg_pil.width
224
+
225
+
226
+ def gen_fg_regbboxes(fg_list: List[Dict], tgt_size: int, min_overlap=0.15, max_overlap=0.8):
227
+
228
+ def _sample_y(h):
229
+ y = (tgt_size - h) // 2
230
+ if y > 0:
231
+ yrange = min(y, h // 4)
232
+ y += random.randint(-yrange, yrange)
233
+ return y
234
+ else:
235
+ return 0
236
+
237
+ shape_list = []
238
+ depth_list = []
239
+
240
+
241
+ for fg_dict in fg_list:
242
+ shape_list.append(fg_dict['image'].shape[:2])
243
+
244
+ shape_list = np.array(shape_list)
245
+ depth_list = np.random.random(len(fg_list))
246
+ depth_list[shape_list[..., 1] > 0.6 * tgt_size] += 1
247
+
248
+ # num_fg = len(fg_list)
249
+ # grid_sample = random.random() < 0.4 or num_fg > 6
250
+ # grid_sample = grid_sample and num_fg < 9 and num_fg > 3
251
+ # grid_sample = False
252
+ # if grid_sample:
253
+ # grid_pos = np.arange(9)
254
+ # np.random.shuffle(grid_pos)
255
+ # grid_pos = grid_pos[: num_fg]
256
+ # grid_x = grid_pos % 3
257
+ # grid_y = grid_pos // 3
258
+
259
+ # else:
260
+ pos_list = [[0, _sample_y(shape_list[0][0])]]
261
+ pre_overlap = 0
262
+ for ii, ((h, w), d) in enumerate(zip(shape_list[1:], depth_list[1:])):
263
+ (preh, prew), predepth, (prex, prey) = shape_list[ii], depth_list[ii], pos_list[ii]
264
+
265
+ isfg = d < predepth
266
+ y = _sample_y(h)
267
+ x = prex+prew
268
+ if isfg:
269
+ min_x = max_x = x
270
+ if pre_overlap < max_overlap:
271
+ min_x -= (max_overlap - pre_overlap) * prew
272
+ min_x = int(min_x)
273
+ if pre_overlap < min_overlap:
274
+ max_x -= (min_overlap - pre_overlap) * prew
275
+ max_x = int(max_x)
276
+ x = random.randint(min_x, max_x)
277
+ pre_overlap = 0
278
+ else:
279
+ overlap = random.uniform(min_overlap, max_overlap)
280
+ x -= int(overlap * w)
281
+ area = h * w
282
+ overlap_area = bbox_overlap_area([x, y, w, h], [prex, prey, prew, preh])
283
+ pre_overlap = overlap_area / area
284
+
285
+ pos_list.append([x, y])
286
+
287
+ pos_list = np.array(pos_list)
288
+ last_x2 = pos_list[-1][0] + shape_list[-1][1]
289
+ valid_shiftx = tgt_size - last_x2
290
+ if valid_shiftx > 0:
291
+ shiftx = random.randint(0, valid_shiftx)
292
+ pos_list[:, 0] += shiftx
293
+ else:
294
+ pos_list[:, 0] += valid_shiftx // 2
295
+
296
+ for pos, fg_dict, depth in zip(pos_list, fg_list, depth_list):
297
+ fg_dict['pos'] = pos
298
+ fg_dict['depth'] = depth
299
+ fg_list.sort(key=lambda x: x['depth'], reverse=True)
300
+
301
+
302
+
303
+ def regular_paste(fg_list, bg: Image, regen_bboxes=False):
304
+ segments_info = []
305
+ im_h, im_w = bg.height, bg.width
306
+
307
+ if regen_bboxes:
308
+ random.shuffle(fg_list)
309
+ gen_fg_regbboxes(fg_list, im_h)
310
+
311
+ segments = np.zeros((im_h, im_w, 3), np.uint8)
312
+ for ii, fg_dict in enumerate(fg_list):
313
+ fg = fg_dict['image']
314
+
315
+ px, py = fg_dict.pop('pos')
316
+ fg_pil = Image.fromarray(fg)
317
+
318
+ seg_color = COLOR_PALETTE[ii]
319
+ area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=True)
320
+
321
+ segments_info.append({
322
+ 'id': rgb2id(seg_color),
323
+ 'bbox': bbox,
324
+ 'area': area
325
+ })
326
+
327
+ return segments_info, segments
animeinsseg/data/sampler.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from random import choice as rchoice
3
+ from random import randint
4
+ import random
5
+ import cv2, traceback, imageio
6
+ import os.path as osp
7
+
8
+ from typing import Optional, List, Union, Tuple, Dict
9
+ from utils.io_utils import imread_nogrey_rgb, json2dict
10
+ from .transforms import rotate_image
11
+ from utils.logger import LOGGER
12
+
13
+
14
+ class NameSampler:
15
+
16
+ def __init__(self, name_prob_dict, sample_num=2048) -> None:
17
+ self.name_prob_dict = name_prob_dict
18
+ self._id2name = list(name_prob_dict.keys())
19
+ self.sample_ids = []
20
+
21
+ total_prob = 0.
22
+ for ii, (_, prob) in enumerate(name_prob_dict.items()):
23
+ tgt_num = int(prob * sample_num)
24
+ total_prob += prob
25
+ if tgt_num > 0:
26
+ self.sample_ids += [ii] * tgt_num
27
+
28
+ nsamples = len(self.sample_ids)
29
+ assert prob <= 1
30
+ if prob < 1 and nsamples < sample_num:
31
+ self.sample_ids += [len(self._id2name)] * (sample_num - nsamples)
32
+ self._id2name.append('_')
33
+
34
+ def sample(self) -> str:
35
+ return self._id2name[rchoice(self.sample_ids)]
36
+
37
+
38
+ class PossionSampler:
39
+ def __init__(self, lam=3, min_val=1, max_val=8) -> None:
40
+ self._distr = np.random.poisson(lam, 1024)
41
+ invalid = np.where(np.logical_or(self._distr<min_val, self._distr > max_val))
42
+ self._distr[invalid] = np.random.randint(min_val, max_val, len(invalid[0]))
43
+
44
+ def sample(self) -> int:
45
+ return rchoice(self._distr)
46
+
47
+
48
+ class NormalSampler:
49
+ def __init__(self, loc=0.33, std=0.2, min_scale=0.15, max_scale=0.85, scalar=1, to_int = True):
50
+ s = np.random.normal(loc, std, 4096)
51
+ valid = np.where(np.logical_and(s>min_scale, s<max_scale))
52
+ self._distr = s[valid] * scalar
53
+ if to_int:
54
+ self._distr = self._distr.astype(np.int32)
55
+
56
+ def sample(self) -> int:
57
+ return rchoice(self._distr)
58
+
59
+
60
+ class PersonBBoxSampler:
61
+
62
+ def __init__(self, sample_path: Union[str, List]='data/cocoperson_bbox_samples.json', fg_info_list: List = None, fg_transform=None, is_train=True) -> None:
63
+ if isinstance(sample_path, str):
64
+ sample_path = [sample_path]
65
+ self.bbox_list = []
66
+ for sp in sample_path:
67
+ bboxlist = json2dict(sp)
68
+ for bboxes in bboxlist:
69
+ if isinstance(bboxes, dict):
70
+ bboxes = bboxes['bboxes']
71
+ bboxes = np.array(bboxes)
72
+ bboxes[:, [0, 1]] -= bboxes[:, [0, 1]].min(axis=0)
73
+ self.bbox_list.append(bboxes)
74
+
75
+ self.fg_info_list = fg_info_list
76
+ self.fg_transform = fg_transform
77
+ self.is_train = is_train
78
+
79
+ def sample(self, tgt_size: int, scale_range=(1, 1), size_thres=(0.02, 0.85)) -> List[np.ndarray]:
80
+ bboxes_normalized = rchoice(self.bbox_list)
81
+ if scale_range[0] != 1 or scale_range[1] != 1:
82
+ bbox_scale = random.uniform(scale_range[0], scale_range[1])
83
+ else:
84
+ bbox_scale = 1
85
+ bboxes = (bboxes_normalized * tgt_size * bbox_scale).astype(np.int32)
86
+
87
+ xyxy_array = np.copy(bboxes)
88
+ xyxy_array[:, [2, 3]] += xyxy_array[:, [0, 1]]
89
+ x_max, y_max = xyxy_array[:, 2].max(), xyxy_array[:, 3].max()
90
+
91
+ x_shift = tgt_size - x_max
92
+ x_shift = randint(0, x_shift) if x_shift > 0 else 0
93
+ y_shift = tgt_size - y_max
94
+ y_shift = randint(0, y_shift) if y_shift > 0 else 0
95
+
96
+ bboxes[:, [0, 1]] += [x_shift, y_shift]
97
+ valid_bboxes = []
98
+ max_size = size_thres[1] * tgt_size
99
+ min_size = size_thres[0] * tgt_size
100
+ for bbox in bboxes:
101
+ w = min(bbox[2], tgt_size - bbox[0])
102
+ h = min(bbox[3], tgt_size - bbox[1])
103
+ if max(h, w) < max_size and min(h, w) > min_size:
104
+ valid_bboxes.append(bbox)
105
+ return valid_bboxes
106
+
107
+ def sample_matchfg(self, tgt_size: int):
108
+ while True:
109
+ bboxes = self.sample(tgt_size, (1.1, 1.8))
110
+ if len(bboxes) > 0:
111
+ break
112
+ MIN_FG_SIZE = 20
113
+ num_fg = len(bboxes)
114
+ rotate = 20 if self.is_train else 15
115
+ fgs = random_load_nfg(num_fg, self.fg_info_list, random_rotate_prob=0.33, random_rotate=rotate)
116
+ assert len(fgs) == num_fg
117
+
118
+ bboxes.sort(key=lambda x: x[2] / x[3])
119
+ fgs.sort(key=lambda x: x['asp_ratio'])
120
+
121
+ for fg, bbox in zip(fgs, bboxes):
122
+ x, y, w, h = bbox
123
+ img = fg['image']
124
+ im_h, im_w = img.shape[:2]
125
+ if im_h < h and im_w < w:
126
+ scale = min(h / im_h, w / im_w)
127
+ new_h, new_w = int(scale * im_h), int(scale * im_w)
128
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
129
+ else:
130
+ scale_h, scale_w = min(1, h / im_h), min(1, w / im_w)
131
+ scale = (scale_h + scale_w) / 2
132
+ if scale < 1:
133
+ new_h, new_w = max(int(scale * im_h), MIN_FG_SIZE), max(int(scale * im_w), MIN_FG_SIZE)
134
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
135
+
136
+ if self.fg_transform is not None:
137
+ img = self.fg_transform(image=img)['image']
138
+
139
+ im_h, im_w = img.shape[:2]
140
+ fg['image'] = img
141
+ px = int(x + w / 2 - im_w / 2)
142
+ py = int(y + h / 2 - im_h / 2)
143
+ fg['pos'] = (px, py)
144
+
145
+ random.shuffle(fgs)
146
+
147
+ slist, llist = [], []
148
+ large_size = int(tgt_size * 0.55)
149
+ for fg in fgs:
150
+ if max(fg['image'].shape[:2]) > large_size:
151
+ llist.append(fg)
152
+ else:
153
+ slist.append(fg)
154
+ return llist + slist
155
+
156
+
157
+ def random_load_nfg(num_fg: int, fg_info_list: List[Union[Dict, str]], random_rotate=0, random_rotate_prob=0.):
158
+ fgs = []
159
+ while len(fgs) < num_fg:
160
+ fg, fginfo = random_load_valid_fg(fg_info_list)
161
+ if random.random() < random_rotate_prob:
162
+ rotate_deg = randint(-random_rotate, random_rotate)
163
+ fg = rotate_image(fg, rotate_deg, alpha_crop=True)
164
+
165
+ asp_ratio = fg.shape[1] / fg.shape[0]
166
+ fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
167
+ while len(fgs) < num_fg and random.random() < 0.12:
168
+ fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
169
+
170
+ return fgs
171
+
172
+
173
+ def random_load_valid_fg(fg_info_list: List[Union[Dict, str]]) -> Tuple[np.ndarray, Dict]:
174
+ while True:
175
+ item = fginfo = rchoice(fg_info_list)
176
+
177
+ file_path = fginfo['file_path']
178
+ if 'root_dir' in fginfo and fginfo['root_dir']:
179
+ file_path = osp.join(fginfo['root_dir'], file_path)
180
+
181
+ try:
182
+ fg = imageio.imread(file_path)
183
+ except:
184
+ LOGGER.error(traceback.format_exc())
185
+ LOGGER.error(f'invalid fg: {file_path}')
186
+ fg_info_list.remove(item)
187
+ continue
188
+
189
+ c = 1
190
+ if len(fg.shape) == 3:
191
+ c = fg.shape[-1]
192
+ if c != 4:
193
+ LOGGER.warning(f'fg {file_path} doesnt have alpha channel')
194
+ fg_info_list.remove(item)
195
+ else:
196
+ if 'xyxy' in fginfo:
197
+ x1, y1, x2, y2 = fginfo['xyxy']
198
+ else:
199
+ oh, ow = fg.shape[:2]
200
+ ksize = 5
201
+ mask = cv2.blur(fg[..., 3], (ksize,ksize))
202
+ _, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY)
203
+
204
+ x1, y1, w, h = cv2.boundingRect(cv2.findNonZero(mask))
205
+ x2, y2 = x1 + w, y1 + h
206
+ if oh - h > 15 or ow - w > 15:
207
+ crop = True
208
+ else:
209
+ x1 = y1 = 0
210
+ x2, y2 = ow, oh
211
+
212
+ fginfo['xyxy'] = [x1, y1, x2, y2]
213
+ fg = fg[y1: y2, x1: x2]
214
+ return fg, fginfo
215
+
216
+
217
+ def random_load_valid_bg(bg_list: List[str]) -> np.ndarray:
218
+ while True:
219
+ try:
220
+ bgp = rchoice(bg_list)
221
+ return imread_nogrey_rgb(bgp)
222
+ except:
223
+ LOGGER.error(traceback.format_exc())
224
+ LOGGER.error(f'invalid bg: {bgp}')
225
+ bg_list.remove(bgp)
226
+ continue
animeinsseg/data/syndataset.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union, Tuple, Dict
3
+ import random
4
+ from PIL import Image
5
+ import cv2
6
+ import imageio, os
7
+ import os.path as osp
8
+ from tqdm import tqdm
9
+ from panopticapi.utils import rgb2id
10
+ import traceback
11
+
12
+ from utils.io_utils import mask2rle, dict2json, fgbg_hist_matching
13
+ from utils.logger import LOGGER
14
+ from utils.constants import CATEGORIES, IMAGE_ID_ZFILL
15
+ from .transforms import get_fg_transforms, get_bg_transforms, quantize_image, resize2height, rotate_image
16
+ from .sampler import random_load_valid_bg, random_load_valid_fg, NameSampler, NormalSampler, PossionSampler, PersonBBoxSampler
17
+ from .paste_methods import regular_paste, partition_paste
18
+
19
+
20
+ def syn_animecoco_dataset(
21
+ bg_list: List, fg_info_list: List[Dict], dataset_save_dir: str, policy: str='train',
22
+ tgt_size=640, syn_num_multiplier=2.5, regular_paste_prob=0.4, person_paste_prob=0.4,
23
+ max_syn_num=-1, image_id_start=0, obj_id_start=0, hist_match_prob=0.2, quantize_prob=0.25):
24
+
25
+ LOGGER.info(f'syn data policy: {policy}')
26
+ LOGGER.info(f'background: {len(bg_list)} foreground: {len(fg_info_list)}')
27
+
28
+ numfg_sampler = PossionSampler(min_val=1, max_val=9, lam=2.5)
29
+ numfg_regpaste_sampler = PossionSampler(min_val=2, max_val=9, lam=3.5)
30
+ regpaste_size_sampler = NormalSampler(scalar=tgt_size, to_int=True, max_scale=0.75)
31
+ color_correction_sampler = NameSampler({'hist_match': hist_match_prob, 'quantize': quantize_prob}, )
32
+ paste_method_sampler = NameSampler({'regular': regular_paste_prob, 'personbbox': person_paste_prob,
33
+ 'partition': 1-regular_paste_prob-person_paste_prob})
34
+
35
+ fg_transform = get_fg_transforms(tgt_size, transform_variant=policy)
36
+ fg_distort_transform = get_fg_transforms(tgt_size, transform_variant='distort_only')
37
+ bg_transform = get_bg_transforms('train', tgt_size)
38
+
39
+ image_id = image_id_start + 1
40
+ obj_id = obj_id_start + 1
41
+
42
+ det_annotations, image_meta = [], []
43
+
44
+ syn_num = int(syn_num_multiplier * len(fg_info_list))
45
+ if max_syn_num > 0:
46
+ syn_num = max_syn_num
47
+
48
+ ann_save_dir = osp.join(dataset_save_dir, 'annotations')
49
+ image_save_dir = osp.join(dataset_save_dir, policy)
50
+
51
+ if not osp.exists(image_save_dir):
52
+ os.makedirs(image_save_dir)
53
+ if not osp.exists(ann_save_dir):
54
+ os.makedirs(ann_save_dir)
55
+
56
+ is_train = policy == 'train'
57
+ if is_train:
58
+ jpg_save_quality = [75, 85, 95]
59
+ else:
60
+ jpg_save_quality = [95]
61
+
62
+ if isinstance(fg_info_list[0], str):
63
+ for ii, fgp in enumerate(fg_info_list):
64
+ if isinstance(fgp, str):
65
+ fg_info_list[ii] = {'file_path': fgp, 'tag_string': [], 'danbooru': False, 'category_id': 0}
66
+
67
+ if person_paste_prob > 0:
68
+ personbbox_sampler = PersonBBoxSampler(
69
+ 'data/cocoperson_bbox_samples.json', fg_info_list,
70
+ fg_transform=fg_distort_transform if is_train else None, is_train=is_train)
71
+
72
+ total = tqdm(range(syn_num))
73
+ for fin in total:
74
+ try:
75
+ paste_method = paste_method_sampler.sample()
76
+
77
+ fgs = []
78
+ if paste_method == 'regular':
79
+ num_fg = numfg_regpaste_sampler.sample()
80
+ size = regpaste_size_sampler.sample()
81
+ while len(fgs) < num_fg:
82
+ tgt_height = int(random.uniform(0.7, 1.2) * size)
83
+ fg, fginfo = random_load_valid_fg(fg_info_list)
84
+ fg = resize2height(fg, tgt_height)
85
+ if is_train:
86
+ fg = fg_distort_transform(image=fg)['image']
87
+ rotate_deg = random.randint(-40, 40)
88
+ else:
89
+ rotate_deg = random.randint(-30, 30)
90
+ if random.random() < 0.3:
91
+ fg = rotate_image(fg, rotate_deg, alpha_crop=True)
92
+ fgs.append({'image': fg, 'fginfo': fginfo})
93
+ while len(fgs) < num_fg and random.random() < 0.15:
94
+ fgs.append({'image': fg, 'fginfo': fginfo})
95
+ elif paste_method == 'personbbox':
96
+ fgs = personbbox_sampler.sample_matchfg(tgt_size)
97
+ else:
98
+ num_fg = numfg_sampler.sample()
99
+ fgs = []
100
+ for ii in range(num_fg):
101
+ fg, fginfo = random_load_valid_fg(fg_info_list)
102
+ fg = fg_transform(image=fg)['image']
103
+ h, w = fg.shape[:2]
104
+ if num_fg > 6:
105
+ downscale = min(tgt_size / 2.5 / w, tgt_size / 2.5 / h)
106
+ if downscale < 1:
107
+ fg = cv2.resize(fg, (int(w * downscale), int(h * downscale)), interpolation=cv2.INTER_AREA)
108
+ fgs.append({'image': fg, 'fginfo': fginfo})
109
+
110
+ bg = random_load_valid_bg(bg_list)
111
+ bg = bg_transform(image=bg)['image']
112
+
113
+ color_correct = color_correction_sampler.sample()
114
+
115
+ if color_correct == 'hist_match':
116
+ fgbg_hist_matching(fgs, bg)
117
+
118
+ bg: Image = Image.fromarray(bg)
119
+
120
+ if paste_method == 'regular':
121
+ segments_info, segments = regular_paste(fgs, bg, regen_bboxes=True)
122
+ elif paste_method == 'personbbox':
123
+ segments_info, segments = regular_paste(fgs, bg, regen_bboxes=False)
124
+ elif paste_method == 'partition':
125
+ segments_info, segments = partition_paste(fgs, bg, )
126
+ else:
127
+ print(f'invalid paste method: {paste_method}')
128
+ raise NotImplementedError
129
+
130
+ image = np.array(bg)
131
+ if color_correct == 'quantize':
132
+ mask = cv2.inRange(segments, np.array([0,0,0]), np.array([0,0,0]))
133
+ # cv2.imshow("mask", mask)
134
+ image = quantize_image(image, random.choice([12, 16, 32]), 'kmeans', mask=mask)[0]
135
+
136
+ # postprocess & check if instance is valid
137
+ for ii, segi in enumerate(segments_info):
138
+ if segi['area'] == 0:
139
+ continue
140
+ x, y, w, h = segi['bbox']
141
+ x2, y2 = x+w, y+h
142
+ c = segments[y: y2, x: x2]
143
+ pan_png = rgb2id(c)
144
+ cmask = (pan_png == segi['id'])
145
+ area = cmask.sum()
146
+
147
+ if paste_method != 'partition' and \
148
+ area / (fgs[ii]['image'][..., 3] > 30).sum() < 0.25:
149
+ # cv2.imshow('im', fgs[ii]['image'])
150
+ # cv2.imshow('mask', fgs[ii]['image'][..., 3])
151
+ # cv2.imshow('seg', segments)
152
+ # cv2.waitKey(0)
153
+ cmask_ids = np.where(cmask)
154
+ segments[y: y2, x: x2][cmask_ids] = 0
155
+ image[y: y2, x: x2][cmask_ids] = (127, 127, 127)
156
+ continue
157
+
158
+ cmask = cmask.astype(np.uint8) * 255
159
+ dx, dy, w, h = cv2.boundingRect(cv2.findNonZero(cmask))
160
+ _bbox = [dx + x, dy + y, w, h]
161
+
162
+ seg = cv2.copyMakeBorder(cmask, y, tgt_size-y2, x, tgt_size-x2, cv2.BORDER_CONSTANT) > 0
163
+ assert seg.shape[0] == tgt_size and seg.shape[1] == tgt_size
164
+ segmentation = mask2rle(seg)
165
+
166
+ det_annotations.append({
167
+ 'id': obj_id,
168
+ 'category_id': fgs[ii]['fginfo']['category_id'],
169
+ 'iscrowd': 0,
170
+ 'segmentation': segmentation,
171
+ 'image_id': image_id,
172
+ 'area': area,
173
+ 'tag_string': fgs[ii]['fginfo']['tag_string'],
174
+ 'tag_string_character': fgs[ii]['fginfo']['tag_string_character'],
175
+ 'bbox': [float(c) for c in _bbox]
176
+ })
177
+
178
+ obj_id += 1
179
+ # cv2.imshow('c', cv2.cvtColor(c, cv2.COLOR_RGB2BGR))
180
+ # cv2.imshow('cmask', cmask)
181
+ # cv2.waitKey(0)
182
+
183
+ image_id_str = str(image_id).zfill(IMAGE_ID_ZFILL)
184
+ image_file_name = image_id_str + '.jpg'
185
+ image_meta.append({
186
+ "id": image_id,"height": tgt_size,"width": tgt_size, "file_name": image_file_name, "id": image_id
187
+ })
188
+
189
+ # LOGGER.info(f'paste method: {paste_method} color correct: {color_correct}')
190
+ # cv2.imshow('image', cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
191
+ # cv2.imshow('segments', cv2.cvtColor(segments, cv2.COLOR_RGB2BGR))
192
+ # cv2.waitKey(0)
193
+
194
+ imageio.imwrite(osp.join(image_save_dir, image_file_name), image, quality=random.choice(jpg_save_quality))
195
+ image_id += 1
196
+
197
+ except:
198
+ LOGGER.error(traceback.format_exc())
199
+ continue
200
+
201
+ det_meta = {
202
+ "info": {},
203
+ "licenses": [],
204
+ "images": image_meta,
205
+ "annotations": det_annotations,
206
+ "categories": CATEGORIES
207
+ }
208
+
209
+ detp = osp.join(ann_save_dir, f'det_{policy}.json')
210
+ dict2json(det_meta, detp)
211
+ LOGGER.info(f'annotations saved to {detp}')
212
+
213
+ return image_id, obj_id
animeinsseg/data/transforms.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations import DualIAATransform, to_tuple
3
+ import imgaug.augmenters as iaa
4
+ import cv2
5
+ from tqdm import tqdm
6
+ from sklearn.cluster import KMeans
7
+ from sklearn.metrics import pairwise_distances_argmin
8
+ from sklearn.utils import shuffle
9
+ import numpy as np
10
+
11
+ class IAAAffine2(DualIAATransform):
12
+ """Place a regular grid of points on the input and randomly move the neighbourhood of these point around
13
+ via affine transformations.
14
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
15
+ Args:
16
+ p (float): probability of applying the transform. Default: 0.5.
17
+ Targets:
18
+ image, mask
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ scale=(0.7, 1.3),
24
+ translate_percent=None,
25
+ translate_px=None,
26
+ rotate=0.0,
27
+ shear=(-0.1, 0.1),
28
+ order=1,
29
+ cval=0,
30
+ mode="reflect",
31
+ always_apply=False,
32
+ p=0.5,
33
+ ):
34
+ super(IAAAffine2, self).__init__(always_apply, p)
35
+ self.scale = dict(x=scale, y=scale)
36
+ self.translate_percent = to_tuple(translate_percent, 0)
37
+ self.translate_px = to_tuple(translate_px, 0)
38
+ self.rotate = to_tuple(rotate)
39
+ self.shear = dict(x=shear, y=shear)
40
+ self.order = order
41
+ self.cval = cval
42
+ self.mode = mode
43
+
44
+ @property
45
+ def processor(self):
46
+ return iaa.Affine(
47
+ self.scale,
48
+ self.translate_percent,
49
+ self.translate_px,
50
+ self.rotate,
51
+ self.shear,
52
+ self.order,
53
+ self.cval,
54
+ self.mode,
55
+ )
56
+
57
+ def get_transform_init_args_names(self):
58
+ return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")
59
+
60
+
61
+ class IAAPerspective2(DualIAATransform):
62
+ """Perform a random four point perspective transform of the input.
63
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
64
+ Args:
65
+ scale ((float, float): standard deviation of the normal distributions. These are used to sample
66
+ the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
67
+ p (float): probability of applying the transform. Default: 0.5.
68
+ Targets:
69
+ image, mask
70
+ """
71
+
72
+ def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5,
73
+ order=1, cval=0, mode="replicate"):
74
+ super(IAAPerspective2, self).__init__(always_apply, p)
75
+ self.scale = to_tuple(scale, 1.0)
76
+ self.keep_size = keep_size
77
+ self.cval = cval
78
+ self.mode = mode
79
+
80
+ @property
81
+ def processor(self):
82
+ return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval)
83
+
84
+ def get_transform_init_args_names(self):
85
+ return ("scale", "keep_size")
86
+
87
+
88
+ def get_bg_transforms(transform_variant, out_size):
89
+ max_size = int(out_size * 1.2)
90
+ if transform_variant == 'train':
91
+ transform = [
92
+ A.SmallestMaxSize(max_size, always_apply=True, interpolation=cv2.INTER_AREA),
93
+ A.RandomResizedCrop(out_size, out_size, scale=(0.9, 1.5), p=1, ratio=(0.9, 1.1)),
94
+ ]
95
+ else:
96
+ transform = [
97
+ A.SmallestMaxSize(out_size, always_apply=True),
98
+ A.RandomCrop(out_size, out_size, True),
99
+ ]
100
+ return A.Compose(transform)
101
+
102
+
103
+ def get_fg_transforms(out_size, scale_limit=(-0.85, -0.3), transform_variant='train'):
104
+ if transform_variant == 'train':
105
+ transform = [
106
+ A.LongestMaxSize(out_size),
107
+ A.RandomScale(scale_limit=scale_limit, always_apply=True, interpolation=cv2.INTER_AREA),
108
+ IAAAffine2(scale=(1, 1),
109
+ rotate=(-15, 15),
110
+ shear=(-0.1, 0.1), p=0.3, mode='constant'),
111
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
112
+ A.HorizontalFlip(),
113
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT, p=0.3),
114
+ A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=0.3)
115
+ ]
116
+ elif transform_variant == 'distort_only':
117
+ transform = [
118
+ IAAAffine2(scale=(1, 1),
119
+ shear=(-0.1, 0.1), p=0.3, mode='constant'),
120
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
121
+ A.HorizontalFlip(),
122
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT, p=0.3),
123
+ A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=0.3)
124
+ ]
125
+ else:
126
+ transform = [
127
+ A.LongestMaxSize(out_size),
128
+ A.RandomScale(scale_limit=scale_limit, always_apply=True, interpolation=cv2.INTER_LINEAR)
129
+ ]
130
+ return A.Compose(transform)
131
+
132
+
133
+ def get_transforms(transform_variant, out_size, to_float=True):
134
+ if transform_variant == 'distortions':
135
+ transform = [
136
+ IAAAffine2(scale=(1, 1.3),
137
+ rotate=(-20, 20),
138
+ shear=(-0.1, 0.1), p=1, mode='constant'),
139
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
140
+ A.OpticalDistortion(),
141
+ A.HorizontalFlip(),
142
+ A.Sharpen(p=0.3),
143
+ A.CLAHE(),
144
+ A.GaussNoise(p=0.3),
145
+ A.Posterize(),
146
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT),
147
+ ]
148
+ elif transform_variant == 'default':
149
+ transform = [
150
+ A.HorizontalFlip(),
151
+ A.Rotate(20, p=0.3)
152
+ ]
153
+ elif transform_variant == 'identity':
154
+ transform = []
155
+ else:
156
+ raise ValueError(f'Unexpected transform_variant {transform_variant}')
157
+ if to_float:
158
+ transform.append(A.ToFloat())
159
+ return A.Compose(transform)
160
+
161
+
162
+ def get_template_transforms(transform_variant, out_size, to_float=True):
163
+ if transform_variant == 'distortions':
164
+ transform = [
165
+ A.Cutout(p=0.3, max_w_size=30, max_h_size=30, num_holes=1),
166
+ IAAAffine2(scale=(1, 1.3),
167
+ rotate=(-20, 20),
168
+ shear=(-0.1, 0.1), p=1, mode='constant'),
169
+ IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
170
+ A.OpticalDistortion(),
171
+ A.HorizontalFlip(),
172
+ A.Sharpen(p=0.3),
173
+ A.CLAHE(),
174
+ A.GaussNoise(p=0.3),
175
+ A.Posterize(),
176
+ A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT),
177
+ ]
178
+ elif transform_variant == 'identity':
179
+ transform = []
180
+ else:
181
+ raise ValueError(f'Unexpected transform_variant {transform_variant}')
182
+ if to_float:
183
+ transform.append(A.ToFloat())
184
+ return A.Compose(transform)
185
+
186
+
187
+ def rotate_image(mat: np.ndarray, angle: float, alpha_crop: bool = False) -> np.ndarray:
188
+ """
189
+ Rotates an image (angle in degrees) and expands image to avoid cropping
190
+ # https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
191
+ """
192
+
193
+ height, width = mat.shape[:2] # image shape has 3 dimensions
194
+ image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
195
+
196
+ rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
197
+
198
+ # rotation calculates the cos and sin, taking absolutes of those.
199
+ abs_cos = abs(rotation_mat[0,0])
200
+ abs_sin = abs(rotation_mat[0,1])
201
+
202
+ # find the new width and height bounds
203
+ bound_w = int(height * abs_sin + width * abs_cos)
204
+ bound_h = int(height * abs_cos + width * abs_sin)
205
+
206
+ # subtract old image center (bringing image back to origo) and adding the new image center coordinates
207
+ rotation_mat[0, 2] += bound_w/2 - image_center[0]
208
+ rotation_mat[1, 2] += bound_h/2 - image_center[1]
209
+
210
+ # rotate image with the new bounds and translated rotation matrix
211
+ rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
212
+
213
+ if alpha_crop and len(rotated_mat.shape) == 3 and rotated_mat.shape[-1] == 4:
214
+ x, y, w, h = cv2.boundingRect(rotated_mat[..., -1])
215
+ rotated_mat = rotated_mat[y: y+h, x: x+w]
216
+
217
+ return rotated_mat
218
+
219
+
220
+ def recreate_image(codebook, labels, w, h):
221
+ """Recreate the (compressed) image from the code book & labels"""
222
+ return (codebook[labels].reshape(w, h, -1) * 255).astype(np.uint8)
223
+
224
+ def quantize_image(image: np.ndarray, n_colors: int, method='kmeans', mask=None):
225
+ # https://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html
226
+ image = np.array(image, dtype=np.float64) / 255
227
+
228
+ if len(image.shape) == 3:
229
+ w, h, d = tuple(image.shape)
230
+ else:
231
+ w, h = image.shape
232
+ d = 1
233
+
234
+ # assert d == 3
235
+ image_array = image.reshape(-1, d)
236
+
237
+ if method == 'kmeans':
238
+
239
+ image_array_sample = None
240
+ if mask is not None:
241
+ ids = np.where(mask)
242
+ if len(ids[0]) > 10:
243
+ bg = image[ids][::2]
244
+ fg = image[np.where(mask == 0)]
245
+ max_bg_num = int(fg.shape[0] * 1.5)
246
+ if bg.shape[0] > max_bg_num:
247
+ bg = shuffle(bg, random_state=0, n_samples=max_bg_num)
248
+ image_array_sample = np.concatenate((fg, bg), axis=0)
249
+ if image_array_sample.shape[0] > 2048:
250
+ image_array_sample = shuffle(image_array_sample, random_state=0, n_samples=2048)
251
+ else:
252
+ image_array_sample = None
253
+
254
+ if image_array_sample is None:
255
+ image_array_sample = shuffle(image_array, random_state=0, n_samples=2048)
256
+
257
+ kmeans = KMeans(n_clusters=n_colors, n_init=10, random_state=0).fit(
258
+ image_array_sample
259
+ )
260
+
261
+ labels = kmeans.predict(image_array)
262
+ quantized = recreate_image(kmeans.cluster_centers_, labels, w, h)
263
+ return quantized, kmeans.cluster_centers_, labels
264
+
265
+ else:
266
+
267
+ codebook_random = shuffle(image_array, random_state=0, n_samples=n_colors)
268
+ labels_random = pairwise_distances_argmin(codebook_random, image_array, axis=0)
269
+
270
+ return [recreate_image(codebook_random, labels_random, w, h)]
271
+
272
+
273
+ def resize2height(img: np.ndarray, height: int):
274
+ im_h, im_w = img.shape[:2]
275
+ if im_h > height:
276
+ interpolation = cv2.INTER_AREA
277
+ else:
278
+ interpolation = cv2.INTER_LINEAR
279
+ if im_h != height:
280
+ img = cv2.resize(img, (int(height / im_h * im_w), height), interpolation=interpolation)
281
+ return img
282
+
283
+ if __name__ == '__main__':
284
+ import os.path as osp
285
+
286
+ img_path = r'tmp\megumin.png'
287
+ save_dir = r'tmp'
288
+ sample_num = 24
289
+
290
+ tv = 'distortions'
291
+ out_size = 224
292
+ transforms = get_transforms(tv, out_size ,to_float=False)
293
+ img = cv2.imread(img_path)
294
+ for idx in tqdm(range(sample_num)):
295
+ transformed = transforms(image=img)['image']
296
+ print(transformed.shape)
297
+ cv2.imwrite(osp.join(save_dir, str(idx)+'-transform.jpg'), transformed)
298
+ # cv2.waitKey(0)
299
+ pass
animeinsseg/inpainting/__init__.py ADDED
File without changes
animeinsseg/inpainting/ldm_inpaint.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from omegaconf import OmegaConf
5
+ import safetensors
6
+ import os
7
+ import einops
8
+ import cv2
9
+ from PIL import Image, ImageFilter, ImageOps
10
+ from utils.io_utils import resize_pad2divisior
11
+ import os
12
+ from utils.io_utils import submit_request, img2b64
13
+ import json
14
+ # Debug by Francis
15
+ # from ldm.util import instantiate_from_config
16
+ # from ldm.models.diffusion.ddpm import LatentDiffusion
17
+ # from ldm.models.diffusion.ddim import DDIMSampler
18
+ # from ldm.modules.diffusionmodules.util import noise_like
19
+ import io
20
+ import base64
21
+ from requests.auth import HTTPBasicAuth
22
+
23
+ # Debug by Francis
24
+ # def create_model(config_path):
25
+ # config = OmegaConf.load(config_path)
26
+ # model = instantiate_from_config(config.model).cpu()
27
+ # return model
28
+ #
29
+ # def get_state_dict(d):
30
+ # return d.get('state_dict', d)
31
+ #
32
+ # def load_state_dict(ckpt_path, location='cpu'):
33
+ # _, extension = os.path.splitext(ckpt_path)
34
+ # if extension.lower() == ".safetensors":
35
+ # import safetensors.torch
36
+ # state_dict = safetensors.torch.load_file(ckpt_path, device=location)
37
+ # else:
38
+ # state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
39
+ # state_dict = get_state_dict(state_dict)
40
+ # return state_dict
41
+ #
42
+ #
43
+ # def load_ldm_sd(model, path) :
44
+ # if path.endswith('.safetensor') :
45
+ # sd = safetensors.torch.load_file(path)
46
+ # else :
47
+ # sd = load_state_dict(path)
48
+ # model.load_state_dict(sd, strict = False)
49
+ #
50
+ # def fill_mask_input(image, mask):
51
+ # """fills masked regions with colors from image using blur. Not extremely effective."""
52
+ #
53
+ # image_mod = Image.new('RGBA', (image.width, image.height))
54
+ #
55
+ # image_masked = Image.new('RGBa', (image.width, image.height))
56
+ # image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
57
+ #
58
+ # image_masked = image_masked.convert('RGBa')
59
+ #
60
+ # for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
61
+ # blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
62
+ # for _ in range(repeats):
63
+ # image_mod.alpha_composite(blurred)
64
+ #
65
+ # return image_mod.convert("RGB")
66
+ #
67
+ #
68
+ # def get_inpainting_image_condition(model, image, mask) :
69
+ # conditioning_mask = np.array(mask.convert("L"))
70
+ # conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
71
+ # conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
72
+ # conditioning_mask = torch.round(conditioning_mask)
73
+ # conditioning_mask = conditioning_mask.to(device=image.device, dtype=image.dtype)
74
+ # conditioning_image = torch.lerp(
75
+ # image,
76
+ # image * (1.0 - conditioning_mask),
77
+ # 1
78
+ # )
79
+ # conditioning_image = model.get_first_stage_encoding(model.encode_first_stage(conditioning_image))
80
+ # conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=conditioning_image.shape[-2:])
81
+ # conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
82
+ # image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
83
+ # return image_conditioning
84
+ #
85
+ #
86
+ # class GuidedLDM(LatentDiffusion):
87
+ # def __init__(self, *args, **kwargs):
88
+ # super().__init__(*args, **kwargs)
89
+ #
90
+ # @torch.no_grad()
91
+ # def img2img_inpaint(
92
+ # self,
93
+ # image: Image.Image,
94
+ # c_text: str,
95
+ # uc_text: str,
96
+ # mask: Image.Image,
97
+ # ddim_steps = 50,
98
+ # mask_blur: int = 0,
99
+ # use_cuda: bool = True,
100
+ # **kwargs) -> Image.Image :
101
+ # ddim_sampler = GuidedDDIMSample(self)
102
+ # if use_cuda :
103
+ # self.cond_stage_model.cuda()
104
+ # self.first_stage_model.cuda()
105
+ # c_text = self.get_learned_conditioning([c_text])
106
+ # uc_text = self.get_learned_conditioning([uc_text])
107
+ # cond = {"c_crossattn": [c_text]}
108
+ # uc_cond = {"c_crossattn": [uc_text]}
109
+ #
110
+ # if use_cuda :
111
+ # device = torch.device('cuda:0')
112
+ # else :
113
+ # device = torch.device('cpu')
114
+ #
115
+ # image_mask = mask
116
+ # image_mask = image_mask.convert('L')
117
+ # image_mask = image_mask.filter(ImageFilter.GaussianBlur(mask_blur))
118
+ # latent_mask = image_mask
119
+ # # image = fill_mask_input(image, latent_mask)
120
+ # # image.save('image_fill.png')
121
+ # image = np.array(image).astype(np.float32) / 127.5 - 1.0
122
+ # image = np.moveaxis(image, 2, 0)
123
+ # image = torch.from_numpy(image).to(device)[None]
124
+ # init_latent = self.get_first_stage_encoding(self.encode_first_stage(image))
125
+ # init_mask = latent_mask
126
+ # latmask = init_mask.convert('RGB').resize((init_latent.shape[3], init_latent.shape[2]))
127
+ # latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
128
+ # latmask = latmask[0]
129
+ # latmask = np.around(latmask)
130
+ # latmask = np.tile(latmask[None], (4, 1, 1))
131
+ # nmask = torch.asarray(latmask).to(init_latent.device).float()
132
+ # init_latent = (1 - nmask) * init_latent + nmask * torch.randn_like(init_latent)
133
+ #
134
+ # denoising_strength = 1
135
+ # if self.model.conditioning_key == 'hybrid' :
136
+ # image_cdt = get_inpainting_image_condition(self, image, image_mask)
137
+ # cond["c_concat"] = [image_cdt]
138
+ # uc_cond["c_concat"] = [image_cdt]
139
+ #
140
+ # steps = ddim_steps
141
+ # t_enc = int(min(denoising_strength, 0.999) * steps)
142
+ # eta = 0
143
+ #
144
+ # noise = torch.randn_like(init_latent)
145
+ # ddim_sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, ddim_discretize="uniform", verbose=False)
146
+ # x1 = ddim_sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * int(init_latent.shape[0])).to(device), noise=noise)
147
+ #
148
+ # if use_cuda :
149
+ # self.cond_stage_model.cpu()
150
+ # self.first_stage_model.cpu()
151
+ #
152
+ # if use_cuda :
153
+ # self.model.cuda()
154
+ # decoded = ddim_sampler.decode(x1, cond,t_enc,init_latent=init_latent,nmask=nmask,unconditional_guidance_scale=7,unconditional_conditioning=uc_cond)
155
+ # if use_cuda :
156
+ # self.model.cpu()
157
+ #
158
+ # if mask is not None :
159
+ # decoded = init_latent * (1 - nmask) + decoded * nmask
160
+ #
161
+ # if use_cuda :
162
+ # self.first_stage_model.cuda()
163
+ # with torch.cuda.amp.autocast(enabled=False):
164
+ # x_samples = self.decode_first_stage(decoded.to(torch.float32))
165
+ # if use_cuda :
166
+ # self.first_stage_model.cpu()
167
+ # return torch.clip(x_samples, -1, 1)
168
+ #
169
+ #
170
+ #
171
+ # class GuidedDDIMSample(DDIMSampler) :
172
+ # def __init__(self, *args, **kwargs):
173
+ # super().__init__(*args, **kwargs)
174
+ #
175
+ # @torch.no_grad()
176
+ # def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
177
+ # temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
178
+ # unconditional_guidance_scale=1., unconditional_conditioning=None,
179
+ # dynamic_threshold=None):
180
+ # b, *_, device = *x.shape, x.device
181
+ #
182
+ # if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
183
+ # model_output = self.model.apply_model(x, t, c)
184
+ # else:
185
+ # x_in = torch.cat([x] * 2)
186
+ # t_in = torch.cat([t] * 2)
187
+ # if isinstance(c, dict):
188
+ # assert isinstance(unconditional_conditioning, dict)
189
+ # c_in = dict()
190
+ # for k in c:
191
+ # if isinstance(c[k], list):
192
+ # c_in[k] = [torch.cat([
193
+ # unconditional_conditioning[k][i],
194
+ # c[k][i]]) for i in range(len(c[k]))]
195
+ # else:
196
+ # c_in[k] = torch.cat([
197
+ # unconditional_conditioning[k],
198
+ # c[k]])
199
+ # elif isinstance(c, list):
200
+ # c_in = list()
201
+ # assert isinstance(unconditional_conditioning, list)
202
+ # for i in range(len(c)):
203
+ # c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
204
+ # else:
205
+ # c_in = torch.cat([unconditional_conditioning, c])
206
+ # model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
207
+ # model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
208
+ #
209
+ # e_t = model_output
210
+ #
211
+ # alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
212
+ # alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
213
+ # sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
214
+ # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
215
+ # # select parameters corresponding to the currently considered timestep
216
+ # a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
217
+ # a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
218
+ # sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
219
+ # sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
220
+ #
221
+ # # current prediction for x_0
222
+ # pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
223
+ #
224
+ # # direction pointing to x_t
225
+ # dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
226
+ # noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
227
+ # if noise_dropout > 0.:
228
+ # noise = torch.nn.functional.dropout(noise, p=noise_dropout)
229
+ # x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
230
+ # return x_prev, pred_x0
231
+ #
232
+ # @torch.no_grad()
233
+ # def decode(self, x_latent, cond, t_start, init_latent=None, nmask=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
234
+ # use_original_steps=False, callback=None):
235
+ #
236
+ # timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
237
+ # total_steps = len(timesteps)
238
+ # timesteps = timesteps[:t_start]
239
+ #
240
+ # time_range = np.flip(timesteps)
241
+ # total_steps = timesteps.shape[0]
242
+ # print(f"Running Guided DDIM Sampling with {len(timesteps)} timesteps, t_start={t_start}")
243
+ # iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
244
+ # x_dec = x_latent
245
+ # for i, step in enumerate(iterator):
246
+ # p = (i + (total_steps - t_start) + 1) / (total_steps)
247
+ # index = total_steps - i - 1
248
+ # ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
249
+ # if nmask is not None :
250
+ # noised_input = self.model.q_sample(init_latent.to(x_latent.device), ts.to(x_latent.device))
251
+ # x_dec = (1 - nmask) * noised_input + nmask * x_dec
252
+ # x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
253
+ # unconditional_guidance_scale=unconditional_guidance_scale,
254
+ # unconditional_conditioning=unconditional_conditioning)
255
+ # if callback: callback(i)
256
+ # return x_dec
257
+ #
258
+ #
259
+ # def ldm_inpaint(model, img, mask, inpaint_size=720, pos_prompt='', neg_prompt = '', use_cuda=True):
260
+ # img_original = np.copy(img)
261
+ # im_h, im_w = img.shape[:2]
262
+ # img_resized, (pad_h, pad_w) = resize_pad2divisior(img, inpaint_size)
263
+ #
264
+ # mask_original = np.copy(mask)
265
+ # mask_original[mask_original < 127] = 0
266
+ # mask_original[mask_original >= 127] = 1
267
+ # mask_original = mask_original[:, :, None]
268
+ # mask, _ = resize_pad2divisior(mask, inpaint_size)
269
+ #
270
+ # # cv2.imwrite('img_resized.png', img_resized)
271
+ # # cv2.imwrite('mask_resized.png', mask)
272
+ #
273
+ #
274
+ # if use_cuda :
275
+ # with torch.autocast(enabled = True, device_type = 'cuda') :
276
+ # img = model.img2img_inpaint(
277
+ # image = Image.fromarray(img_resized),
278
+ # c_text = pos_prompt,
279
+ # uc_text = neg_prompt,
280
+ # mask = Image.fromarray(mask),
281
+ # use_cuda = True
282
+ # )
283
+ # else :
284
+ # img = model.img2img_inpaint(
285
+ # image = Image.fromarray(img_resized),
286
+ # c_text = pos_prompt,
287
+ # uc_text = neg_prompt,
288
+ # mask = Image.fromarray(mask),
289
+ # use_cuda = False
290
+ # )
291
+ #
292
+ # img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
293
+ # if pad_h != 0:
294
+ # img_inpainted = img_inpainted[:-pad_h]
295
+ # if pad_w != 0:
296
+ # img_inpainted = img_inpainted[:, :-pad_w]
297
+ #
298
+ #
299
+ # if img_inpainted.shape[0] != im_h or img_inpainted.shape[1] != im_w:
300
+ # img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR)
301
+ # ans = img_inpainted * mask_original + img_original * (1 - mask_original)
302
+ # ans = img_inpainted
303
+ # return ans
304
+
305
+
306
+
307
+
308
+ import requests
309
+ from PIL import Image
310
+ def ldm_inpaint_webui(
311
+ img, mask, resolution: int, url: str, prompt: str = '', neg_prompt: str = '',
312
+ **inpaint_ldm_options):
313
+ if isinstance(img, np.ndarray):
314
+ img = Image.fromarray(img)
315
+
316
+ im_h, im_w = img.height, img.width
317
+
318
+ if img.height > img.width:
319
+ W = resolution
320
+ H = (img.height / img.width * resolution) // 32 * 32
321
+ H = int(H)
322
+ else:
323
+ H = resolution
324
+ W = (img.width / img.height * resolution) // 32 * 32
325
+ W = int(W)
326
+
327
+ auth = None
328
+ if 'username' in inpaint_ldm_options:
329
+ username = inpaint_ldm_options.pop('username')
330
+ password = inpaint_ldm_options.pop('password')
331
+ auth = HTTPBasicAuth(username, password)
332
+
333
+ img_b64 = img2b64(img)
334
+ mask_b64 = img2b64(mask)
335
+ data = {
336
+ "init_images": [img_b64],
337
+ "mask": mask_b64,
338
+ "prompt": prompt,
339
+ "negative_prompt": neg_prompt,
340
+ "width": W,
341
+ "height": H,
342
+ **inpaint_ldm_options,
343
+ }
344
+ data = json.dumps(data)
345
+
346
+ response = submit_request(url, data, auth=auth)
347
+
348
+ inpainted_b64 = response.json()['images'][0]
349
+ inpainted = Image.open(io.BytesIO(base64.b64decode(inpainted_b64)))
350
+ if inpainted.height != im_h or inpainted.width != im_w:
351
+ inpainted = inpainted.resize((im_w, im_h), resample=Image.Resampling.LANCZOS)
352
+ inpainted = np.array(inpainted)
353
+ return inpainted
animeinsseg/inpainting/patch_match.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : patch_match.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import ctypes, os
11
+ import os.path as osp
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+ # try:
18
+ # # If the Jacinle library (https://github.com/vacancy/Jacinle) is present, use its auto_travis feature.
19
+ # from jacinle.jit.cext import auto_travis
20
+ # auto_travis(__file__, required_files=['*.so'])
21
+ # except ImportError as e:
22
+ # # Otherwise, fall back to the subprocess.
23
+ # import subprocess
24
+ # print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
25
+ # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
26
+
27
+
28
+ __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
29
+
30
+
31
+ class CShapeT(ctypes.Structure):
32
+ _fields_ = [
33
+ ('width', ctypes.c_int),
34
+ ('height', ctypes.c_int),
35
+ ('channels', ctypes.c_int),
36
+ ]
37
+
38
+ class CMatT(ctypes.Structure):
39
+ _fields_ = [
40
+ ('data_ptr', ctypes.c_void_p),
41
+ ('shape', CShapeT),
42
+ ('dtype', ctypes.c_int)
43
+ ]
44
+
45
+ import sys
46
+ if sys.platform == 'linux':
47
+ PMLIB = ctypes.CDLL('data/libs/libpatchmatch_inpaint.so')
48
+ else:
49
+ PMLIB = ctypes.CDLL('data/libs/libpatchmatch.dll')
50
+
51
+ PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
52
+ PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
53
+ PMLIB.PM_free_pymat.argtypes = [CMatT]
54
+ PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
55
+ PMLIB.PM_inpaint.restype = CMatT
56
+ PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
57
+ PMLIB.PM_inpaint_regularity.restype = CMatT
58
+ PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
59
+ PMLIB.PM_inpaint2.restype = CMatT
60
+ PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
61
+ PMLIB.PM_inpaint2_regularity.restype = CMatT
62
+
63
+
64
+ def set_random_seed(seed: int):
65
+ PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
66
+
67
+
68
+ def set_verbose(verbose: bool):
69
+ PMLIB.PM_set_verbose(ctypes.c_int(verbose))
70
+
71
+
72
+ def inpaint(
73
+ image: Union[np.ndarray, Image.Image],
74
+ mask: Optional[Union[np.ndarray, Image.Image]] = None,
75
+ *,
76
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
77
+ patch_size: int = 15
78
+ ) -> np.ndarray:
79
+ """
80
+ PatchMatch based inpainting proposed in:
81
+
82
+ PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
83
+ C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
84
+ SIGGRAPH 2009
85
+
86
+ Args:
87
+ image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
88
+ mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
89
+ If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
90
+ global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
91
+ patch_size (int): the patch size for the inpainting algorithm.
92
+
93
+ Return:
94
+ result (np.ndarray): the repaired image, of the same size as the input image.
95
+ """
96
+
97
+ if isinstance(image, Image.Image):
98
+ image = np.array(image)
99
+ image = np.ascontiguousarray(image)
100
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
101
+
102
+ if mask is None:
103
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
104
+ mask = np.ascontiguousarray(mask)
105
+ else:
106
+ mask = _canonize_mask_array(mask)
107
+
108
+ if global_mask is None:
109
+ ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
110
+ else:
111
+ global_mask = _canonize_mask_array(global_mask)
112
+ ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
113
+
114
+ ret_npmat = pymat_to_np(ret_pymat)
115
+ PMLIB.PM_free_pymat(ret_pymat)
116
+
117
+ return ret_npmat
118
+
119
+
120
+ def inpaint_regularity(
121
+ image: Union[np.ndarray, Image.Image],
122
+ mask: Optional[Union[np.ndarray, Image.Image]],
123
+ ijmap: np.ndarray,
124
+ *,
125
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
126
+ patch_size: int = 15, guide_weight: float = 0.25
127
+ ) -> np.ndarray:
128
+ if isinstance(image, Image.Image):
129
+ image = np.array(image)
130
+ image = np.ascontiguousarray(image)
131
+
132
+ assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
133
+ ijmap = np.ascontiguousarray(ijmap)
134
+
135
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
136
+ if mask is None:
137
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
138
+ mask = np.ascontiguousarray(mask)
139
+ else:
140
+ mask = _canonize_mask_array(mask)
141
+
142
+
143
+ if global_mask is None:
144
+ ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
145
+ else:
146
+ global_mask = _canonize_mask_array(global_mask)
147
+ ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
148
+
149
+ ret_npmat = pymat_to_np(ret_pymat)
150
+ PMLIB.PM_free_pymat(ret_pymat)
151
+
152
+ return ret_npmat
153
+
154
+
155
+ def _canonize_mask_array(mask):
156
+ if isinstance(mask, Image.Image):
157
+ mask = np.array(mask)
158
+ if mask.ndim == 2 and mask.dtype == 'uint8':
159
+ mask = mask[..., np.newaxis]
160
+ assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
161
+ return np.ascontiguousarray(mask)
162
+
163
+
164
+ dtype_pymat_to_ctypes = [
165
+ ctypes.c_uint8,
166
+ ctypes.c_int8,
167
+ ctypes.c_uint16,
168
+ ctypes.c_int16,
169
+ ctypes.c_int32,
170
+ ctypes.c_float,
171
+ ctypes.c_double,
172
+ ]
173
+
174
+
175
+ dtype_np_to_pymat = {
176
+ 'uint8': 0,
177
+ 'int8': 1,
178
+ 'uint16': 2,
179
+ 'int16': 3,
180
+ 'int32': 4,
181
+ 'float32': 5,
182
+ 'float64': 6,
183
+ }
184
+
185
+
186
+ def np_to_pymat(npmat):
187
+ assert npmat.ndim == 3
188
+ return CMatT(
189
+ ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
190
+ CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
191
+ dtype_np_to_pymat[str(npmat.dtype)]
192
+ )
193
+
194
+
195
+ def pymat_to_np(pymat):
196
+ npmat = np.ctypeslib.as_array(
197
+ ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
198
+ (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
199
+ )
200
+ ret = np.empty(npmat.shape, npmat.dtype)
201
+ ret[:] = npmat
202
+ return ret
203
+
animeinsseg/models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from typing import Union
5
+
6
+
7
+
animeinsseg/models/animeseg_refine/__init__.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/SkyTNT/anime-segmentation/blob/main/train.py
2
+ import os
3
+
4
+ import argparse
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import torch.optim as optim
11
+ import numpy as np
12
+ import cv2
13
+ from torch.cuda import amp
14
+
15
+ from utils.constants import DEFAULT_DEVICE
16
+ # from data_loader import create_training_datasets
17
+
18
+
19
+ import pytorch_lightning as pl
20
+ import warnings
21
+
22
+ from .isnet import ISNetDIS, ISNetGTEncoder
23
+ from .u2net import U2NET, U2NET_full, U2NET_full2, U2NET_lite2
24
+ from .modnet import MODNet
25
+
26
+ # warnings.filterwarnings("ignore")
27
+
28
+ def get_net(net_name):
29
+ if net_name == "isnet":
30
+ return ISNetDIS()
31
+ elif net_name == "isnet_is":
32
+ return ISNetDIS()
33
+ elif net_name == "isnet_gt":
34
+ return ISNetGTEncoder()
35
+ elif net_name == "u2net":
36
+ return U2NET_full2()
37
+ elif net_name == "u2netl":
38
+ return U2NET_lite2()
39
+ elif net_name == "modnet":
40
+ return MODNet()
41
+ raise NotImplemented
42
+
43
+
44
+ def f1_torch(pred, gt):
45
+ # micro F1-score
46
+ pred = pred.float().view(pred.shape[0], -1)
47
+ gt = gt.float().view(gt.shape[0], -1)
48
+ tp1 = torch.sum(pred * gt, dim=1)
49
+ tp_fp1 = torch.sum(pred, dim=1)
50
+ tp_fn1 = torch.sum(gt, dim=1)
51
+ pred = 1 - pred
52
+ gt = 1 - gt
53
+ tp2 = torch.sum(pred * gt, dim=1)
54
+ tp_fp2 = torch.sum(pred, dim=1)
55
+ tp_fn2 = torch.sum(gt, dim=1)
56
+ precision = (tp1 + tp2) / (tp_fp1 + tp_fp2 + 0.0001)
57
+ recall = (tp1 + tp2) / (tp_fn1 + tp_fn2 + 0.0001)
58
+ f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 0.0001)
59
+ return precision, recall, f1
60
+
61
+
62
+ class AnimeSegmentation(pl.LightningModule):
63
+
64
+ def __init__(self, net_name):
65
+ super().__init__()
66
+ assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
67
+ self.net = get_net(net_name)
68
+ if net_name == "isnet_is":
69
+ self.gt_encoder = get_net("isnet_gt")
70
+ self.gt_encoder.requires_grad_(False)
71
+ else:
72
+ self.gt_encoder = None
73
+
74
+ @classmethod
75
+ def try_load(cls, net_name, ckpt_path, map_location=None):
76
+ state_dict = torch.load(ckpt_path, map_location=map_location)
77
+ if "epoch" in state_dict:
78
+ return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
79
+ else:
80
+ model = cls(net_name)
81
+ if any([k.startswith("net.") for k, v in state_dict.items()]):
82
+ model.load_state_dict(state_dict)
83
+ else:
84
+ model.net.load_state_dict(state_dict)
85
+ return model
86
+
87
+ def configure_optimizers(self):
88
+ optimizer = optim.Adam(self.net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
89
+ return optimizer
90
+
91
+ def forward(self, x):
92
+ if isinstance(self.net, ISNetDIS):
93
+ return self.net(x)[0][0].sigmoid()
94
+ if isinstance(self.net, ISNetGTEncoder):
95
+ return self.net(x)[0][0].sigmoid()
96
+ elif isinstance(self.net, U2NET):
97
+ return self.net(x)[0].sigmoid()
98
+ elif isinstance(self.net, MODNet):
99
+ return self.net(x, True)[2]
100
+ raise NotImplemented
101
+
102
+ def training_step(self, batch, batch_idx):
103
+ images, labels = batch["image"], batch["label"]
104
+ if isinstance(self.net, ISNetDIS):
105
+ ds, dfs = self.net(images)
106
+ loss_args = [ds, dfs, labels]
107
+ elif isinstance(self.net, ISNetGTEncoder):
108
+ ds = self.net(labels)[0]
109
+ loss_args = [ds, labels]
110
+ elif isinstance(self.net, U2NET):
111
+ ds = self.net(images)
112
+ loss_args = [ds, labels]
113
+ elif isinstance(self.net, MODNet):
114
+ trimaps = batch["trimap"]
115
+ pred_semantic, pred_detail, pred_matte = self.net(images, False)
116
+ loss_args = [pred_semantic, pred_detail, pred_matte, images, trimaps, labels]
117
+ else:
118
+ raise NotImplemented
119
+ if self.gt_encoder is not None:
120
+ fs = self.gt_encoder(labels)[1]
121
+ loss_args.append(fs)
122
+
123
+ loss0, loss = self.net.compute_loss(loss_args)
124
+ self.log_dict({"train/loss": loss, "train/loss_tar": loss0})
125
+ return loss
126
+
127
+ def validation_step(self, batch, batch_idx):
128
+ images, labels = batch["image"], batch["label"]
129
+ if isinstance(self.net, ISNetGTEncoder):
130
+ preds = self.forward(labels)
131
+ else:
132
+ preds = self.forward(images)
133
+ pre, rec, f1, = f1_torch(preds.nan_to_num(nan=0, posinf=1, neginf=0), labels)
134
+ mae_m = F.l1_loss(preds, labels, reduction="mean")
135
+ pre_m = pre.mean()
136
+ rec_m = rec.mean()
137
+ f1_m = f1.mean()
138
+ self.log_dict({"val/precision": pre_m, "val/recall": rec_m, "val/f1": f1_m, "val/mae": mae_m}, sync_dist=True)
139
+
140
+
141
+ def get_gt_encoder(train_dataloader, val_dataloader, opt):
142
+ print("---start train ground truth encoder---")
143
+ gt_encoder = AnimeSegmentation("isnet_gt")
144
+ trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
145
+ devices=opt.devices, max_epochs=opt.gt_epoch,
146
+ benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
147
+ check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
148
+ strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
149
+ )
150
+ trainer.fit(gt_encoder, train_dataloader, val_dataloader)
151
+ return gt_encoder.net
152
+
153
+
154
+ def load_refinenet(refine_method = 'animeseg', device: str = None) -> AnimeSegmentation:
155
+ if device is None:
156
+ device = DEFAULT_DEVICE
157
+ if refine_method == 'animeseg':
158
+ model = AnimeSegmentation.try_load('isnet_is', 'models/anime-seg/isnetis.ckpt', device)
159
+ elif refine_method == 'refinenet_isnet':
160
+ model = ISNetDIS(in_ch=4)
161
+ sd = torch.load('models/AnimeInstanceSegmentation/refine_last.ckpt', map_location='cpu')
162
+ # sd = torch.load('models/AnimeInstanceSegmentation/refine_noweight_dist.ckpt', map_location='cpu')
163
+ # sd = torch.load('models/AnimeInstanceSegmentation/refine_f3loss.ckpt', map_location='cpu')
164
+ model.load_state_dict(sd)
165
+ else:
166
+ raise NotImplementedError
167
+ return model.eval().to(device)
168
+
169
+ def get_mask(model, input_img, use_amp=True, s=640):
170
+ h0, w0 = h, w = input_img.shape[0], input_img.shape[1]
171
+ if h > w:
172
+ h, w = s, int(s * w / h)
173
+ else:
174
+ h, w = int(s * h / w), s
175
+ ph, pw = s - h, s - w
176
+ tmpImg = np.zeros([s, s, 3], dtype=np.float32)
177
+ tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
178
+ tmpImg = tmpImg.transpose((2, 0, 1))
179
+ tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
180
+ with torch.no_grad():
181
+ if use_amp:
182
+ with amp.autocast():
183
+ pred = model(tmpImg)
184
+ pred = pred.to(dtype=torch.float32)
185
+ else:
186
+ pred = model(tmpImg)
187
+ pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
188
+ pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
189
+ return pred
animeinsseg/models/animeseg_refine/encoders.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+
8
+ class AbstractEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def encode(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+
16
+ class IdentityEncoder(AbstractEncoder):
17
+
18
+ def encode(self, x):
19
+ return x
20
+
21
+
22
+ class ClassEmbedder(nn.Module):
23
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
24
+ super().__init__()
25
+ self.key = key
26
+ self.embedding = nn.Embedding(n_classes, embed_dim)
27
+ self.n_classes = n_classes
28
+ self.ucg_rate = ucg_rate
29
+
30
+ def forward(self, batch, key=None, disable_dropout=False):
31
+ if key is None:
32
+ key = self.key
33
+ # this is for use in crossattn
34
+ c = batch[key][:, None]
35
+ if self.ucg_rate > 0. and not disable_dropout:
36
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
37
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
38
+ c = c.long()
39
+ c = self.embedding(c)
40
+ return c
41
+
42
+ def get_unconditional_conditioning(self, bs, device="cuda"):
43
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
44
+ uc = torch.ones((bs,), device=device) * uc_class
45
+ uc = {self.key: uc}
46
+ return uc
47
+
48
+
49
+ class DanbooruEmbedder(AbstractEncoder):
50
+ def __init__(self):
51
+ super().__init__()
animeinsseg/models/animeseg_refine/isnet.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models
7
+ import torch.nn.functional as F
8
+
9
+ _bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
10
+ _bce_loss_none = nn.BCEWithLogitsLoss(reduction='none')
11
+
12
+ def bce_loss(p, t, weights=None):
13
+ if weights is None:
14
+ return _bce_loss(p, t)
15
+ else:
16
+ loss = _bce_loss_none(p, t)
17
+ loss = loss * weights
18
+ return loss.mean()
19
+
20
+
21
+ _fea_loss = nn.MSELoss(reduction="mean")
22
+ _fea_loss_none = nn.MSELoss(reduction="none")
23
+
24
+ def fea_loss(p, t, weights=None):
25
+ return _fea_loss(p, t)
26
+
27
+ kl_loss = nn.KLDivLoss(reduction="mean")
28
+ l1_loss = nn.L1Loss(reduction="mean")
29
+ smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
30
+
31
+
32
+ def structure_loss(pred, mask):
33
+ weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=15, stride=1, padding=7)-mask)
34
+ wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
35
+ wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
36
+
37
+ pred = torch.sigmoid(pred)
38
+ inter = ((pred*mask)*weit).sum(dim=(2,3))
39
+ union = ((pred+mask)*weit).sum(dim=(2,3))
40
+ wiou = 1-(inter+1)/(union-inter+1)
41
+ return (wbce+wiou).mean()
42
+
43
+
44
+ def muti_loss_fusion(preds, target, dist_weight=None, loss0_weight=1.0):
45
+ loss0 = 0.0
46
+ loss = 0.0
47
+
48
+ for i in range(0, len(preds)):
49
+ weight = dist_weight if i == 0 else None
50
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
51
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
52
+ loss = loss + structure_loss(preds[i], tmp_target)
53
+ else:
54
+ # loss = loss + bce_loss(preds[i], target, weight)
55
+ loss = loss + structure_loss(preds[i], target)
56
+ if i == 0:
57
+ loss *= loss0_weight
58
+ loss0 = loss
59
+ return loss0, loss
60
+
61
+
62
+
63
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE', dist_weight=None, loss0_weight=1.0):
64
+ loss0 = 0.0
65
+ loss = 0.0
66
+
67
+ for i in range(0, len(preds)):
68
+ weight = dist_weight if i == 0 else None
69
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
70
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
71
+ # loss = loss + bce_loss(preds[i], tmp_target, weight)
72
+ loss = loss + structure_loss(preds[i], tmp_target)
73
+ else:
74
+ # loss = loss + bce_loss(preds[i], target, weight)
75
+ loss = loss + structure_loss(preds[i], target)
76
+ if i == 0:
77
+ loss *= loss0_weight
78
+ loss0 = loss
79
+
80
+ for i in range(0, len(dfs)):
81
+ df = dfs[i]
82
+ fs_i = fs[i]
83
+ if mode == 'MSE':
84
+ loss = loss + fea_loss(df, fs_i, dist_weight) ### add the mse loss of features as additional constraints
85
+ elif mode == 'KL':
86
+ loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
87
+ elif mode == 'MAE':
88
+ loss = loss + l1_loss(df, fs_i)
89
+ elif mode == 'SmoothL1':
90
+ loss = loss + smooth_l1_loss(df, fs_i)
91
+
92
+ return loss0, loss
93
+
94
+
95
+ class REBNCONV(nn.Module):
96
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
97
+ super(REBNCONV, self).__init__()
98
+
99
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
100
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
101
+ self.relu_s1 = nn.ReLU(inplace=True)
102
+
103
+ def forward(self, x):
104
+ hx = x
105
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
106
+
107
+ return xout
108
+
109
+
110
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
111
+ def _upsample_like(src, tar):
112
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
113
+
114
+ return src
115
+
116
+
117
+ ### RSU-7 ###
118
+ class RSU7(nn.Module):
119
+
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
121
+ super(RSU7, self).__init__()
122
+
123
+ self.in_ch = in_ch
124
+ self.mid_ch = mid_ch
125
+ self.out_ch = out_ch
126
+
127
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
128
+
129
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
130
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
131
+
132
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
133
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
134
+
135
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
136
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
137
+
138
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
139
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
140
+
141
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
142
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
143
+
144
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
145
+
146
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
147
+
148
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
149
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
150
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
151
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
152
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
153
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
154
+
155
+ def forward(self, x):
156
+ b, c, h, w = x.shape
157
+
158
+ hx = x
159
+ hxin = self.rebnconvin(hx)
160
+
161
+ hx1 = self.rebnconv1(hxin)
162
+ hx = self.pool1(hx1)
163
+
164
+ hx2 = self.rebnconv2(hx)
165
+ hx = self.pool2(hx2)
166
+
167
+ hx3 = self.rebnconv3(hx)
168
+ hx = self.pool3(hx3)
169
+
170
+ hx4 = self.rebnconv4(hx)
171
+ hx = self.pool4(hx4)
172
+
173
+ hx5 = self.rebnconv5(hx)
174
+ hx = self.pool5(hx5)
175
+
176
+ hx6 = self.rebnconv6(hx)
177
+
178
+ hx7 = self.rebnconv7(hx6)
179
+
180
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
181
+ hx6dup = _upsample_like(hx6d, hx5)
182
+
183
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
184
+ hx5dup = _upsample_like(hx5d, hx4)
185
+
186
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
187
+ hx4dup = _upsample_like(hx4d, hx3)
188
+
189
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
190
+ hx3dup = _upsample_like(hx3d, hx2)
191
+
192
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
193
+ hx2dup = _upsample_like(hx2d, hx1)
194
+
195
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
196
+
197
+ return hx1d + hxin
198
+
199
+
200
+ ### RSU-6 ###
201
+ class RSU6(nn.Module):
202
+
203
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
204
+ super(RSU6, self).__init__()
205
+
206
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
207
+
208
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
209
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
210
+
211
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
212
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
213
+
214
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
215
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
216
+
217
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
218
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
219
+
220
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
221
+
222
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
223
+
224
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
225
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
226
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
227
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
228
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
229
+
230
+ def forward(self, x):
231
+ hx = x
232
+
233
+ hxin = self.rebnconvin(hx)
234
+
235
+ hx1 = self.rebnconv1(hxin)
236
+ hx = self.pool1(hx1)
237
+
238
+ hx2 = self.rebnconv2(hx)
239
+ hx = self.pool2(hx2)
240
+
241
+ hx3 = self.rebnconv3(hx)
242
+ hx = self.pool3(hx3)
243
+
244
+ hx4 = self.rebnconv4(hx)
245
+ hx = self.pool4(hx4)
246
+
247
+ hx5 = self.rebnconv5(hx)
248
+
249
+ hx6 = self.rebnconv6(hx5)
250
+
251
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
252
+ hx5dup = _upsample_like(hx5d, hx4)
253
+
254
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
255
+ hx4dup = _upsample_like(hx4d, hx3)
256
+
257
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
258
+ hx3dup = _upsample_like(hx3d, hx2)
259
+
260
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
261
+ hx2dup = _upsample_like(hx2d, hx1)
262
+
263
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
264
+
265
+ return hx1d + hxin
266
+
267
+
268
+ ### RSU-5 ###
269
+ class RSU5(nn.Module):
270
+
271
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
272
+ super(RSU5, self).__init__()
273
+
274
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
275
+
276
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
277
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
278
+
279
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
280
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
281
+
282
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
283
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
284
+
285
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
286
+
287
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
288
+
289
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
290
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
291
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
292
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
293
+
294
+ def forward(self, x):
295
+ hx = x
296
+
297
+ hxin = self.rebnconvin(hx)
298
+
299
+ hx1 = self.rebnconv1(hxin)
300
+ hx = self.pool1(hx1)
301
+
302
+ hx2 = self.rebnconv2(hx)
303
+ hx = self.pool2(hx2)
304
+
305
+ hx3 = self.rebnconv3(hx)
306
+ hx = self.pool3(hx3)
307
+
308
+ hx4 = self.rebnconv4(hx)
309
+
310
+ hx5 = self.rebnconv5(hx4)
311
+
312
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
313
+ hx4dup = _upsample_like(hx4d, hx3)
314
+
315
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
316
+ hx3dup = _upsample_like(hx3d, hx2)
317
+
318
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
319
+ hx2dup = _upsample_like(hx2d, hx1)
320
+
321
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
322
+
323
+ return hx1d + hxin
324
+
325
+
326
+ ### RSU-4 ###
327
+ class RSU4(nn.Module):
328
+
329
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
330
+ super(RSU4, self).__init__()
331
+
332
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
333
+
334
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
335
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
336
+
337
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
338
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
339
+
340
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
341
+
342
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
343
+
344
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
345
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
346
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
347
+
348
+ def forward(self, x):
349
+ hx = x
350
+
351
+ hxin = self.rebnconvin(hx)
352
+
353
+ hx1 = self.rebnconv1(hxin)
354
+ hx = self.pool1(hx1)
355
+
356
+ hx2 = self.rebnconv2(hx)
357
+ hx = self.pool2(hx2)
358
+
359
+ hx3 = self.rebnconv3(hx)
360
+
361
+ hx4 = self.rebnconv4(hx3)
362
+
363
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
364
+ hx3dup = _upsample_like(hx3d, hx2)
365
+
366
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
367
+ hx2dup = _upsample_like(hx2d, hx1)
368
+
369
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
370
+
371
+ return hx1d + hxin
372
+
373
+
374
+ ### RSU-4F ###
375
+ class RSU4F(nn.Module):
376
+
377
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
378
+ super(RSU4F, self).__init__()
379
+
380
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
381
+
382
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
383
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
384
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
385
+
386
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
387
+
388
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
389
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
390
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
391
+
392
+ def forward(self, x):
393
+ hx = x
394
+
395
+ hxin = self.rebnconvin(hx)
396
+
397
+ hx1 = self.rebnconv1(hxin)
398
+ hx2 = self.rebnconv2(hx1)
399
+ hx3 = self.rebnconv3(hx2)
400
+
401
+ hx4 = self.rebnconv4(hx3)
402
+
403
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
404
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
405
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
406
+
407
+ return hx1d + hxin
408
+
409
+
410
+ class myrebnconv(nn.Module):
411
+ def __init__(self, in_ch=3,
412
+ out_ch=1,
413
+ kernel_size=3,
414
+ stride=1,
415
+ padding=1,
416
+ dilation=1,
417
+ groups=1):
418
+ super(myrebnconv, self).__init__()
419
+
420
+ self.conv = nn.Conv2d(in_ch,
421
+ out_ch,
422
+ kernel_size=kernel_size,
423
+ stride=stride,
424
+ padding=padding,
425
+ dilation=dilation,
426
+ groups=groups)
427
+ self.bn = nn.BatchNorm2d(out_ch)
428
+ self.rl = nn.ReLU(inplace=True)
429
+
430
+ def forward(self, x):
431
+ return self.rl(self.bn(self.conv(x)))
432
+
433
+
434
+ class ISNetGTEncoder(nn.Module):
435
+
436
+ def __init__(self, in_ch=1, out_ch=1):
437
+ super(ISNetGTEncoder, self).__init__()
438
+
439
+ self.conv_in = myrebnconv(in_ch, 16, 3, stride=2, padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
440
+
441
+ self.stage1 = RSU7(16, 16, 64)
442
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
443
+
444
+ self.stage2 = RSU6(64, 16, 64)
445
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
446
+
447
+ self.stage3 = RSU5(64, 32, 128)
448
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
449
+
450
+ self.stage4 = RSU4(128, 32, 256)
451
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
452
+
453
+ self.stage5 = RSU4F(256, 64, 512)
454
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
455
+
456
+ self.stage6 = RSU4F(512, 64, 512)
457
+
458
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
459
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
460
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
461
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
462
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
463
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
464
+
465
+ @staticmethod
466
+ def compute_loss(args, dist_weight=None):
467
+ preds, targets = args
468
+ return muti_loss_fusion(preds, targets, dist_weight)
469
+
470
+ def forward(self, x):
471
+ hx = x
472
+
473
+ hxin = self.conv_in(hx)
474
+ # hx = self.pool_in(hxin)
475
+
476
+ # stage 1
477
+ hx1 = self.stage1(hxin)
478
+ hx = self.pool12(hx1)
479
+
480
+ # stage 2
481
+ hx2 = self.stage2(hx)
482
+ hx = self.pool23(hx2)
483
+
484
+ # stage 3
485
+ hx3 = self.stage3(hx)
486
+ hx = self.pool34(hx3)
487
+
488
+ # stage 4
489
+ hx4 = self.stage4(hx)
490
+ hx = self.pool45(hx4)
491
+
492
+ # stage 5
493
+ hx5 = self.stage5(hx)
494
+ hx = self.pool56(hx5)
495
+
496
+ # stage 6
497
+ hx6 = self.stage6(hx)
498
+
499
+ # side output
500
+ d1 = self.side1(hx1)
501
+ d1 = _upsample_like(d1, x)
502
+
503
+ d2 = self.side2(hx2)
504
+ d2 = _upsample_like(d2, x)
505
+
506
+ d3 = self.side3(hx3)
507
+ d3 = _upsample_like(d3, x)
508
+
509
+ d4 = self.side4(hx4)
510
+ d4 = _upsample_like(d4, x)
511
+
512
+ d5 = self.side5(hx5)
513
+ d5 = _upsample_like(d5, x)
514
+
515
+ d6 = self.side6(hx6)
516
+ d6 = _upsample_like(d6, x)
517
+
518
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
519
+
520
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
521
+ return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
522
+
523
+
524
+ class ISNetDIS(nn.Module):
525
+
526
+ def __init__(self, in_ch=3, out_ch=1):
527
+ super(ISNetDIS, self).__init__()
528
+
529
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
530
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
531
+
532
+ self.stage1 = RSU7(64, 32, 64)
533
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
534
+
535
+ self.stage2 = RSU6(64, 32, 128)
536
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
537
+
538
+ self.stage3 = RSU5(128, 64, 256)
539
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
540
+
541
+ self.stage4 = RSU4(256, 128, 512)
542
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
543
+
544
+ self.stage5 = RSU4F(512, 256, 512)
545
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
546
+
547
+ self.stage6 = RSU4F(512, 256, 512)
548
+
549
+ # decoder
550
+ self.stage5d = RSU4F(1024, 256, 512)
551
+ self.stage4d = RSU4(1024, 128, 256)
552
+ self.stage3d = RSU5(512, 64, 128)
553
+ self.stage2d = RSU6(256, 32, 64)
554
+ self.stage1d = RSU7(128, 16, 64)
555
+
556
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
557
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
558
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
559
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
560
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
561
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
562
+
563
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
564
+
565
+ @staticmethod
566
+ def compute_loss_kl(preds, targets, dfs, fs, mode='MSE'):
567
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode, loss0_weight=5.0)
568
+
569
+ @staticmethod
570
+ def compute_loss(args, dist_weight=None):
571
+ if len(args) == 3:
572
+ ds, dfs, labels = args
573
+ return muti_loss_fusion(ds, labels, dist_weight, loss0_weight=5.0)
574
+ else:
575
+ ds, dfs, labels, fs = args
576
+ return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE", dist_weight=dist_weight, loss0_weight=5.0)
577
+
578
+ def forward(self, x):
579
+ hx = x
580
+
581
+ hxin = self.conv_in(hx)
582
+ hx = self.pool_in(hxin)
583
+
584
+ # stage 1
585
+ hx1 = self.stage1(hxin)
586
+ hx = self.pool12(hx1)
587
+
588
+ # stage 2
589
+ hx2 = self.stage2(hx)
590
+ hx = self.pool23(hx2)
591
+
592
+ # stage 3
593
+ hx3 = self.stage3(hx)
594
+ hx = self.pool34(hx3)
595
+
596
+ # stage 4
597
+ hx4 = self.stage4(hx)
598
+ hx = self.pool45(hx4)
599
+
600
+ # stage 5
601
+ hx5 = self.stage5(hx)
602
+ hx = self.pool56(hx5)
603
+
604
+ # stage 6
605
+ hx6 = self.stage6(hx)
606
+ hx6up = _upsample_like(hx6, hx5)
607
+
608
+ # -------------------- decoder --------------------
609
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
610
+ hx5dup = _upsample_like(hx5d, hx4)
611
+
612
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
613
+ hx4dup = _upsample_like(hx4d, hx3)
614
+
615
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
616
+ hx3dup = _upsample_like(hx3d, hx2)
617
+
618
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
619
+ hx2dup = _upsample_like(hx2d, hx1)
620
+
621
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
622
+
623
+ # side output
624
+ d1 = self.side1(hx1d)
625
+ d1 = _upsample_like(d1, x)
626
+
627
+ d2 = self.side2(hx2d)
628
+ d2 = _upsample_like(d2, x)
629
+
630
+ d3 = self.side3(hx3d)
631
+ d3 = _upsample_like(d3, x)
632
+
633
+ d4 = self.side4(hx4d)
634
+ d4 = _upsample_like(d4, x)
635
+
636
+ d5 = self.side5(hx5d)
637
+ d5 = _upsample_like(d5, x)
638
+
639
+ d6 = self.side6(hx6)
640
+ d6 = _upsample_like(d6, x)
641
+
642
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
643
+
644
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
645
+ return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
animeinsseg/models/animeseg_refine/models.py ADDED
File without changes
animeinsseg/models/animeseg_refine/modnet.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/trainer.py
3
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/models/backbones/mobilenetv2.py
4
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/models/modnet.py
5
+
6
+ import numpy as np
7
+ import scipy
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import os
12
+ import math
13
+ import torch
14
+ from scipy.ndimage import gaussian_filter
15
+
16
+
17
+ # ----------------------------------------------------------------------------------
18
+ # Loss Functions
19
+ # ----------------------------------------------------------------------------------
20
+
21
+
22
+ class GaussianBlurLayer(nn.Module):
23
+ """ Add Gaussian Blur to a 4D tensors
24
+ This layer takes a 4D tensor of {N, C, H, W} as input.
25
+ The Gaussian blur will be performed in given channel number (C) splitly.
26
+ """
27
+
28
+ def __init__(self, channels, kernel_size):
29
+ """
30
+ Arguments:
31
+ channels (int): Channel for input tensor
32
+ kernel_size (int): Size of the kernel used in blurring
33
+ """
34
+
35
+ super(GaussianBlurLayer, self).__init__()
36
+ self.channels = channels
37
+ self.kernel_size = kernel_size
38
+ assert self.kernel_size % 2 != 0
39
+
40
+ self.op = nn.Sequential(
41
+ nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
42
+ nn.Conv2d(channels, channels, self.kernel_size,
43
+ stride=1, padding=0, bias=None, groups=channels)
44
+ )
45
+
46
+ self._init_kernel()
47
+
48
+ def forward(self, x):
49
+ """
50
+ Arguments:
51
+ x (torch.Tensor): input 4D tensor
52
+ Returns:
53
+ torch.Tensor: Blurred version of the input
54
+ """
55
+
56
+ if not len(list(x.shape)) == 4:
57
+ print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
58
+ exit()
59
+ elif not x.shape[1] == self.channels:
60
+ print('In \'GaussianBlurLayer\', the required channel ({0}) is'
61
+ 'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
62
+ exit()
63
+
64
+ return self.op(x)
65
+
66
+ def _init_kernel(self):
67
+ sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
68
+
69
+ n = np.zeros((self.kernel_size, self.kernel_size))
70
+ i = math.floor(self.kernel_size / 2)
71
+ n[i, i] = 1
72
+ kernel = gaussian_filter(n, sigma)
73
+
74
+ for name, param in self.named_parameters():
75
+ param.data.copy_(torch.from_numpy(kernel))
76
+ param.requires_grad = False
77
+
78
+
79
+ blurer = GaussianBlurLayer(1, 3)
80
+
81
+
82
+ def loss_func(pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte,
83
+ semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
84
+ """ loss of MODNet
85
+ Arguments:
86
+ blurer: GaussianBlurLayer
87
+ pred_semantic: model output
88
+ pred_detail: model output
89
+ pred_matte: model output
90
+ image : input RGB image ts pixel values should be normalized
91
+ trimap : trimap used to calculate the losses
92
+ its pixel values can be 0, 0.5, or 1
93
+ (foreground=1, background=0, unknown=0.5)
94
+ gt_matte: ground truth alpha matte its pixel values are between [0, 1]
95
+ semantic_scale (float): scale of the semantic loss
96
+ NOTE: please adjust according to your dataset
97
+ detail_scale (float): scale of the detail loss
98
+ NOTE: please adjust according to your dataset
99
+ matte_scale (float): scale of the matte loss
100
+ NOTE: please adjust according to your dataset
101
+
102
+ Returns:
103
+ semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
104
+ detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
105
+ matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
106
+ """
107
+
108
+ trimap = trimap.float()
109
+ # calculate the boundary mask from the trimap
110
+ boundaries = (trimap < 0.5) + (trimap > 0.5)
111
+
112
+ # calculate the semantic loss
113
+ gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
114
+ gt_semantic = blurer(gt_semantic)
115
+ semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
116
+ semantic_loss = semantic_scale * semantic_loss
117
+
118
+ # calculate the detail loss
119
+ pred_boundary_detail = torch.where(boundaries, trimap, pred_detail.float())
120
+ gt_detail = torch.where(boundaries, trimap, gt_matte.float())
121
+ detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail.float()))
122
+ detail_loss = detail_scale * detail_loss
123
+
124
+ # calculate the matte loss
125
+ pred_boundary_matte = torch.where(boundaries, trimap, pred_matte.float())
126
+ matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
127
+ matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
128
+ + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
129
+ matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
130
+ matte_loss = matte_scale * matte_loss
131
+
132
+ return semantic_loss, detail_loss, matte_loss
133
+
134
+
135
+ # ------------------------------------------------------------------------------
136
+ # Useful functions
137
+ # ------------------------------------------------------------------------------
138
+
139
+ def _make_divisible(v, divisor, min_value=None):
140
+ if min_value is None:
141
+ min_value = divisor
142
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
143
+ # Make sure that round down does not go down by more than 10%.
144
+ if new_v < 0.9 * v:
145
+ new_v += divisor
146
+ return new_v
147
+
148
+
149
+ def conv_bn(inp, oup, stride):
150
+ return nn.Sequential(
151
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
152
+ nn.BatchNorm2d(oup),
153
+ nn.ReLU6(inplace=True)
154
+ )
155
+
156
+
157
+ def conv_1x1_bn(inp, oup):
158
+ return nn.Sequential(
159
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
160
+ nn.BatchNorm2d(oup),
161
+ nn.ReLU6(inplace=True)
162
+ )
163
+
164
+
165
+ # ------------------------------------------------------------------------------
166
+ # Class of Inverted Residual block
167
+ # ------------------------------------------------------------------------------
168
+
169
+ class InvertedResidual(nn.Module):
170
+ def __init__(self, inp, oup, stride, expansion, dilation=1):
171
+ super(InvertedResidual, self).__init__()
172
+ self.stride = stride
173
+ assert stride in [1, 2]
174
+
175
+ hidden_dim = round(inp * expansion)
176
+ self.use_res_connect = self.stride == 1 and inp == oup
177
+
178
+ if expansion == 1:
179
+ self.conv = nn.Sequential(
180
+ # dw
181
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
182
+ nn.BatchNorm2d(hidden_dim),
183
+ nn.ReLU6(inplace=True),
184
+ # pw-linear
185
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
186
+ nn.BatchNorm2d(oup),
187
+ )
188
+ else:
189
+ self.conv = nn.Sequential(
190
+ # pw
191
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
192
+ nn.BatchNorm2d(hidden_dim),
193
+ nn.ReLU6(inplace=True),
194
+ # dw
195
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
196
+ nn.BatchNorm2d(hidden_dim),
197
+ nn.ReLU6(inplace=True),
198
+ # pw-linear
199
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
200
+ nn.BatchNorm2d(oup),
201
+ )
202
+
203
+ def forward(self, x):
204
+ if self.use_res_connect:
205
+ return x + self.conv(x)
206
+ else:
207
+ return self.conv(x)
208
+
209
+
210
+ # ------------------------------------------------------------------------------
211
+ # Class of MobileNetV2
212
+ # ------------------------------------------------------------------------------
213
+
214
+ class MobileNetV2(nn.Module):
215
+ def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
216
+ super(MobileNetV2, self).__init__()
217
+ self.in_channels = in_channels
218
+ self.num_classes = num_classes
219
+ input_channel = 32
220
+ last_channel = 1280
221
+ interverted_residual_setting = [
222
+ # t, c, n, s
223
+ [1, 16, 1, 1],
224
+ [expansion, 24, 2, 2],
225
+ [expansion, 32, 3, 2],
226
+ [expansion, 64, 4, 2],
227
+ [expansion, 96, 3, 1],
228
+ [expansion, 160, 3, 2],
229
+ [expansion, 320, 1, 1],
230
+ ]
231
+
232
+ # building first layer
233
+ input_channel = _make_divisible(input_channel * alpha, 8)
234
+ self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
235
+ self.features = [conv_bn(self.in_channels, input_channel, 2)]
236
+
237
+ # building inverted residual blocks
238
+ for t, c, n, s in interverted_residual_setting:
239
+ output_channel = _make_divisible(int(c * alpha), 8)
240
+ for i in range(n):
241
+ if i == 0:
242
+ self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
243
+ else:
244
+ self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
245
+ input_channel = output_channel
246
+
247
+ # building last several layers
248
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
249
+
250
+ # make it nn.Sequential
251
+ self.features = nn.Sequential(*self.features)
252
+
253
+ # building classifier
254
+ if self.num_classes is not None:
255
+ self.classifier = nn.Sequential(
256
+ nn.Dropout(0.2),
257
+ nn.Linear(self.last_channel, num_classes),
258
+ )
259
+
260
+ # Initialize weights
261
+ self._init_weights()
262
+
263
+ def forward(self, x):
264
+ # Stage1
265
+ x = self.features[0](x)
266
+ x = self.features[1](x)
267
+ # Stage2
268
+ x = self.features[2](x)
269
+ x = self.features[3](x)
270
+ # Stage3
271
+ x = self.features[4](x)
272
+ x = self.features[5](x)
273
+ x = self.features[6](x)
274
+ # Stage4
275
+ x = self.features[7](x)
276
+ x = self.features[8](x)
277
+ x = self.features[9](x)
278
+ x = self.features[10](x)
279
+ x = self.features[11](x)
280
+ x = self.features[12](x)
281
+ x = self.features[13](x)
282
+ # Stage5
283
+ x = self.features[14](x)
284
+ x = self.features[15](x)
285
+ x = self.features[16](x)
286
+ x = self.features[17](x)
287
+ x = self.features[18](x)
288
+
289
+ # Classification
290
+ if self.num_classes is not None:
291
+ x = x.mean(dim=(2, 3))
292
+ x = self.classifier(x)
293
+
294
+ # Output
295
+ return x
296
+
297
+ def _load_pretrained_model(self, pretrained_file):
298
+ pretrain_dict = torch.load(pretrained_file, map_location='cpu')
299
+ model_dict = {}
300
+ state_dict = self.state_dict()
301
+ print("[MobileNetV2] Loading pretrained model...")
302
+ for k, v in pretrain_dict.items():
303
+ if k in state_dict:
304
+ model_dict[k] = v
305
+ else:
306
+ print(k, "is ignored")
307
+ state_dict.update(model_dict)
308
+ self.load_state_dict(state_dict)
309
+
310
+ def _init_weights(self):
311
+ for m in self.modules():
312
+ if isinstance(m, nn.Conv2d):
313
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
314
+ m.weight.data.normal_(0, math.sqrt(2. / n))
315
+ if m.bias is not None:
316
+ m.bias.data.zero_()
317
+ elif isinstance(m, nn.BatchNorm2d):
318
+ m.weight.data.fill_(1)
319
+ m.bias.data.zero_()
320
+ elif isinstance(m, nn.Linear):
321
+ n = m.weight.size(1)
322
+ m.weight.data.normal_(0, 0.01)
323
+ m.bias.data.zero_()
324
+
325
+
326
+ class BaseBackbone(nn.Module):
327
+ """ Superclass of Replaceable Backbone Model for Semantic Estimation
328
+ """
329
+
330
+ def __init__(self, in_channels):
331
+ super(BaseBackbone, self).__init__()
332
+ self.in_channels = in_channels
333
+
334
+ self.model = None
335
+ self.enc_channels = []
336
+
337
+ def forward(self, x):
338
+ raise NotImplementedError
339
+
340
+ def load_pretrained_ckpt(self):
341
+ raise NotImplementedError
342
+
343
+
344
+ class MobileNetV2Backbone(BaseBackbone):
345
+ """ MobileNetV2 Backbone
346
+ """
347
+
348
+ def __init__(self, in_channels):
349
+ super(MobileNetV2Backbone, self).__init__(in_channels)
350
+
351
+ self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
352
+ self.enc_channels = [16, 24, 32, 96, 1280]
353
+
354
+ def forward(self, x):
355
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
356
+ x = self.model.features[0](x)
357
+ x = self.model.features[1](x)
358
+ enc2x = x
359
+
360
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
361
+ x = self.model.features[2](x)
362
+ x = self.model.features[3](x)
363
+ enc4x = x
364
+
365
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
366
+ x = self.model.features[4](x)
367
+ x = self.model.features[5](x)
368
+ x = self.model.features[6](x)
369
+ enc8x = x
370
+
371
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
372
+ x = self.model.features[7](x)
373
+ x = self.model.features[8](x)
374
+ x = self.model.features[9](x)
375
+ x = self.model.features[10](x)
376
+ x = self.model.features[11](x)
377
+ x = self.model.features[12](x)
378
+ x = self.model.features[13](x)
379
+ enc16x = x
380
+
381
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
382
+ x = self.model.features[14](x)
383
+ x = self.model.features[15](x)
384
+ x = self.model.features[16](x)
385
+ x = self.model.features[17](x)
386
+ x = self.model.features[18](x)
387
+ enc32x = x
388
+ return [enc2x, enc4x, enc8x, enc16x, enc32x]
389
+
390
+ def load_pretrained_ckpt(self):
391
+ # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
392
+ ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
393
+ if not os.path.exists(ckpt_path):
394
+ print('cannot find the pretrained mobilenetv2 backbone')
395
+ exit()
396
+
397
+ ckpt = torch.load(ckpt_path)
398
+ self.model.load_state_dict(ckpt)
399
+
400
+
401
+ SUPPORTED_BACKBONES = {
402
+ 'mobilenetv2': MobileNetV2Backbone,
403
+ }
404
+
405
+
406
+ # ------------------------------------------------------------------------------
407
+ # MODNet Basic Modules
408
+ # ------------------------------------------------------------------------------
409
+
410
+ class IBNorm(nn.Module):
411
+ """ Combine Instance Norm and Batch Norm into One Layer
412
+ """
413
+
414
+ def __init__(self, in_channels):
415
+ super(IBNorm, self).__init__()
416
+ in_channels = in_channels
417
+ self.bnorm_channels = int(in_channels / 2)
418
+ self.inorm_channels = in_channels - self.bnorm_channels
419
+
420
+ self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
421
+ self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
422
+
423
+ def forward(self, x):
424
+ bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
425
+ in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
426
+
427
+ return torch.cat((bn_x, in_x), 1)
428
+
429
+
430
+ class Conv2dIBNormRelu(nn.Module):
431
+ """ Convolution + IBNorm + ReLu
432
+ """
433
+
434
+ def __init__(self, in_channels, out_channels, kernel_size,
435
+ stride=1, padding=0, dilation=1, groups=1, bias=True,
436
+ with_ibn=True, with_relu=True):
437
+ super(Conv2dIBNormRelu, self).__init__()
438
+
439
+ layers = [
440
+ nn.Conv2d(in_channels, out_channels, kernel_size,
441
+ stride=stride, padding=padding, dilation=dilation,
442
+ groups=groups, bias=bias)
443
+ ]
444
+
445
+ if with_ibn:
446
+ layers.append(IBNorm(out_channels))
447
+ if with_relu:
448
+ layers.append(nn.ReLU(inplace=True))
449
+
450
+ self.layers = nn.Sequential(*layers)
451
+
452
+ def forward(self, x):
453
+ return self.layers(x)
454
+
455
+
456
+ class SEBlock(nn.Module):
457
+ """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
458
+ """
459
+
460
+ def __init__(self, in_channels, out_channels, reduction=1):
461
+ super(SEBlock, self).__init__()
462
+ self.pool = nn.AdaptiveAvgPool2d(1)
463
+ self.fc = nn.Sequential(
464
+ nn.Linear(in_channels, int(in_channels // reduction), bias=False),
465
+ nn.ReLU(inplace=True),
466
+ nn.Linear(int(in_channels // reduction), out_channels, bias=False),
467
+ nn.Sigmoid()
468
+ )
469
+
470
+ def forward(self, x):
471
+ b, c, _, _ = x.size()
472
+ w = self.pool(x).view(b, c)
473
+ w = self.fc(w).view(b, c, 1, 1)
474
+
475
+ return x * w.expand_as(x)
476
+
477
+
478
+ # ------------------------------------------------------------------------------
479
+ # MODNet Branches
480
+ # ------------------------------------------------------------------------------
481
+
482
+ class LRBranch(nn.Module):
483
+ """ Low Resolution Branch of MODNet
484
+ """
485
+
486
+ def __init__(self, backbone):
487
+ super(LRBranch, self).__init__()
488
+
489
+ enc_channels = backbone.enc_channels
490
+
491
+ self.backbone = backbone
492
+ self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
493
+ self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
494
+ self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
495
+ self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
496
+ with_relu=False)
497
+
498
+ def forward(self, img, inference):
499
+ enc_features = self.backbone.forward(img)
500
+ enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
501
+
502
+ enc32x = self.se_block(enc32x)
503
+ lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
504
+ lr16x = self.conv_lr16x(lr16x)
505
+ lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
506
+ lr8x = self.conv_lr8x(lr8x)
507
+
508
+ pred_semantic = None
509
+ if not inference:
510
+ lr = self.conv_lr(lr8x)
511
+ pred_semantic = torch.sigmoid(lr)
512
+
513
+ return pred_semantic, lr8x, [enc2x, enc4x]
514
+
515
+
516
+ class HRBranch(nn.Module):
517
+ """ High Resolution Branch of MODNet
518
+ """
519
+
520
+ def __init__(self, hr_channels, enc_channels):
521
+ super(HRBranch, self).__init__()
522
+
523
+ self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
524
+ self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
525
+
526
+ self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
527
+ self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
528
+
529
+ self.conv_hr4x = nn.Sequential(
530
+ Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
531
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
532
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
533
+ )
534
+
535
+ self.conv_hr2x = nn.Sequential(
536
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
537
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
538
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
539
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
540
+ )
541
+
542
+ self.conv_hr = nn.Sequential(
543
+ Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
544
+ Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
545
+ )
546
+
547
+ def forward(self, img, enc2x, enc4x, lr8x, inference):
548
+ img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
549
+ img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
550
+
551
+ enc2x = self.tohr_enc2x(enc2x)
552
+ hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
553
+
554
+ enc4x = self.tohr_enc4x(enc4x)
555
+ hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
556
+
557
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
558
+ hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
559
+
560
+ hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
561
+ hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
562
+
563
+ pred_detail = None
564
+ if not inference:
565
+ hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
566
+ hr = self.conv_hr(torch.cat((hr, img), dim=1))
567
+ pred_detail = torch.sigmoid(hr)
568
+
569
+ return pred_detail, hr2x
570
+
571
+
572
+ class FusionBranch(nn.Module):
573
+ """ Fusion Branch of MODNet
574
+ """
575
+
576
+ def __init__(self, hr_channels, enc_channels):
577
+ super(FusionBranch, self).__init__()
578
+ self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
579
+
580
+ self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
581
+ self.conv_f = nn.Sequential(
582
+ Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
583
+ Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
584
+ )
585
+
586
+ def forward(self, img, lr8x, hr2x):
587
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
588
+ lr4x = self.conv_lr4x(lr4x)
589
+ lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
590
+
591
+ f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
592
+ f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
593
+ f = self.conv_f(torch.cat((f, img), dim=1))
594
+ pred_matte = torch.sigmoid(f)
595
+
596
+ return pred_matte
597
+
598
+
599
+ # ------------------------------------------------------------------------------
600
+ # MODNet
601
+ # ------------------------------------------------------------------------------
602
+
603
+ class MODNet(nn.Module):
604
+ """ Architecture of MODNet
605
+ """
606
+
607
+ def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=False):
608
+ super(MODNet, self).__init__()
609
+
610
+ self.in_channels = in_channels
611
+ self.hr_channels = hr_channels
612
+ self.backbone_arch = backbone_arch
613
+ self.backbone_pretrained = backbone_pretrained
614
+
615
+ self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
616
+
617
+ self.lr_branch = LRBranch(self.backbone)
618
+ self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
619
+ self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
620
+
621
+ for m in self.modules():
622
+ if isinstance(m, nn.Conv2d):
623
+ self._init_conv(m)
624
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
625
+ self._init_norm(m)
626
+
627
+ if self.backbone_pretrained:
628
+ self.backbone.load_pretrained_ckpt()
629
+
630
+ def forward(self, img, inference):
631
+ pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
632
+ pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
633
+ pred_matte = self.f_branch(img, lr8x, hr2x)
634
+
635
+ return pred_semantic, pred_detail, pred_matte
636
+
637
+ @staticmethod
638
+ def compute_loss(args):
639
+ pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte = args
640
+ semantic_loss, detail_loss, matte_loss = loss_func(pred_semantic, pred_detail, pred_matte,
641
+ image, trimap, gt_matte)
642
+ loss = semantic_loss + detail_loss + matte_loss
643
+ return matte_loss, loss
644
+
645
+ def freeze_norm(self):
646
+ norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
647
+ for m in self.modules():
648
+ for n in norm_types:
649
+ if isinstance(m, n):
650
+ m.eval()
651
+ continue
652
+
653
+ def _init_conv(self, conv):
654
+ nn.init.kaiming_uniform_(
655
+ conv.weight, a=0, mode='fan_in', nonlinearity='relu')
656
+ if conv.bias is not None:
657
+ nn.init.constant_(conv.bias, 0)
658
+
659
+ def _init_norm(self, norm):
660
+ if norm.weight is not None:
661
+ nn.init.constant_(norm.weight, 1)
662
+ nn.init.constant_(norm.bias, 0)
663
+
664
+ def _apply(self, fn):
665
+ super(MODNet, self)._apply(fn)
666
+ blurer._apply(fn) # let blurer's device same as modnet
667
+ return self
animeinsseg/models/animeseg_refine/u2net.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/U-2-Net/blob/master/model/u2net_refactor.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ __all__ = ['U2NET_full', 'U2NET_full2', 'U2NET_lite', 'U2NET_lite2', "U2NET"]
10
+
11
+ bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
12
+
13
+
14
+ def _upsample_like(x, size):
15
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
16
+
17
+
18
+ def _size_map(x, height):
19
+ # {height: size} for Upsample
20
+ size = list(x.shape[-2:])
21
+ sizes = {}
22
+ for h in range(1, height):
23
+ sizes[h] = size
24
+ size = [math.ceil(w / 2) for w in size]
25
+ return sizes
26
+
27
+
28
+ class REBNCONV(nn.Module):
29
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
30
+ super(REBNCONV, self).__init__()
31
+
32
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
33
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
34
+ self.relu_s1 = nn.ReLU(inplace=True)
35
+
36
+ def forward(self, x):
37
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
38
+
39
+
40
+ class RSU(nn.Module):
41
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
42
+ super(RSU, self).__init__()
43
+ self.name = name
44
+ self.height = height
45
+ self.dilated = dilated
46
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
47
+
48
+ def forward(self, x):
49
+ sizes = _size_map(x, self.height)
50
+ x = self.rebnconvin(x)
51
+
52
+ # U-Net like symmetric encoder-decoder structure
53
+ def unet(x, height=1):
54
+ if height < self.height:
55
+ x1 = getattr(self, f'rebnconv{height}')(x)
56
+ if not self.dilated and height < self.height - 1:
57
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
58
+ else:
59
+ x2 = unet(x1, height + 1)
60
+
61
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
62
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
63
+ else:
64
+ return getattr(self, f'rebnconv{height}')(x)
65
+
66
+ return x + unet(x)
67
+
68
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
69
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
70
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
71
+
72
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
73
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
74
+
75
+ for i in range(2, height):
76
+ dilate = 1 if not dilated else 2 ** (i - 1)
77
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
78
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
79
+
80
+ dilate = 2 if not dilated else 2 ** (height - 1)
81
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
82
+
83
+
84
+ class U2NET(nn.Module):
85
+ def __init__(self, cfgs, out_ch):
86
+ super(U2NET, self).__init__()
87
+ self.out_ch = out_ch
88
+ self._make_layers(cfgs)
89
+
90
+ def forward(self, x):
91
+ sizes = _size_map(x, self.height)
92
+ maps = [] # storage for maps
93
+
94
+ # side saliency map
95
+ def unet(x, height=1):
96
+ if height < 6:
97
+ x1 = getattr(self, f'stage{height}')(x)
98
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
99
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
100
+ side(x, height)
101
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
102
+ else:
103
+ x = getattr(self, f'stage{height}')(x)
104
+ side(x, height)
105
+ return _upsample_like(x, sizes[height - 1])
106
+
107
+ def side(x, h):
108
+ # side output saliency map (before sigmoid)
109
+ x = getattr(self, f'side{h}')(x)
110
+ x = _upsample_like(x, sizes[1])
111
+ maps.append(x)
112
+
113
+ def fuse():
114
+ # fuse saliency probability maps
115
+ maps.reverse()
116
+ x = torch.cat(maps, 1)
117
+ x = getattr(self, 'outconv')(x)
118
+ maps.insert(0, x)
119
+ # return [torch.sigmoid(x) for x in maps]
120
+ return [x for x in maps]
121
+
122
+ unet(x)
123
+ maps = fuse()
124
+ return maps
125
+
126
+ @staticmethod
127
+ def compute_loss(args):
128
+ preds, labels_v = args
129
+ d0, d1, d2, d3, d4, d5, d6 = preds
130
+ loss0 = bce_loss(d0, labels_v)
131
+ loss1 = bce_loss(d1, labels_v)
132
+ loss2 = bce_loss(d2, labels_v)
133
+ loss3 = bce_loss(d3, labels_v)
134
+ loss4 = bce_loss(d4, labels_v)
135
+ loss5 = bce_loss(d5, labels_v)
136
+ loss6 = bce_loss(d6, labels_v)
137
+
138
+ loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
139
+
140
+ return loss0, loss
141
+
142
+ def _make_layers(self, cfgs):
143
+ self.height = int((len(cfgs) + 1) / 2)
144
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
145
+ for k, v in cfgs.items():
146
+ # build rsu block
147
+ self.add_module(k, RSU(v[0], *v[1]))
148
+ if v[2] > 0:
149
+ # build side layer
150
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
151
+ # build fuse layer
152
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
153
+
154
+
155
+ def U2NET_full():
156
+ full = {
157
+ # cfgs for building RSUs and sides
158
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
159
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
160
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
161
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
162
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
163
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
164
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
165
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
166
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
167
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
168
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
169
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
170
+ }
171
+ return U2NET(cfgs=full, out_ch=1)
172
+
173
+
174
+ def U2NET_full2():
175
+ full = {
176
+ # cfgs for building RSUs and sides
177
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
178
+ 'stage1': ['En_1', (8, 3, 32, 64), -1],
179
+ 'stage2': ['En_2', (7, 64, 32, 128), -1],
180
+ 'stage3': ['En_3', (6, 128, 64, 256), -1],
181
+ 'stage4': ['En_4', (5, 256, 128, 512), -1],
182
+ 'stage5': ['En_5', (5, 512, 256, 512, True), -1],
183
+ 'stage6': ['En_6', (5, 512, 256, 512, True), 512],
184
+ 'stage5d': ['De_5', (5, 1024, 256, 512, True), 512],
185
+ 'stage4d': ['De_4', (5, 1024, 128, 256), 256],
186
+ 'stage3d': ['De_3', (6, 512, 64, 128), 128],
187
+ 'stage2d': ['De_2', (7, 256, 32, 64), 64],
188
+ 'stage1d': ['De_1', (8, 128, 16, 64), 64],
189
+ }
190
+ return U2NET(cfgs=full, out_ch=1)
191
+
192
+
193
+ def U2NET_lite():
194
+ lite = {
195
+ # cfgs for building RSUs and sides
196
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
197
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
198
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
199
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
200
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
201
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
202
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
203
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
204
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
205
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
206
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
207
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
208
+ }
209
+ return U2NET(cfgs=lite, out_ch=1)
210
+
211
+
212
+ def U2NET_lite2():
213
+ lite = {
214
+ # cfgs for building RSUs and sides
215
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
216
+ 'stage1': ['En_1', (8, 3, 16, 64), -1],
217
+ 'stage2': ['En_2', (7, 64, 16, 64), -1],
218
+ 'stage3': ['En_3', (6, 64, 16, 64), -1],
219
+ 'stage4': ['En_4', (5, 64, 16, 64), -1],
220
+ 'stage5': ['En_5', (5, 64, 16, 64, True), -1],
221
+ 'stage6': ['En_6', (5, 64, 16, 64, True), 64],
222
+ 'stage5d': ['De_5', (5, 128, 16, 64, True), 64],
223
+ 'stage4d': ['De_4', (5, 128, 16, 64), 64],
224
+ 'stage3d': ['De_3', (6, 128, 16, 64), 64],
225
+ 'stage2d': ['De_2', (7, 128, 16, 64), 64],
226
+ 'stage1d': ['De_1', (8, 128, 16, 64), 64],
227
+ }
228
+ return U2NET(cfgs=lite, out_ch=1)
animeinsseg/models/rtmdet_inshead_custom.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import math
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from mmcv.cnn import ConvModule, is_norm
10
+ from mmcv.ops import batched_nms
11
+ from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
12
+ normal_init)
13
+ from mmengine.structures import InstanceData
14
+ from torch import Tensor
15
+
16
+ from mmdet.models.layers.transformer import inverse_sigmoid
17
+ from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
18
+ select_single_mlvl, sigmoid_geometric_mean)
19
+ from mmdet.registry import MODELS
20
+ from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor,
21
+ get_box_wh, scale_boxes)
22
+ from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
23
+ from mmdet.models.dense_heads.rtmdet_head import RTMDetHead
24
+ from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead, MaskFeatModule
25
+
26
+ from mmdet.utils import AvoidCUDAOOM
27
+
28
+
29
+
30
+ def sthgoeswrong(logits):
31
+ return torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits))
32
+
33
+ from time import time
34
+
35
+ @MODELS.register_module(force=True)
36
+ class RTMDetInsHeadCustom(RTMDetInsHead):
37
+
38
+ def loss_by_feat(self,
39
+ cls_scores: List[Tensor],
40
+ bbox_preds: List[Tensor],
41
+ kernel_preds: List[Tensor],
42
+ mask_feat: Tensor,
43
+ batch_gt_instances: InstanceList,
44
+ batch_img_metas: List[dict],
45
+ batch_gt_instances_ignore: OptInstanceList = None):
46
+ """Compute losses of the head.
47
+
48
+ Args:
49
+ cls_scores (list[Tensor]): Box scores for each scale level
50
+ Has shape (N, num_anchors * num_classes, H, W)
51
+ bbox_preds (list[Tensor]): Decoded box for each scale
52
+ level with shape (N, num_anchors * 4, H, W) in
53
+ [tl_x, tl_y, br_x, br_y] format.
54
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
55
+ gt_instance. It usually includes ``bboxes`` and ``labels``
56
+ attributes.
57
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
58
+ image size, scaling factor, etc.
59
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
60
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
61
+ data that is ignored during training and testing.
62
+ Defaults to None.
63
+
64
+ Returns:
65
+ dict[str, Tensor]: A dictionary of loss components.
66
+ """
67
+ num_imgs = len(batch_img_metas)
68
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
69
+ assert len(featmap_sizes) == self.prior_generator.num_levels
70
+
71
+ device = cls_scores[0].device
72
+ anchor_list, valid_flag_list = self.get_anchors(
73
+ featmap_sizes, batch_img_metas, device=device)
74
+ flatten_cls_scores = torch.cat([
75
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
76
+ self.cls_out_channels)
77
+ for cls_score in cls_scores
78
+ ], 1)
79
+ flatten_kernels = torch.cat([
80
+ kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
81
+ self.num_gen_params)
82
+ for kernel_pred in kernel_preds
83
+ ], 1)
84
+ decoded_bboxes = []
85
+ for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
86
+ anchor = anchor.reshape(-1, 4)
87
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
88
+ bbox_pred = distance2bbox(anchor, bbox_pred)
89
+ decoded_bboxes.append(bbox_pred)
90
+
91
+ flatten_bboxes = torch.cat(decoded_bboxes, 1)
92
+ for gt_instances in batch_gt_instances:
93
+ gt_instances.masks = gt_instances.masks.to_tensor(
94
+ dtype=torch.bool, device=device)
95
+
96
+ cls_reg_targets = self.get_targets(
97
+ flatten_cls_scores,
98
+ flatten_bboxes,
99
+ anchor_list,
100
+ valid_flag_list,
101
+ batch_gt_instances,
102
+ batch_img_metas,
103
+ batch_gt_instances_ignore=batch_gt_instances_ignore)
104
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
105
+ assign_metrics_list, sampling_results_list) = cls_reg_targets
106
+
107
+ losses_cls, losses_bbox,\
108
+ cls_avg_factors, bbox_avg_factors = multi_apply(
109
+ self.loss_by_feat_single,
110
+ cls_scores,
111
+ decoded_bboxes,
112
+ labels_list,
113
+ label_weights_list,
114
+ bbox_targets_list,
115
+ assign_metrics_list,
116
+ self.prior_generator.strides)
117
+
118
+ cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
119
+ losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
120
+
121
+ bbox_avg_factor = reduce_mean(
122
+ sum(bbox_avg_factors)).clamp_(min=1).item()
123
+ losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
124
+
125
+ loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels,
126
+ sampling_results_list,
127
+ batch_gt_instances)
128
+ loss = dict(
129
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask)
130
+
131
+ return loss
132
+
133
+
134
+ def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
135
+ priors: Tensor) -> Tensor:
136
+
137
+ ori_maskfeat = mask_feat
138
+
139
+ num_inst = priors.shape[0]
140
+ h, w = mask_feat.size()[-2:]
141
+ if num_inst < 1:
142
+ return torch.empty(
143
+ size=(num_inst, h, w),
144
+ dtype=mask_feat.dtype,
145
+ device=mask_feat.device)
146
+ if len(mask_feat.shape) < 4:
147
+ mask_feat.unsqueeze(0)
148
+
149
+ coord = self.prior_generator.single_level_grid_priors(
150
+ (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
151
+ num_inst = priors.shape[0]
152
+ points = priors[:, :2].reshape(-1, 1, 2)
153
+ strides = priors[:, 2:].reshape(-1, 1, 2)
154
+ relative_coord = (points - coord).permute(0, 2, 1) / (
155
+ strides[..., 0].reshape(-1, 1, 1) * 8)
156
+ relative_coord = relative_coord.reshape(num_inst, 2, h, w)
157
+
158
+ mask_feat = torch.cat(
159
+ [relative_coord,
160
+ mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
161
+ weights, biases = self.parse_dynamic_params(kernels)
162
+
163
+ fp16_used = weights[0].dtype == torch.float16
164
+
165
+ n_layers = len(weights)
166
+ x = mask_feat.reshape(1, -1, h, w)
167
+ for i, (weight, bias) in enumerate(zip(weights, biases)):
168
+ with torch.cuda.amp.autocast(enabled=False):
169
+ if fp16_used:
170
+ weight = weight.to(torch.float32)
171
+ bias = bias.to(torch.float32)
172
+ x = F.conv2d(
173
+ x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
174
+ if i < n_layers - 1:
175
+ x = F.relu(x)
176
+
177
+ if fp16_used:
178
+ x = torch.clip(x, -8192, 8192)
179
+ if sthgoeswrong(x):
180
+ torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt')
181
+ raise Exception('Mask Head NaN')
182
+
183
+ x = x.reshape(num_inst, h, w)
184
+ return x
185
+
186
+ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
187
+ sampling_results_list: list,
188
+ batch_gt_instances: InstanceList) -> Tensor:
189
+ batch_pos_mask_logits = []
190
+ pos_gt_masks = []
191
+ ignore_masks = []
192
+ for idx, (mask_feat, kernels, sampling_results,
193
+ gt_instances) in enumerate(
194
+ zip(mask_feats, flatten_kernels, sampling_results_list,
195
+ batch_gt_instances)):
196
+ pos_priors = sampling_results.pos_priors
197
+ pos_inds = sampling_results.pos_inds
198
+ pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
199
+ pos_mask_logits = self._mask_predict_by_feat_single(
200
+ mask_feat, pos_kernels, pos_priors)
201
+ if gt_instances.masks.numel() == 0:
202
+ gt_masks = torch.empty_like(gt_instances.masks)
203
+ if gt_masks.shape[0] > 0:
204
+ ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device)
205
+ ignore_masks.append(ignore)
206
+ else:
207
+ gt_masks = gt_instances.masks[
208
+ sampling_results.pos_assigned_gt_inds, :]
209
+ ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
210
+ batch_pos_mask_logits.append(pos_mask_logits)
211
+ pos_gt_masks.append(gt_masks)
212
+
213
+ pos_gt_masks = torch.cat(pos_gt_masks, 0)
214
+ batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
215
+ ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0))
216
+
217
+ pos_gt_masks = pos_gt_masks[ignore_masks]
218
+ batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks]
219
+
220
+
221
+ # avg_factor
222
+ num_pos = batch_pos_mask_logits.shape[0]
223
+ num_pos = reduce_mean(mask_feats.new_tensor([num_pos
224
+ ])).clamp_(min=1).item()
225
+
226
+ if batch_pos_mask_logits.shape[0] == 0:
227
+ return mask_feats.sum() * 0
228
+
229
+ scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
230
+ # upsample pred masks
231
+ batch_pos_mask_logits = F.interpolate(
232
+ batch_pos_mask_logits.unsqueeze(0),
233
+ scale_factor=scale,
234
+ mode='bilinear',
235
+ align_corners=False).squeeze(0)
236
+ # downsample gt masks
237
+ pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
238
+ 2::self.mask_loss_stride,
239
+ self.mask_loss_stride //
240
+ 2::self.mask_loss_stride]
241
+
242
+ loss_mask = self.loss_mask(
243
+ batch_pos_mask_logits,
244
+ pos_gt_masks,
245
+ weight=None,
246
+ avg_factor=num_pos)
247
+
248
+ return loss_mask
249
+
250
+
251
+ @MODELS.register_module()
252
+ class RTMDetInsSepBNHeadCustom(RTMDetInsSepBNHead):
253
+ def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
254
+ priors: Tensor) -> Tensor:
255
+
256
+ ori_maskfeat = mask_feat
257
+
258
+ num_inst = priors.shape[0]
259
+ h, w = mask_feat.size()[-2:]
260
+ if num_inst < 1:
261
+ return torch.empty(
262
+ size=(num_inst, h, w),
263
+ dtype=mask_feat.dtype,
264
+ device=mask_feat.device)
265
+ if len(mask_feat.shape) < 4:
266
+ mask_feat.unsqueeze(0)
267
+
268
+ coord = self.prior_generator.single_level_grid_priors(
269
+ (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
270
+ num_inst = priors.shape[0]
271
+ points = priors[:, :2].reshape(-1, 1, 2)
272
+ strides = priors[:, 2:].reshape(-1, 1, 2)
273
+ relative_coord = (points - coord).permute(0, 2, 1) / (
274
+ strides[..., 0].reshape(-1, 1, 1) * 8)
275
+ relative_coord = relative_coord.reshape(num_inst, 2, h, w)
276
+
277
+ mask_feat = torch.cat(
278
+ [relative_coord,
279
+ mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
280
+ weights, biases = self.parse_dynamic_params(kernels)
281
+
282
+ fp16_used = weights[0].dtype == torch.float16
283
+
284
+ n_layers = len(weights)
285
+ x = mask_feat.reshape(1, -1, h, w)
286
+ for i, (weight, bias) in enumerate(zip(weights, biases)):
287
+ with torch.cuda.amp.autocast(enabled=False):
288
+ if fp16_used:
289
+ weight = weight.to(torch.float32)
290
+ bias = bias.to(torch.float32)
291
+ x = F.conv2d(
292
+ x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
293
+ if i < n_layers - 1:
294
+ x = F.relu(x)
295
+
296
+ if fp16_used:
297
+ x = torch.clip(x, -8192, 8192)
298
+ if sthgoeswrong(x):
299
+ torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt')
300
+ raise Exception('Mask Head NaN')
301
+
302
+ x = x.reshape(num_inst, h, w)
303
+ return x
304
+
305
+ @AvoidCUDAOOM.retry_if_cuda_oom
306
+ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
307
+ sampling_results_list: list,
308
+ batch_gt_instances: InstanceList) -> Tensor:
309
+ batch_pos_mask_logits = []
310
+ pos_gt_masks = []
311
+ ignore_masks = []
312
+ for idx, (mask_feat, kernels, sampling_results,
313
+ gt_instances) in enumerate(
314
+ zip(mask_feats, flatten_kernels, sampling_results_list,
315
+ batch_gt_instances)):
316
+ pos_priors = sampling_results.pos_priors
317
+ pos_inds = sampling_results.pos_inds
318
+ pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
319
+ pos_mask_logits = self._mask_predict_by_feat_single(
320
+ mask_feat, pos_kernels, pos_priors)
321
+ if gt_instances.masks.numel() == 0:
322
+ gt_masks = torch.empty_like(gt_instances.masks)
323
+ # if gt_masks.shape[0] > 0:
324
+ # ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device)
325
+ # ignore_masks.append(ignore)
326
+ else:
327
+ msk = torch.logical_not(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
328
+ gt_masks = gt_instances.masks[
329
+ sampling_results.pos_assigned_gt_inds, :][msk]
330
+ pos_mask_logits = pos_mask_logits[msk]
331
+ # ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
332
+ batch_pos_mask_logits.append(pos_mask_logits)
333
+ pos_gt_masks.append(gt_masks)
334
+
335
+ pos_gt_masks = torch.cat(pos_gt_masks, 0)
336
+ batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
337
+ # ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0))
338
+
339
+ # pos_gt_masks = pos_gt_masks[ignore_masks]
340
+ # batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks]
341
+
342
+
343
+ # avg_factor
344
+ num_pos = batch_pos_mask_logits.shape[0]
345
+ num_pos = reduce_mean(mask_feats.new_tensor([num_pos
346
+ ])).clamp_(min=1).item()
347
+
348
+ if batch_pos_mask_logits.shape[0] == 0:
349
+ return mask_feats.sum() * 0
350
+
351
+ scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
352
+ # upsample pred masks
353
+ batch_pos_mask_logits = F.interpolate(
354
+ batch_pos_mask_logits.unsqueeze(0),
355
+ scale_factor=scale,
356
+ mode='bilinear',
357
+ align_corners=False).squeeze(0)
358
+ # downsample gt masks
359
+ pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
360
+ 2::self.mask_loss_stride,
361
+ self.mask_loss_stride //
362
+ 2::self.mask_loss_stride]
363
+
364
+ loss_mask = self.loss_mask(
365
+ batch_pos_mask_logits,
366
+ pos_gt_masks,
367
+ weight=None,
368
+ avg_factor=num_pos)
369
+
370
+ return loss_mask
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from animeinsseg import AnimeInsSeg, AnimeInstances
8
+ from animeinsseg.anime_instances import get_color
9
+
10
+ import os
11
+
12
+ if not os.path.exists("models"):
13
+ os.mkdir("models")
14
+
15
+ os.system("huggingface-cli lfs-enable-largefiles .")
16
+ os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
17
+
18
+ ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
19
+
20
+ mask_thres = 0.3
21
+ instance_thres = 0.3
22
+ refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
23
+ # refine_kwargs = None
24
+
25
+ net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
26
+
27
+ def fn(image):
28
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
29
+ instances: AnimeInstances = net.infer(
30
+ img,
31
+ output_type='numpy',
32
+ pred_score_thr=instance_thres
33
+ )
34
+
35
+ drawed = img.copy()
36
+ im_h, im_w = img.shape[:2]
37
+
38
+ # instances.bboxes, instances.masks will be None, None if no obj is detected
39
+
40
+ for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
41
+ color = get_color(ii)
42
+
43
+ mask_alpha = 0.5
44
+ linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
45
+
46
+ # draw bbox
47
+ p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1]))
48
+ cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
49
+
50
+ # draw mask
51
+ p = mask.astype(np.float32)
52
+ blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
53
+ alpha_msk = (mask_alpha * p)[..., None]
54
+ alpha_ori = 1 - alpha_msk
55
+ drawed = drawed * alpha_ori + alpha_msk * blend_mask
56
+
57
+ drawed = drawed.astype(np.uint8)
58
+
59
+ return Image.fromarray(drawed[..., ::-1])
60
+
61
+ iface = gr.Interface(
62
+ inputs=gr.Image(type="numpy"),
63
+ outputs="Image",
64
+ fn=fn
65
+ )
66
+
67
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ imageio
3
+ git+https://github.com/cocodataset/panopticapi.git
4
+ pytorch-lightning
5
+ albumentations
6
+ huggingface_hub
7
+
8
+ # For Web UI
9
+ gradio
10
+ torch
11
+ torchvision
12
+ openmim
13
+ mmengine
14
+ mmcv>=2.0.0
15
+ mmdet