File size: 2,446 Bytes
395d300
 
b7d88cf
 
 
 
 
 
395d300
 
 
 
 
 
 
b7d88cf
395d300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dfc879
 
395d300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce90b79
 
a1cfe96
ce90b79
8e69757
ce90b79
a1cfe96
395d300
 
1187a16
395d300
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr

import os

os.system("mim install mmengine")
os.system('mim install "mmcv>=2.0.0"')
os.system("mim install mmdet")

import cv2
from PIL import Image
import numpy as np

from animeinsseg import AnimeInsSeg, AnimeInstances
from animeinsseg.anime_instances import get_color



if not os.path.exists("models"):
    os.mkdir("models")

os.system("huggingface-cli lfs-enable-largefiles .")
os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")

ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'

mask_thres = 0.3
instance_thres = 0.3
refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
# refine_kwargs = None

net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)

def fn(image):
    img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    instances: AnimeInstances = net.infer(
        img,
        output_type='numpy',
        pred_score_thr=instance_thres
    )

    drawed = img.copy()
    im_h, im_w = img.shape[:2]

    # instances.bboxes, instances.masks will be None, None if no obj is detected
    if instances.bboxes is None:
        return Image.fromarray(drawed[..., ::-1])

    for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
        color = get_color(ii)

        mask_alpha = 0.5
        linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)

        # draw bbox
        p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1]))
        cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
        
        # draw mask
        p = mask.astype(np.float32)
        blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
        alpha_msk = (mask_alpha * p)[..., None]
        alpha_ori = 1 - alpha_msk
        drawed = drawed * alpha_ori + alpha_msk * blend_mask

        drawed = drawed.astype(np.uint8)

    return Image.fromarray(drawed[..., ::-1])

iface = gr.Interface(
    # design titles and text descriptions
    title="Anime Subject Instance Segmentation",
    description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).",
    fn=fn,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Image(type="pil"),
    examples=["1562990.jpg", "612989.jpg", "sample_3.jpg"]
)

iface.launch()