Upload 7 files
Browse files- app.py +5 -91
- fn.py +101 -0
- install.bat +56 -0
- main.py +43 -0
- requirements.txt +3 -0
- venv.sh +7 -0
app.py
CHANGED
@@ -1,94 +1,7 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
from PIL import Image
|
6 |
-
import cv2
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet
|
9 |
|
10 |
-
|
11 |
-
h, w = input_img.shape[0], input_img.shape[1]
|
12 |
-
ph, pw = 0, 0
|
13 |
-
tmpImg = np.zeros([h, w, 3], dtype=np.float16)
|
14 |
-
tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
|
15 |
-
tmpImg = tmpImg.transpose((2, 0, 1))
|
16 |
-
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
|
17 |
-
with torch.no_grad():
|
18 |
-
pred = model(tmpImg)
|
19 |
-
pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
20 |
-
pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
|
21 |
-
return pred
|
22 |
-
|
23 |
-
def get_net(net_name):
|
24 |
-
if net_name == "isnet":
|
25 |
-
return ISNetDIS()
|
26 |
-
elif net_name == "isnet_is":
|
27 |
-
return ISNetDIS()
|
28 |
-
elif net_name == "isnet_gt":
|
29 |
-
return ISNetGTEncoder()
|
30 |
-
elif net_name == "u2net":
|
31 |
-
return U2NET_full2()
|
32 |
-
elif net_name == "u2netl":
|
33 |
-
return U2NET_lite2()
|
34 |
-
elif net_name == "modnet":
|
35 |
-
return MODNet()
|
36 |
-
raise NotImplemented
|
37 |
-
|
38 |
-
# from anime-segmentation.train
|
39 |
-
class AnimeSegmentation(pl.LightningModule):
|
40 |
-
def __init__(self, net_name):
|
41 |
-
super().__init__()
|
42 |
-
assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
|
43 |
-
self.net = get_net(net_name)
|
44 |
-
if net_name == "isnet_is":
|
45 |
-
self.gt_encoder = get_net("isnet_gt")
|
46 |
-
for param in self.gt_encoder.parameters():
|
47 |
-
param.requires_grad = False
|
48 |
-
else:
|
49 |
-
self.gt_encoder = None
|
50 |
-
|
51 |
-
@classmethod
|
52 |
-
def try_load(cls, net_name, ckpt_path, map_location=None):
|
53 |
-
state_dict = torch.load(ckpt_path, map_location=map_location)
|
54 |
-
if "epoch" in state_dict:
|
55 |
-
return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
|
56 |
-
else:
|
57 |
-
model = cls(net_name)
|
58 |
-
if any([k.startswith("net.") for k, v in state_dict.items()]):
|
59 |
-
model.load_state_dict(state_dict)
|
60 |
-
else:
|
61 |
-
model.net.load_state_dict(state_dict)
|
62 |
-
return model
|
63 |
-
|
64 |
-
def forward(self, x):
|
65 |
-
if isinstance(self.net, ISNetDIS):
|
66 |
-
return self.net(x)[0][0].sigmoid()
|
67 |
-
if isinstance(self.net, ISNetGTEncoder):
|
68 |
-
return self.net(x)[0][0].sigmoid()
|
69 |
-
elif isinstance(self.net, U2NET):
|
70 |
-
return self.net(x)[0].sigmoid()
|
71 |
-
elif isinstance(self.net, MODNet):
|
72 |
-
return self.net(x, True)[2]
|
73 |
-
raise NotImplemented
|
74 |
-
|
75 |
-
def animeseg(image):
|
76 |
-
if not image:
|
77 |
-
return None
|
78 |
-
|
79 |
-
if torch.cuda.is_available():
|
80 |
-
device = 'cuda'
|
81 |
-
else:
|
82 |
-
device = 'cpu'
|
83 |
-
|
84 |
-
model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device)
|
85 |
-
model.eval()
|
86 |
-
model.to(device)
|
87 |
-
|
88 |
-
img = np.array(image, dtype=np.uint8)
|
89 |
-
mask = get_mask(model, img)
|
90 |
-
img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
|
91 |
-
return img
|
92 |
|
93 |
with gr.Blocks() as demo:
|
94 |
title = gr.Markdown('# katanuki')
|
@@ -97,9 +10,10 @@ with gr.Blocks() as demo:
|
|
97 |
dst_image = gr.Image(label="Result", interactive=False, type="numpy")
|
98 |
|
99 |
src_image.change(
|
100 |
-
fn=animeseg,
|
101 |
inputs=[src_image],
|
102 |
outputs=[dst_image],
|
103 |
)
|
104 |
|
105 |
-
|
|
|
|
1 |
+
import fn
|
2 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
fn.load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
with gr.Blocks() as demo:
|
7 |
title = gr.Markdown('# katanuki')
|
|
|
10 |
dst_image = gr.Image(label="Result", interactive=False, type="numpy")
|
11 |
|
12 |
src_image.change(
|
13 |
+
fn=fn.animeseg,
|
14 |
inputs=[src_image],
|
15 |
outputs=[dst_image],
|
16 |
)
|
17 |
|
18 |
+
if __name__ == '__main__':
|
19 |
+
demo.launch()
|
fn.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet
|
9 |
+
|
10 |
+
model = None
|
11 |
+
|
12 |
+
def get_mask(model, input_img):
|
13 |
+
h, w = input_img.shape[0], input_img.shape[1]
|
14 |
+
ph, pw = 0, 0
|
15 |
+
tmpImg = np.zeros([h, w, 3], dtype=np.float16)
|
16 |
+
tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
|
17 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
18 |
+
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
|
19 |
+
with torch.no_grad():
|
20 |
+
pred = model(tmpImg)
|
21 |
+
pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
22 |
+
pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
|
23 |
+
return pred
|
24 |
+
|
25 |
+
def get_net(net_name):
|
26 |
+
if net_name == "isnet":
|
27 |
+
return ISNetDIS()
|
28 |
+
elif net_name == "isnet_is":
|
29 |
+
return ISNetDIS()
|
30 |
+
elif net_name == "isnet_gt":
|
31 |
+
return ISNetGTEncoder()
|
32 |
+
elif net_name == "u2net":
|
33 |
+
return U2NET_full2()
|
34 |
+
elif net_name == "u2netl":
|
35 |
+
return U2NET_lite2()
|
36 |
+
elif net_name == "modnet":
|
37 |
+
return MODNet()
|
38 |
+
raise NotImplemented
|
39 |
+
|
40 |
+
# from anime-segmentation.train
|
41 |
+
class AnimeSegmentation(pl.LightningModule):
|
42 |
+
def __init__(self, net_name):
|
43 |
+
super().__init__()
|
44 |
+
assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
|
45 |
+
self.net = get_net(net_name)
|
46 |
+
if net_name == "isnet_is":
|
47 |
+
self.gt_encoder = get_net("isnet_gt")
|
48 |
+
for param in self.gt_encoder.parameters():
|
49 |
+
param.requires_grad = False
|
50 |
+
else:
|
51 |
+
self.gt_encoder = None
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def try_load(cls, net_name, ckpt_path, map_location=None):
|
55 |
+
state_dict = torch.load(ckpt_path, map_location=map_location)
|
56 |
+
if "epoch" in state_dict:
|
57 |
+
return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
|
58 |
+
else:
|
59 |
+
model = cls(net_name)
|
60 |
+
if any([k.startswith("net.") for k, v in state_dict.items()]):
|
61 |
+
model.load_state_dict(state_dict)
|
62 |
+
else:
|
63 |
+
model.net.load_state_dict(state_dict)
|
64 |
+
return model
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
if isinstance(self.net, ISNetDIS):
|
68 |
+
return self.net(x)[0][0].sigmoid()
|
69 |
+
if isinstance(self.net, ISNetGTEncoder):
|
70 |
+
return self.net(x)[0][0].sigmoid()
|
71 |
+
elif isinstance(self.net, U2NET):
|
72 |
+
return self.net(x)[0].sigmoid()
|
73 |
+
elif isinstance(self.net, MODNet):
|
74 |
+
return self.net(x, True)[2]
|
75 |
+
raise NotImplemented
|
76 |
+
|
77 |
+
def load_model():
|
78 |
+
global model
|
79 |
+
|
80 |
+
if torch.cuda.is_available():
|
81 |
+
device = 'cuda'
|
82 |
+
else:
|
83 |
+
device = 'cpu'
|
84 |
+
|
85 |
+
model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device)
|
86 |
+
model.eval()
|
87 |
+
model.to(device)
|
88 |
+
|
89 |
+
def animeseg(image):
|
90 |
+
global model
|
91 |
+
|
92 |
+
if not image:
|
93 |
+
return None
|
94 |
+
|
95 |
+
if not model:
|
96 |
+
model = load_model()
|
97 |
+
|
98 |
+
img = np.array(image, dtype=np.uint8)
|
99 |
+
mask = get_mask(model, img)
|
100 |
+
img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
|
101 |
+
return img
|
install.bat
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
rem -------------------------------------------
|
4 |
+
rem NOT guaranteed to work on Windows
|
5 |
+
|
6 |
+
set REPOS=https://huggingface.co/spaces/aka7774/katanuki
|
7 |
+
set APPDIR=katanuki
|
8 |
+
set VENV=venv
|
9 |
+
|
10 |
+
rem -------------------------------------------
|
11 |
+
|
12 |
+
set INSTALL_DIR=%~dp0
|
13 |
+
cd /d %INSTALL_DIR%
|
14 |
+
|
15 |
+
:git_clone
|
16 |
+
set DL_URL=%REPOS%
|
17 |
+
set DL_DST=%APPDIR%
|
18 |
+
git clone %DL_URL% %APPDIR%
|
19 |
+
if exist %DL_DST% goto install_python
|
20 |
+
|
21 |
+
set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.41.0.windows.3/PortableGit-2.41.0.3-64-bit.7z.exe
|
22 |
+
set DL_DST=PortableGit-2.41.0.3-64-bit.7z.exe
|
23 |
+
curl -L -o %DL_DST% %DL_URL%
|
24 |
+
if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
|
25 |
+
%DL_DST% -y
|
26 |
+
del %DL_DST%
|
27 |
+
|
28 |
+
set GIT=%INSTALL_DIR%PortableGit\bin\git
|
29 |
+
%GIT% clone %REPOS%
|
30 |
+
|
31 |
+
:install_python
|
32 |
+
set DL_URL=https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.10.13+20240107-i686-pc-windows-msvc-shared-install_only.tar.gz
|
33 |
+
set DL_DST="%INSTALL_DIR%python.tar.gz"
|
34 |
+
curl -L -o %DL_DST% %DL_URL%
|
35 |
+
if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
|
36 |
+
tar -xzf %DL_DST%
|
37 |
+
|
38 |
+
set PYTHON=%INSTALL_DIR%python\python.exe
|
39 |
+
set PATH=%PATH%;%INSTALL_DIR%python310\Scripts
|
40 |
+
|
41 |
+
:install_venv
|
42 |
+
cd %APPDIR%
|
43 |
+
%PYTHON% -m venv %VENV%
|
44 |
+
set PYTHON=%VENV%\Scripts\python.exe
|
45 |
+
|
46 |
+
:install_pip
|
47 |
+
set DL_URL=https://bootstrap.pypa.io/get-pip.py
|
48 |
+
set DL_DST=%INSTALL_DIR%get-pip.py
|
49 |
+
curl -o %DL_DST% %DL_URL%
|
50 |
+
if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
|
51 |
+
%PYTHON% %DL_DST%
|
52 |
+
|
53 |
+
%PYTHON% -m pip install gradio
|
54 |
+
%PYTHON% -m pip install -r requirements.txt
|
55 |
+
|
56 |
+
pause
|
main.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import signal
|
5 |
+
import psutil
|
6 |
+
import io
|
7 |
+
|
8 |
+
from fastapi import FastAPI, Request, status, Form, UploadFile
|
9 |
+
from fastapi.staticfiles import StaticFiles
|
10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
11 |
+
from pydantic import BaseModel, Field
|
12 |
+
from fastapi.exceptions import RequestValidationError
|
13 |
+
from fastapi.responses import JSONResponse
|
14 |
+
|
15 |
+
import fn
|
16 |
+
import gradio as gr
|
17 |
+
from app import demo
|
18 |
+
|
19 |
+
app = FastAPI()
|
20 |
+
|
21 |
+
app.add_middleware(
|
22 |
+
CORSMiddleware,
|
23 |
+
allow_origins=['*'],
|
24 |
+
allow_credentials=True,
|
25 |
+
allow_methods=["*"],
|
26 |
+
allow_headers=["*"],
|
27 |
+
)
|
28 |
+
|
29 |
+
gr.mount_gradio_app(app, demo, path="/gradio")
|
30 |
+
|
31 |
+
fn.load_model()
|
32 |
+
|
33 |
+
@app.post("/katanuki")
|
34 |
+
async def katanuki_image(file: UploadFile = Form(...)):
|
35 |
+
try:
|
36 |
+
file_content = await file.read()
|
37 |
+
file_stream = io.BytesIO(file_content)
|
38 |
+
|
39 |
+
dst_image = fn.animeseg(src_image)
|
40 |
+
|
41 |
+
return Response(content=dst_image, media_type="image/webp")
|
42 |
+
except Exception as e:
|
43 |
+
return {"error": str(e)}
|
requirements.txt
CHANGED
@@ -1,6 +1,9 @@
|
|
|
|
|
|
1 |
opencv-python
|
2 |
pytorch_lightning
|
3 |
torch
|
4 |
torchvision
|
5 |
numpy
|
6 |
scipy
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn
|
3 |
opencv-python
|
4 |
pytorch_lightning
|
5 |
torch
|
6 |
torchvision
|
7 |
numpy
|
8 |
scipy
|
9 |
+
python-multipart
|
venv.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/bash
|
2 |
+
|
3 |
+
python3 -m venv venv
|
4 |
+
curl -kL https://bootstrap.pypa.io/get-pip.py | venv/bin/python
|
5 |
+
|
6 |
+
venv/bin/python -m pip install gradio
|
7 |
+
venv/bin/python -m pip install -r requirements.txt
|