katanuki / fn.py
aka7774's picture
Upload 2 files
2c53c97 verified
raw
history blame contribute delete
No virus
3.44 kB
import os
import io
import torch
import numpy as np
from PIL import Image
import cv2
import pytorch_lightning as pl
from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet
model = None
def get_mask(model, input_img):
h, w = input_img.shape[0], input_img.shape[1]
ph, pw = 0, 0
tmpImg = np.zeros([h, w, 3], dtype=np.float16)
tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
tmpImg = tmpImg.transpose((2, 0, 1))
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
with torch.no_grad():
pred = model(tmpImg)
pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
return pred
def get_net(net_name):
if net_name == "isnet":
return ISNetDIS()
elif net_name == "isnet_is":
return ISNetDIS()
elif net_name == "isnet_gt":
return ISNetGTEncoder()
elif net_name == "u2net":
return U2NET_full2()
elif net_name == "u2netl":
return U2NET_lite2()
elif net_name == "modnet":
return MODNet()
raise NotImplemented
# from anime-segmentation.train
class AnimeSegmentation(pl.LightningModule):
def __init__(self, net_name):
super().__init__()
assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
self.net = get_net(net_name)
if net_name == "isnet_is":
self.gt_encoder = get_net("isnet_gt")
for param in self.gt_encoder.parameters():
param.requires_grad = False
else:
self.gt_encoder = None
@classmethod
def try_load(cls, net_name, ckpt_path, map_location=None):
state_dict = torch.load(ckpt_path, map_location=map_location)
if "epoch" in state_dict:
return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
else:
model = cls(net_name)
if any([k.startswith("net.") for k, v in state_dict.items()]):
model.load_state_dict(state_dict)
else:
model.net.load_state_dict(state_dict)
return model
def forward(self, x):
if isinstance(self.net, ISNetDIS):
return self.net(x)[0][0].sigmoid()
if isinstance(self.net, ISNetGTEncoder):
return self.net(x)[0][0].sigmoid()
elif isinstance(self.net, U2NET):
return self.net(x)[0].sigmoid()
elif isinstance(self.net, MODNet):
return self.net(x, True)[2]
raise NotImplemented
def load_model():
global model
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device)
model.eval()
model.to(device)
def animeseg(image):
global model
if not image:
return None
if not model:
model = load_model()
img = np.array(image, dtype=np.uint8)
mask = get_mask(model, img)
img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
return img
def pil_to_webp(img):
buffer = io.BytesIO()
img.save(buffer, 'webp')
return buffer.getvalue()
def bin_to_base64(bin):
return base64.b64encode(bin).decode('ascii')