Spaces:
Running
Running
import mmcv, torch | |
from tqdm import tqdm | |
from einops import rearrange | |
import os | |
import os.path as osp | |
import cv2 | |
import gc | |
import math | |
from .anime_instances import AnimeInstances | |
import numpy as np | |
from typing import List, Tuple, Union, Optional, Callable | |
from mmengine import Config | |
from mmengine.model.utils import revert_sync_batchnorm | |
from mmdet.utils import register_all_modules, get_test_pipeline_cfg | |
from mmdet.apis import init_detector | |
from mmdet.registry import MODELS | |
from mmdet.structures import DetDataSample, SampleList | |
from mmdet.structures.bbox.transforms import scale_boxes, get_box_wh | |
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead | |
from pycocotools.coco import COCO | |
from mmcv.transforms import Compose | |
from mmdet.models.detectors.single_stage import SingleStageDetector | |
from utils.logger import LOGGER | |
from utils.io_utils import square_pad_resize, find_all_imgs, imglist2grid, mask2rle, dict2json, scaledown_maxsize, resize_pad | |
from utils.constants import DEFAULT_DEVICE, CATEGORIES | |
from utils.booru_tagger import Tagger | |
from .models.animeseg_refine import AnimeSegmentation, load_refinenet, get_mask | |
from .models.rtmdet_inshead_custom import RTMDetInsSepBNHeadCustom | |
from torchvision.ops.boxes import box_iou | |
import torch.nn.functional as F | |
def prepare_refine_batch(segmentations: np.ndarray, img: np.ndarray, max_batch_size: int = 4, device: str = 'cpu', input_size: int = 720): | |
img, (pt, pb, pl, pr) = resize_pad(img, input_size, pad_value=(0, 0, 0)) | |
img = img.transpose((2, 0, 1)).astype(np.float32) / 255. | |
batch = [] | |
num_seg = len(segmentations) | |
for ii, seg in enumerate(segmentations): | |
seg, _ = resize_pad(seg, input_size, 0) | |
seg = seg[None, ...] | |
batch.append(np.concatenate((img, seg))) | |
if ii == num_seg - 1: | |
yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr) | |
elif len(batch) >= max_batch_size: | |
yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr) | |
batch = [] | |
VALID_REFINEMETHODS = {'animeseg', 'none'} | |
register_all_modules() | |
def single_image_preprocess(img: Union[str, np.ndarray], pipeline: Compose): | |
if isinstance(img, str): | |
img = mmcv.imread(img) | |
elif not isinstance(img, np.ndarray): | |
raise NotImplementedError | |
# img = square_pad_resize(img, 1024)[0] | |
data_ = dict(img=img, img_id=0) | |
data_ = pipeline(data_) | |
data_['inputs'] = [data_['inputs']] | |
data_['data_samples'] = [data_['data_samples']] | |
return data_, img | |
def animeseg_refine(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024): | |
num_pred = len(det_pred.pred_instances) | |
if num_pred < 1: | |
return | |
with torch.no_grad(): | |
if to_rgb: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
seg_thr = 0.5 | |
mask = get_mask(net, img, s=input_size)[..., 0] | |
mask = (mask > seg_thr) | |
ins_masks = det_pred.pred_instances.masks | |
if isinstance(ins_masks, torch.Tensor): | |
tensor_device = ins_masks.device | |
tensor_dtype = ins_masks.dtype | |
to_tensor = True | |
ins_masks = ins_masks.cpu().numpy() | |
area_original = np.sum(ins_masks, axis=(1, 2)) | |
masks_refined = np.bitwise_and(ins_masks, mask[None, ...]) | |
area_refined = np.sum(masks_refined, axis=(1, 2)) | |
for ii in range(num_pred): | |
if area_refined[ii] / area_original[ii] > 0.3: | |
ins_masks[ii] = masks_refined[ii] | |
ins_masks = np.ascontiguousarray(ins_masks) | |
# for ii, insm in enumerate(ins_masks): | |
# cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255) | |
if to_tensor: | |
ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device) | |
det_pred.pred_instances.masks = ins_masks | |
# rst = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8) | |
# cv2.imwrite('rst.png', rst) | |
# def refinenet_forward(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024): | |
# num_pred = len(det_pred.pred_instances) | |
# if num_pred < 1: | |
# return | |
# with torch.no_grad(): | |
# if to_rgb: | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# seg_thr = 0.5 | |
# h0, w0 = h, w = img.shape[0], img.shape[1] | |
# if h > w: | |
# h, w = input_size, int(input_size * w / h) | |
# else: | |
# h, w = int(input_size * h / w), input_size | |
# ph, pw = input_size - h, input_size - w | |
# tmpImg = np.zeros([s, s, 3], dtype=np.float32) | |
# tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255 | |
# tmpImg = tmpImg.transpose((2, 0, 1)) | |
# tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device) | |
# with torch.no_grad(): | |
# if use_amp: | |
# with amp.autocast(): | |
# pred = model(tmpImg) | |
# pred = pred.to(dtype=torch.float32) | |
# else: | |
# pred = model(tmpImg) | |
# pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] | |
# pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis] | |
# return pred | |
# mask = (mask > seg_thr) | |
# ins_masks = det_pred.pred_instances.masks | |
# if isinstance(ins_masks, torch.Tensor): | |
# tensor_device = ins_masks.device | |
# tensor_dtype = ins_masks.dtype | |
# to_tensor = True | |
# ins_masks = ins_masks.cpu().numpy() | |
# area_original = np.sum(ins_masks, axis=(1, 2)) | |
# masks_refined = np.bitwise_and(ins_masks, mask[None, ...]) | |
# area_refined = np.sum(masks_refined, axis=(1, 2)) | |
# for ii in range(num_pred): | |
# if area_refined[ii] / area_original[ii] > 0.3: | |
# ins_masks[ii] = masks_refined[ii] | |
# ins_masks = np.ascontiguousarray(ins_masks) | |
# # for ii, insm in enumerate(ins_masks): | |
# # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255) | |
# if to_tensor: | |
# ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device) | |
# det_pred.pred_instances.masks = ins_masks | |
def read_imglst_from_txt(filep) -> List[str]: | |
with open(filep, 'r', encoding='utf8') as f: | |
lines = f.read().splitlines() | |
return lines | |
class AnimeInsSeg: | |
def __init__(self, ckpt: str, default_det_size: int = 640, device: str = None, | |
refine_kwargs: dict = {'refine_method': 'refinenet_isnet'}, | |
tagger_path: str = 'models/wd-v1-4-swinv2-tagger-v2/model.onnx', mask_thr=0.3) -> None: | |
self.ckpt = ckpt | |
self.default_det_size = default_det_size | |
self.device = DEFAULT_DEVICE if device is None else device | |
# init detector in mmdet's way | |
ckpt = torch.load(ckpt, map_location='cpu') | |
cfg = Config.fromstring(ckpt['meta']['cfg'].replace('file_client_args', 'backend_args'), file_format='.py') | |
cfg.visualizer = [] | |
cfg.vis_backends = {} | |
cfg.default_hooks.pop('visualization') | |
# self.model: SingleStageDetector = init_detector(cfg, checkpoint=None, device='cpu') | |
model = MODELS.build(cfg.model) | |
model = revert_sync_batchnorm(model) | |
self.model = model.to(self.device).eval() | |
self.model.load_state_dict(ckpt['state_dict'], strict=False) | |
self.model = self.model.to(self.device).eval() | |
self.cfg = cfg.copy() | |
test_pipeline = get_test_pipeline_cfg(self.cfg.copy()) | |
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' | |
test_pipeline = Compose(test_pipeline) | |
self.default_data_pipeline = test_pipeline | |
self.refinenet = None | |
self.refinenet_animeseg: AnimeSegmentation = None | |
self.postprocess_refine: Callable = None | |
if refine_kwargs is not None: | |
self.set_refine_method(**refine_kwargs) | |
self.tagger = None | |
self.tagger_path = tagger_path | |
self.mask_thr = mask_thr | |
def init_tagger(self, tagger_path: str = None): | |
tagger_path = self.tagger_path if tagger_path is None else tagger_path | |
self.tagger = Tagger(self.tagger_path) | |
def infer_tags(self, instances: AnimeInstances, img: np.ndarray, infer_grey: bool = False): | |
if self.tagger is None: | |
self.init_tagger() | |
if infer_grey: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None][..., [0, 0, 0]] | |
num_ins = len(instances) | |
for ii in range(num_ins): | |
bbox = instances.bboxes[ii] | |
mask = instances.masks[ii] | |
if isinstance(bbox, torch.Tensor): | |
bbox = bbox.cpu().numpy() | |
mask = mask.cpu().numpy() | |
bbox = bbox.astype(np.int32) | |
crop = img[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]].copy() | |
mask = mask[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]] | |
crop[mask == 0] = 255 | |
tags, character_tags = self.tagger.label_cv2_bgr(crop) | |
exclude_tags = ['simple_background', 'white_background'] | |
valid_tags = [] | |
for tag in tags: | |
if tag in exclude_tags: | |
continue | |
valid_tags.append(tag) | |
instances.tags[ii] = ' '.join(valid_tags) | |
instances.character_tags[ii] = character_tags | |
def infer_embeddings(self, imgs, det_size = None): | |
def hijack_bbox_mask_post_process( | |
self, | |
results, | |
mask_feat, | |
cfg, | |
rescale: bool = False, | |
with_nms: bool = True, | |
img_meta: Optional[dict] = None): | |
stride = self.prior_generator.strides[0][0] | |
if rescale: | |
assert img_meta.get('scale_factor') is not None | |
scale_factor = [1 / s for s in img_meta['scale_factor']] | |
results.bboxes = scale_boxes(results.bboxes, scale_factor) | |
if hasattr(results, 'score_factors'): | |
# TODO: Add sqrt operation in order to be consistent with | |
# the paper. | |
score_factors = results.pop('score_factors') | |
results.scores = results.scores * score_factors | |
# filter small size bboxes | |
if cfg.get('min_bbox_size', -1) >= 0: | |
w, h = get_box_wh(results.bboxes) | |
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) | |
if not valid_mask.all(): | |
results = results[valid_mask] | |
# results.mask_feat = mask_feat | |
return results, mask_feat | |
def hijack_detector_predict(self: SingleStageDetector, | |
batch_inputs: torch.Tensor, | |
batch_data_samples: SampleList, | |
rescale: bool = True) -> SampleList: | |
x = self.extract_feat(batch_inputs) | |
bbox_head: RTMDetInsSepBNHeadCustom = self.bbox_head | |
old_postprocess = RTMDetInsSepBNHeadCustom._bbox_mask_post_process | |
RTMDetInsSepBNHeadCustom._bbox_mask_post_process = hijack_bbox_mask_post_process | |
# results_list = bbox_head.predict( | |
# x, batch_data_samples, rescale=rescale) | |
batch_img_metas = [ | |
data_samples.metainfo for data_samples in batch_data_samples | |
] | |
outs = bbox_head(x) | |
results_list = bbox_head.predict_by_feat( | |
*outs, batch_img_metas=batch_img_metas, rescale=rescale) | |
# batch_data_samples = self.add_pred_to_datasample( | |
# batch_data_samples, results_list) | |
RTMDetInsSepBNHeadCustom._bbox_mask_post_process = old_postprocess | |
return results_list | |
old_predict = SingleStageDetector.predict | |
SingleStageDetector.predict = hijack_detector_predict | |
test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size) | |
if len(imgs) > 1: | |
imgs = tqdm(imgs) | |
model = self.model | |
img = imgs[0] | |
data_, img = test_pipeline(img) | |
data = model.data_preprocessor(data_, False) | |
instance_data, mask_feat = model(**data, mode='predict')[0] | |
SingleStageDetector.predict = old_predict | |
# print((instance_data.scores > 0.9).sum()) | |
return img, instance_data, mask_feat | |
def segment_with_bboxes(self, img, bboxes: torch.Tensor, instance_data, mask_feat: torch.Tensor): | |
# instance_data.bboxes: x1, y1, x2, y2 | |
maxidx = torch.argmax(instance_data.scores) | |
bbox = instance_data.bboxes[maxidx].cpu().numpy() | |
p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])) | |
tgt_bboxes = instance_data.bboxes | |
im_h, im_w = img.shape[:2] | |
long_side = max(im_h, im_w) | |
bbox_head: RTMDetInsSepBNHeadCustom = self.model.bbox_head | |
priors, kernels = instance_data.priors, instance_data.kernels | |
stride = bbox_head.prior_generator.strides[0][0] | |
ins_bboxes, ins_segs, scores = [], [], [] | |
for bbox in bboxes: | |
bbox = torch.from_numpy(np.array([bbox])).to(tgt_bboxes.dtype).to(tgt_bboxes.device) | |
ioulst = box_iou(bbox, tgt_bboxes).squeeze() | |
matched_idx = torch.argmax(ioulst) | |
mask_logits = bbox_head._mask_predict_by_feat_single( | |
mask_feat, kernels[matched_idx][None, ...], priors[matched_idx][None, ...]) | |
mask_logits = F.interpolate( | |
mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear') | |
mask_logits = F.interpolate( | |
mask_logits, | |
size=[long_side, long_side], | |
mode='bilinear', | |
align_corners=False)[..., :im_h, :im_w] | |
mask = mask_logits.sigmoid().squeeze() | |
mask = mask > 0.5 | |
mask = mask.cpu().numpy() | |
ins_segs.append(mask) | |
matched_iou_score = ioulst[matched_idx] | |
matched_score = instance_data.scores[matched_idx] | |
scores.append(matched_score.cpu().item()) | |
matched_bbox = tgt_bboxes[matched_idx] | |
ins_bboxes.append(matched_bbox.cpu().numpy()) | |
# p1, p2 = (int(matched_bbox[0]), int(matched_bbox[1])), (int(matched_bbox[2]), int(matched_bbox[3])) | |
if len(ins_bboxes) > 0: | |
ins_bboxes = np.array(ins_bboxes).astype(np.int32) | |
ins_bboxes[:, 2:] -= ins_bboxes[:, :2] | |
ins_segs = np.array(ins_segs) | |
instances = AnimeInstances(ins_segs, ins_bboxes, scores) | |
self._postprocess_refine(instances, img) | |
drawed = instances.draw_instances(img) | |
# cv2.imshow('drawed', drawed) | |
# cv2.waitKey(0) | |
return instances | |
def set_detect_size(self, det_size: Union[int, Tuple]): | |
if isinstance(det_size, int): | |
det_size = (det_size, det_size) | |
self.default_data_pipeline.transforms[1].scale = det_size | |
self.default_data_pipeline.transforms[2].size = det_size | |
def infer(self, imgs: Union[List, str, np.ndarray], | |
pred_score_thr: float = 0.3, | |
refine_kwargs: dict = None, | |
output_type: str="tensor", | |
det_size: int = None, | |
save_dir: str = '', | |
save_visualization: bool = False, | |
save_annotation: str = '', | |
infer_tags: bool = False, | |
obj_id_start: int = -1, | |
img_id_start: int = -1, | |
verbose: bool = False, | |
infer_grey: bool = False, | |
save_mask_only: bool = False, | |
val_dir=None, | |
max_instances: int = 100, | |
**kwargs) -> Union[List[AnimeInstances], AnimeInstances, None]: | |
""" | |
Args: | |
imgs (str, ndarray, Sequence[str/ndarray]): | |
Either image files or loaded images. | |
Returns: | |
:obj:`AnimeInstances` or list[:obj:`AnimeInstances`]: | |
If save_annotation or save_annotation, return None. | |
""" | |
if det_size is not None: | |
self.set_detect_size(det_size) | |
if refine_kwargs is not None: | |
self.set_refine_method(**refine_kwargs) | |
self.set_max_instance(max_instances) | |
if isinstance(imgs, str): | |
if imgs.endswith('.txt'): | |
imgs = read_imglst_from_txt(imgs) | |
if save_annotation or save_visualization: | |
return self._infer_save_annotations(imgs, pred_score_thr, det_size, save_dir, save_visualization, \ | |
save_annotation, infer_tags, obj_id_start, img_id_start, val_dir=val_dir) | |
else: | |
return self._infer_simple(imgs, pred_score_thr, det_size, output_type, infer_tags, verbose=verbose, infer_grey=infer_grey) | |
def _det_forward(self, img, test_pipeline, pred_score_thr: float = 0.3) -> Tuple[AnimeInstances, np.ndarray]: | |
data_, img = test_pipeline(img) | |
with torch.no_grad(): | |
results: DetDataSample = self.model.test_step(data_)[0] | |
pred_instances = results.pred_instances | |
pred_instances = pred_instances[pred_instances.scores > pred_score_thr] | |
if len(pred_instances) < 1: | |
return AnimeInstances(), img | |
del data_ | |
bboxes = pred_instances.bboxes.to(torch.int32) | |
bboxes[:, 2:] -= bboxes[:, :2] | |
masks = pred_instances.masks | |
scores = pred_instances.scores | |
return AnimeInstances(masks, bboxes, scores), img | |
def _infer_simple(self, imgs: Union[List, str, np.ndarray], | |
pred_score_thr: float = 0.3, | |
det_size: int = None, | |
output_type: str = "tensor", | |
infer_tags: bool = False, | |
infer_grey: bool = False, | |
verbose: bool = False) -> Union[DetDataSample, List[DetDataSample]]: | |
if isinstance(imgs, List): | |
return_list = True | |
else: | |
return_list = False | |
assert output_type in {'tensor', 'numpy'} | |
test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size) | |
predictions = [] | |
if len(imgs) > 1: | |
imgs = tqdm(imgs) | |
for img in imgs: | |
instances, img = self._det_forward(img, test_pipeline, pred_score_thr) | |
# drawed = instances.draw_instances(img) | |
# cv2.imwrite('drawed.jpg', drawed) | |
self.postprocess_results(instances, img) | |
# drawed = instances.draw_instances(img) | |
# cv2.imwrite('drawed_post.jpg', drawed) | |
if infer_tags: | |
self.infer_tags(instances, img, infer_grey) | |
if output_type == 'numpy': | |
instances.to_numpy() | |
predictions.append(instances) | |
if return_list: | |
return predictions | |
else: | |
return predictions[0] | |
def _infer_save_annotations(self, imgs: Union[List, str, np.ndarray], | |
pred_score_thr: float = 0.3, | |
det_size: int = None, | |
save_dir: str = '', | |
save_visualization: bool = False, | |
save_annotation: str = '', | |
infer_tags: bool = False, | |
obj_id_start: int = 100000000000, | |
img_id_start: int = 100000000000, | |
save_mask_only: bool = False, | |
val_dir = None, | |
**kwargs) -> None: | |
coco_api = None | |
if isinstance(imgs, str) and imgs.endswith('.json'): | |
coco_api = COCO(imgs) | |
if val_dir is None: | |
val_dir = osp.join(osp.dirname(osp.dirname(imgs)), 'val') | |
imgs = coco_api.getImgIds() | |
imgp2ids = {} | |
imgps, coco_imgmetas = [], [] | |
for imgid in imgs: | |
imeta = coco_api.loadImgs(imgid)[0] | |
imgname = imeta['file_name'] | |
imgp = osp.join(val_dir, imgname) | |
imgp2ids[imgp] = imgid | |
imgps.append(imgp) | |
coco_imgmetas.append(imeta) | |
imgs = imgps | |
test_pipeline, imgs, target_dir = self.prepare_data_pipeline(imgs, det_size) | |
if save_dir == '': | |
save_dir = osp.join(target_dir, \ | |
osp.basename(self.ckpt).replace('.ckpt', '').replace('.pth', '').replace('.pt', '')) | |
if not osp.exists(save_dir): | |
os.makedirs(save_dir) | |
det_annotations = [] | |
image_meta = [] | |
obj_id = obj_id_start + 1 | |
image_id = img_id_start + 1 | |
for ii, img in enumerate(tqdm(imgs)): | |
# prepare data | |
if isinstance(img, str): | |
img_name = osp.basename(img) | |
else: | |
img_name = f'{ii}'.zfill(12) + '.jpg' | |
if coco_api is not None: | |
image_id = imgp2ids[img] | |
try: | |
instances, img = self._det_forward(img, test_pipeline, pred_score_thr) | |
except Exception as e: | |
raise e | |
if isinstance(e, torch.cuda.OutOfMemoryError): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
try: | |
instances, img = self._det_forward(img, test_pipeline, pred_score_thr) | |
except: | |
LOGGER.warning(f'cuda out of memory: {img_name}') | |
if isinstance(img, str): | |
img = cv2.imread(img) | |
instances = None | |
if instances is not None: | |
self.postprocess_results(instances, img) | |
if infer_tags: | |
self.infer_tags(instances, img) | |
if save_visualization: | |
out_file = osp.join(save_dir, img_name) | |
self.save_visualization(out_file, img, instances) | |
if save_annotation: | |
im_h, im_w = img.shape[:2] | |
image_meta.append({ | |
"id": image_id,"height": im_h,"width": im_w, | |
"file_name": img_name, "id": image_id | |
}) | |
if instances is not None: | |
for ii in range(len(instances)): | |
segmentation = instances.masks[ii].squeeze().cpu().numpy().astype(np.uint8) | |
area = segmentation.sum() | |
segmentation *= 255 | |
if save_mask_only: | |
cv2.imwrite(osp.join(save_dir, 'mask_' + str(ii).zfill(3) + '_' +img_name+'.png'), segmentation) | |
else: | |
score = instances.scores[ii] | |
if isinstance(score, torch.Tensor): | |
score = score.item() | |
score = float(score) | |
bbox = instances.bboxes[ii].cpu().numpy() | |
bbox = bbox.astype(np.float32).tolist() | |
segmentation = mask2rle(segmentation) | |
tag_string = instances.tags[ii] | |
tag_string_character = instances.character_tags[ii] | |
det_annotations.append({'id': obj_id, 'category_id': 0, 'iscrowd': 0, 'score': score, | |
'segmentation': segmentation, 'image_id': image_id, 'area': area, | |
'tag_string': tag_string, 'tag_string_character': tag_string_character, 'bbox': bbox | |
}) | |
obj_id += 1 | |
image_id += 1 | |
if save_annotation != '' and not save_mask_only: | |
det_meta = {"info": {},"licenses": [], "images": image_meta, | |
"annotations": det_annotations, "categories": CATEGORIES} | |
detp = save_annotation | |
dict2json(det_meta, detp) | |
LOGGER.info(f'annotations saved to {detp}') | |
def set_refine_method(self, refine_method: str = 'none', refine_size: int = 720): | |
if refine_method == 'none': | |
self.postprocess_refine = None | |
elif refine_method == 'animeseg': | |
if self.refinenet_animeseg is None: | |
self.refinenet_animeseg = load_refinenet(refine_method) | |
self.postprocess_refine = lambda det_pred, img: \ | |
animeseg_refine(det_pred, img, self.refinenet_animeseg, True, refine_size) | |
elif refine_method == 'refinenet_isnet': | |
if self.refinenet is None: | |
self.refinenet = load_refinenet(refine_method) | |
self.postprocess_refine = self._postprocess_refine | |
else: | |
raise NotImplementedError(f'Invalid refine method: {refine_method}') | |
def _postprocess_refine(self, instances: AnimeInstances, img: np.ndarray, refine_size: int = 720, max_refine_batch: int = 4, **kwargs): | |
if instances.is_empty: | |
return | |
segs = instances.masks | |
is_tensor = instances.is_tensor | |
if is_tensor: | |
segs = segs.cpu().numpy() | |
segs = segs.astype(np.float32) | |
im_h, im_w = img.shape[:2] | |
masks = [] | |
with torch.no_grad(): | |
for batch, (pt, pb, pl, pr) in prepare_refine_batch(segs, img, max_refine_batch, self.device, refine_size): | |
preds = self.refinenet(batch)[0][0].sigmoid() | |
if pb == 0: | |
pb = -im_h | |
if pr == 0: | |
pr = -im_w | |
preds = preds[..., pt: -pb, pl: -pr] | |
preds = torch.nn.functional.interpolate(preds, (im_h, im_w), mode='bilinear', align_corners=True) | |
masks.append(preds.cpu()[:, 0]) | |
masks = (torch.concat(masks, dim=0) > self.mask_thr).to(self.device) | |
if not is_tensor: | |
masks = masks.cpu().numpy() | |
instances.masks = masks | |
def prepare_data_pipeline(self, imgs: Union[str, np.ndarray, List], det_size: int) -> Tuple[Compose, List, str]: | |
if det_size is None: | |
det_size = self.default_det_size | |
target_dir = './workspace/output' | |
# cast imgs to a list of np.ndarray or image_file_path if necessary | |
if isinstance(imgs, str): | |
if osp.isdir(imgs): | |
target_dir = imgs | |
imgs = find_all_imgs(imgs, abs_path=True) | |
elif osp.isfile(imgs): | |
target_dir = osp.dirname(imgs) | |
imgs = [imgs] | |
elif isinstance(imgs, np.ndarray) or isinstance(imgs, str): | |
imgs = [imgs] | |
elif isinstance(imgs, List): | |
if len(imgs) > 0: | |
if isinstance(imgs[0], np.ndarray) or isinstance(imgs[0], str): | |
pass | |
else: | |
raise NotImplementedError | |
else: | |
raise NotImplementedError | |
test_pipeline = lambda img: single_image_preprocess(img, pipeline=self.default_data_pipeline) | |
return test_pipeline, imgs, target_dir | |
def save_visualization(self, out_file: str, img: np.ndarray, instances: AnimeInstances): | |
drawed = instances.draw_instances(img) | |
mmcv.imwrite(drawed, out_file) | |
def postprocess_results(self, results: DetDataSample, img: np.ndarray) -> None: | |
if self.postprocess_refine is not None: | |
self.postprocess_refine(results, img) | |
def set_mask_threshold(self, mask_thr: float): | |
self.model.bbox_head.test_cfg['mask_thr_binary'] = mask_thr | |
def set_max_instance(self, num_ins): | |
self.model.bbox_head.test_cfg['max_per_img'] = num_ins |