File size: 3,438 Bytes
9b3b1b4
2c53c97
9b3b1b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4de9f71
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import io
import torch
import numpy as np
from PIL import Image
import cv2
import pytorch_lightning as pl
from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet

model = None

def get_mask(model, input_img):
    h, w = input_img.shape[0], input_img.shape[1]
    ph, pw = 0, 0
    tmpImg = np.zeros([h, w, 3], dtype=np.float16)
    tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
    tmpImg = tmpImg.transpose((2, 0, 1))
    tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
    with torch.no_grad():
        pred = model(tmpImg)
        pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
        pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
        return pred

def get_net(net_name):
    if net_name == "isnet":
        return ISNetDIS()
    elif net_name == "isnet_is":
        return ISNetDIS()
    elif net_name == "isnet_gt":
        return ISNetGTEncoder()
    elif net_name == "u2net":
        return U2NET_full2()
    elif net_name == "u2netl":
        return U2NET_lite2()
    elif net_name == "modnet":
        return MODNet()
    raise NotImplemented

# from anime-segmentation.train
class AnimeSegmentation(pl.LightningModule):
    def __init__(self, net_name):
        super().__init__()
        assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
        self.net = get_net(net_name)
        if net_name == "isnet_is":
            self.gt_encoder = get_net("isnet_gt")
            for param in self.gt_encoder.parameters():
                param.requires_grad = False
        else:
            self.gt_encoder = None

    @classmethod
    def try_load(cls, net_name, ckpt_path, map_location=None):
        state_dict = torch.load(ckpt_path, map_location=map_location)
        if "epoch" in state_dict:
            return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
        else:
            model = cls(net_name)
            if any([k.startswith("net.") for k, v in state_dict.items()]):
                model.load_state_dict(state_dict)
            else:
                model.net.load_state_dict(state_dict)
            return model

    def forward(self, x):
        if isinstance(self.net, ISNetDIS):
            return self.net(x)[0][0].sigmoid()
        if isinstance(self.net, ISNetGTEncoder):
            return self.net(x)[0][0].sigmoid()
        elif isinstance(self.net, U2NET):
            return self.net(x)[0].sigmoid()
        elif isinstance(self.net, MODNet):
            return self.net(x, True)[2]
        raise NotImplemented

def load_model():
    global model

    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    
    model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device)
    model.eval()
    model.to(device)
    
def animeseg(image):
    global model

    if not image:
        return None

    if not model:
        model = load_model()

    img = np.array(image, dtype=np.uint8)
    mask = get_mask(model, img)
    img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
    return img

def pil_to_webp(img):
    buffer = io.BytesIO()
    img.save(buffer, 'webp')

    return buffer.getvalue()

def bin_to_base64(bin):
    return base64.b64encode(bin).decode('ascii')