import gradio as gr import os import torch import numpy as np from PIL import Image import cv2 import pytorch_lightning as pl from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet model = None def get_mask(model, input_img): h, w = input_img.shape[0], input_img.shape[1] ph, pw = 0, 0 tmpImg = np.zeros([h, w, 3], dtype=np.float16) tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255 tmpImg = tmpImg.transpose((2, 0, 1)) tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device) with torch.no_grad(): pred = model(tmpImg) pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis] return pred def get_net(net_name): if net_name == "isnet": return ISNetDIS() elif net_name == "isnet_is": return ISNetDIS() elif net_name == "isnet_gt": return ISNetGTEncoder() elif net_name == "u2net": return U2NET_full2() elif net_name == "u2netl": return U2NET_lite2() elif net_name == "modnet": return MODNet() raise NotImplemented # from anime-segmentation.train class AnimeSegmentation(pl.LightningModule): def __init__(self, net_name): super().__init__() assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"] self.net = get_net(net_name) if net_name == "isnet_is": self.gt_encoder = get_net("isnet_gt") for param in self.gt_encoder.parameters(): param.requires_grad = False else: self.gt_encoder = None @classmethod def try_load(cls, net_name, ckpt_path, map_location=None): state_dict = torch.load(ckpt_path, map_location=map_location) if "epoch" in state_dict: return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location) else: model = cls(net_name) if any([k.startswith("net.") for k, v in state_dict.items()]): model.load_state_dict(state_dict) else: model.net.load_state_dict(state_dict) return model def forward(self, x): if isinstance(self.net, ISNetDIS): return self.net(x)[0][0].sigmoid() if isinstance(self.net, ISNetGTEncoder): return self.net(x)[0][0].sigmoid() elif isinstance(self.net, U2NET): return self.net(x)[0].sigmoid() elif isinstance(self.net, MODNet): return self.net(x, True)[2] raise NotImplemented def load_model(): global model if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device) model.eval() model.to(device) def animeseg(image): global model if not image: return None if not model: model = load_model() img = np.array(image, dtype=np.uint8) mask = get_mask(model, img) img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8) return img