|
import os |
|
import io |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
import pytorch_lightning as pl |
|
from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet |
|
|
|
model = None |
|
|
|
def get_mask(model, input_img): |
|
h, w = input_img.shape[0], input_img.shape[1] |
|
ph, pw = 0, 0 |
|
tmpImg = np.zeros([h, w, 3], dtype=np.float16) |
|
tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255 |
|
tmpImg = tmpImg.transpose((2, 0, 1)) |
|
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device) |
|
with torch.no_grad(): |
|
pred = model(tmpImg) |
|
pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] |
|
pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis] |
|
return pred |
|
|
|
def get_net(net_name): |
|
if net_name == "isnet": |
|
return ISNetDIS() |
|
elif net_name == "isnet_is": |
|
return ISNetDIS() |
|
elif net_name == "isnet_gt": |
|
return ISNetGTEncoder() |
|
elif net_name == "u2net": |
|
return U2NET_full2() |
|
elif net_name == "u2netl": |
|
return U2NET_lite2() |
|
elif net_name == "modnet": |
|
return MODNet() |
|
raise NotImplemented |
|
|
|
|
|
class AnimeSegmentation(pl.LightningModule): |
|
def __init__(self, net_name): |
|
super().__init__() |
|
assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"] |
|
self.net = get_net(net_name) |
|
if net_name == "isnet_is": |
|
self.gt_encoder = get_net("isnet_gt") |
|
for param in self.gt_encoder.parameters(): |
|
param.requires_grad = False |
|
else: |
|
self.gt_encoder = None |
|
|
|
@classmethod |
|
def try_load(cls, net_name, ckpt_path, map_location=None): |
|
state_dict = torch.load(ckpt_path, map_location=map_location) |
|
if "epoch" in state_dict: |
|
return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location) |
|
else: |
|
model = cls(net_name) |
|
if any([k.startswith("net.") for k, v in state_dict.items()]): |
|
model.load_state_dict(state_dict) |
|
else: |
|
model.net.load_state_dict(state_dict) |
|
return model |
|
|
|
def forward(self, x): |
|
if isinstance(self.net, ISNetDIS): |
|
return self.net(x)[0][0].sigmoid() |
|
if isinstance(self.net, ISNetGTEncoder): |
|
return self.net(x)[0][0].sigmoid() |
|
elif isinstance(self.net, U2NET): |
|
return self.net(x)[0].sigmoid() |
|
elif isinstance(self.net, MODNet): |
|
return self.net(x, True)[2] |
|
raise NotImplemented |
|
|
|
def load_model(): |
|
global model |
|
|
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
else: |
|
device = 'cpu' |
|
|
|
model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device) |
|
model.eval() |
|
model.to(device) |
|
|
|
def animeseg(image): |
|
global model |
|
|
|
if not image: |
|
return None |
|
|
|
if not model: |
|
model = load_model() |
|
|
|
img = np.array(image, dtype=np.uint8) |
|
mask = get_mask(model, img) |
|
img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8) |
|
return img |
|
|
|
def pil_to_webp(img): |
|
buffer = io.BytesIO() |
|
img.save(buffer, 'webp') |
|
|
|
return buffer.getvalue() |
|
|
|
def bin_to_base64(bin): |
|
return base64.b64encode(bin).decode('ascii') |
|
|