aka7774 commited on
Commit
9b3b1b4
1 Parent(s): 14edfb9

Upload 7 files

Browse files
Files changed (6) hide show
  1. app.py +5 -91
  2. fn.py +101 -0
  3. install.bat +56 -0
  4. main.py +43 -0
  5. requirements.txt +3 -0
  6. venv.sh +7 -0
app.py CHANGED
@@ -1,94 +1,7 @@
 
1
  import gradio as gr
2
- import os
3
- import torch
4
- import numpy as np
5
- from PIL import Image
6
- import cv2
7
- import pytorch_lightning as pl
8
- from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet
9
 
10
- def get_mask(model, input_img):
11
- h, w = input_img.shape[0], input_img.shape[1]
12
- ph, pw = 0, 0
13
- tmpImg = np.zeros([h, w, 3], dtype=np.float16)
14
- tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
15
- tmpImg = tmpImg.transpose((2, 0, 1))
16
- tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
17
- with torch.no_grad():
18
- pred = model(tmpImg)
19
- pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
20
- pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
21
- return pred
22
-
23
- def get_net(net_name):
24
- if net_name == "isnet":
25
- return ISNetDIS()
26
- elif net_name == "isnet_is":
27
- return ISNetDIS()
28
- elif net_name == "isnet_gt":
29
- return ISNetGTEncoder()
30
- elif net_name == "u2net":
31
- return U2NET_full2()
32
- elif net_name == "u2netl":
33
- return U2NET_lite2()
34
- elif net_name == "modnet":
35
- return MODNet()
36
- raise NotImplemented
37
-
38
- # from anime-segmentation.train
39
- class AnimeSegmentation(pl.LightningModule):
40
- def __init__(self, net_name):
41
- super().__init__()
42
- assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
43
- self.net = get_net(net_name)
44
- if net_name == "isnet_is":
45
- self.gt_encoder = get_net("isnet_gt")
46
- for param in self.gt_encoder.parameters():
47
- param.requires_grad = False
48
- else:
49
- self.gt_encoder = None
50
-
51
- @classmethod
52
- def try_load(cls, net_name, ckpt_path, map_location=None):
53
- state_dict = torch.load(ckpt_path, map_location=map_location)
54
- if "epoch" in state_dict:
55
- return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
56
- else:
57
- model = cls(net_name)
58
- if any([k.startswith("net.") for k, v in state_dict.items()]):
59
- model.load_state_dict(state_dict)
60
- else:
61
- model.net.load_state_dict(state_dict)
62
- return model
63
-
64
- def forward(self, x):
65
- if isinstance(self.net, ISNetDIS):
66
- return self.net(x)[0][0].sigmoid()
67
- if isinstance(self.net, ISNetGTEncoder):
68
- return self.net(x)[0][0].sigmoid()
69
- elif isinstance(self.net, U2NET):
70
- return self.net(x)[0].sigmoid()
71
- elif isinstance(self.net, MODNet):
72
- return self.net(x, True)[2]
73
- raise NotImplemented
74
-
75
- def animeseg(image):
76
- if not image:
77
- return None
78
-
79
- if torch.cuda.is_available():
80
- device = 'cuda'
81
- else:
82
- device = 'cpu'
83
-
84
- model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device)
85
- model.eval()
86
- model.to(device)
87
-
88
- img = np.array(image, dtype=np.uint8)
89
- mask = get_mask(model, img)
90
- img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
91
- return img
92
 
93
  with gr.Blocks() as demo:
94
  title = gr.Markdown('# katanuki')
@@ -97,9 +10,10 @@ with gr.Blocks() as demo:
97
  dst_image = gr.Image(label="Result", interactive=False, type="numpy")
98
 
99
  src_image.change(
100
- fn=animeseg,
101
  inputs=[src_image],
102
  outputs=[dst_image],
103
  )
104
 
105
- demo.launch()
 
 
1
+ import fn
2
  import gradio as gr
 
 
 
 
 
 
 
3
 
4
+ fn.load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  with gr.Blocks() as demo:
7
  title = gr.Markdown('# katanuki')
 
10
  dst_image = gr.Image(label="Result", interactive=False, type="numpy")
11
 
12
  src_image.change(
13
+ fn=fn.animeseg,
14
  inputs=[src_image],
15
  outputs=[dst_image],
16
  )
17
 
18
+ if __name__ == '__main__':
19
+ demo.launch()
fn.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import pytorch_lightning as pl
8
+ from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet
9
+
10
+ model = None
11
+
12
+ def get_mask(model, input_img):
13
+ h, w = input_img.shape[0], input_img.shape[1]
14
+ ph, pw = 0, 0
15
+ tmpImg = np.zeros([h, w, 3], dtype=np.float16)
16
+ tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
17
+ tmpImg = tmpImg.transpose((2, 0, 1))
18
+ tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
19
+ with torch.no_grad():
20
+ pred = model(tmpImg)
21
+ pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
22
+ pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
23
+ return pred
24
+
25
+ def get_net(net_name):
26
+ if net_name == "isnet":
27
+ return ISNetDIS()
28
+ elif net_name == "isnet_is":
29
+ return ISNetDIS()
30
+ elif net_name == "isnet_gt":
31
+ return ISNetGTEncoder()
32
+ elif net_name == "u2net":
33
+ return U2NET_full2()
34
+ elif net_name == "u2netl":
35
+ return U2NET_lite2()
36
+ elif net_name == "modnet":
37
+ return MODNet()
38
+ raise NotImplemented
39
+
40
+ # from anime-segmentation.train
41
+ class AnimeSegmentation(pl.LightningModule):
42
+ def __init__(self, net_name):
43
+ super().__init__()
44
+ assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
45
+ self.net = get_net(net_name)
46
+ if net_name == "isnet_is":
47
+ self.gt_encoder = get_net("isnet_gt")
48
+ for param in self.gt_encoder.parameters():
49
+ param.requires_grad = False
50
+ else:
51
+ self.gt_encoder = None
52
+
53
+ @classmethod
54
+ def try_load(cls, net_name, ckpt_path, map_location=None):
55
+ state_dict = torch.load(ckpt_path, map_location=map_location)
56
+ if "epoch" in state_dict:
57
+ return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
58
+ else:
59
+ model = cls(net_name)
60
+ if any([k.startswith("net.") for k, v in state_dict.items()]):
61
+ model.load_state_dict(state_dict)
62
+ else:
63
+ model.net.load_state_dict(state_dict)
64
+ return model
65
+
66
+ def forward(self, x):
67
+ if isinstance(self.net, ISNetDIS):
68
+ return self.net(x)[0][0].sigmoid()
69
+ if isinstance(self.net, ISNetGTEncoder):
70
+ return self.net(x)[0][0].sigmoid()
71
+ elif isinstance(self.net, U2NET):
72
+ return self.net(x)[0].sigmoid()
73
+ elif isinstance(self.net, MODNet):
74
+ return self.net(x, True)[2]
75
+ raise NotImplemented
76
+
77
+ def load_model():
78
+ global model
79
+
80
+ if torch.cuda.is_available():
81
+ device = 'cuda'
82
+ else:
83
+ device = 'cpu'
84
+
85
+ model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device)
86
+ model.eval()
87
+ model.to(device)
88
+
89
+ def animeseg(image):
90
+ global model
91
+
92
+ if not image:
93
+ return None
94
+
95
+ if not model:
96
+ model = load_model()
97
+
98
+ img = np.array(image, dtype=np.uint8)
99
+ mask = get_mask(model, img)
100
+ img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
101
+ return img
install.bat ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ rem -------------------------------------------
4
+ rem NOT guaranteed to work on Windows
5
+
6
+ set REPOS=https://huggingface.co/spaces/aka7774/katanuki
7
+ set APPDIR=katanuki
8
+ set VENV=venv
9
+
10
+ rem -------------------------------------------
11
+
12
+ set INSTALL_DIR=%~dp0
13
+ cd /d %INSTALL_DIR%
14
+
15
+ :git_clone
16
+ set DL_URL=%REPOS%
17
+ set DL_DST=%APPDIR%
18
+ git clone %DL_URL% %APPDIR%
19
+ if exist %DL_DST% goto install_python
20
+
21
+ set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.41.0.windows.3/PortableGit-2.41.0.3-64-bit.7z.exe
22
+ set DL_DST=PortableGit-2.41.0.3-64-bit.7z.exe
23
+ curl -L -o %DL_DST% %DL_URL%
24
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
25
+ %DL_DST% -y
26
+ del %DL_DST%
27
+
28
+ set GIT=%INSTALL_DIR%PortableGit\bin\git
29
+ %GIT% clone %REPOS%
30
+
31
+ :install_python
32
+ set DL_URL=https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.10.13+20240107-i686-pc-windows-msvc-shared-install_only.tar.gz
33
+ set DL_DST="%INSTALL_DIR%python.tar.gz"
34
+ curl -L -o %DL_DST% %DL_URL%
35
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
36
+ tar -xzf %DL_DST%
37
+
38
+ set PYTHON=%INSTALL_DIR%python\python.exe
39
+ set PATH=%PATH%;%INSTALL_DIR%python310\Scripts
40
+
41
+ :install_venv
42
+ cd %APPDIR%
43
+ %PYTHON% -m venv %VENV%
44
+ set PYTHON=%VENV%\Scripts\python.exe
45
+
46
+ :install_pip
47
+ set DL_URL=https://bootstrap.pypa.io/get-pip.py
48
+ set DL_DST=%INSTALL_DIR%get-pip.py
49
+ curl -o %DL_DST% %DL_URL%
50
+ if not exist %DL_DST% bitsadmin /transfer dl %DL_URL% %DL_DST%
51
+ %PYTHON% %DL_DST%
52
+
53
+ %PYTHON% -m pip install gradio
54
+ %PYTHON% -m pip install -r requirements.txt
55
+
56
+ pause
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import signal
5
+ import psutil
6
+ import io
7
+
8
+ from fastapi import FastAPI, Request, status, Form, UploadFile
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ from fastapi.exceptions import RequestValidationError
13
+ from fastapi.responses import JSONResponse
14
+
15
+ import fn
16
+ import gradio as gr
17
+ from app import demo
18
+
19
+ app = FastAPI()
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=['*'],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ gr.mount_gradio_app(app, demo, path="/gradio")
30
+
31
+ fn.load_model()
32
+
33
+ @app.post("/katanuki")
34
+ async def katanuki_image(file: UploadFile = Form(...)):
35
+ try:
36
+ file_content = await file.read()
37
+ file_stream = io.BytesIO(file_content)
38
+
39
+ dst_image = fn.animeseg(src_image)
40
+
41
+ return Response(content=dst_image, media_type="image/webp")
42
+ except Exception as e:
43
+ return {"error": str(e)}
requirements.txt CHANGED
@@ -1,6 +1,9 @@
 
 
1
  opencv-python
2
  pytorch_lightning
3
  torch
4
  torchvision
5
  numpy
6
  scipy
 
 
1
+ fastapi
2
+ uvicorn
3
  opencv-python
4
  pytorch_lightning
5
  torch
6
  torchvision
7
  numpy
8
  scipy
9
+ python-multipart
venv.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ python3 -m venv venv
4
+ curl -kL https://bootstrap.pypa.io/get-pip.py | venv/bin/python
5
+
6
+ venv/bin/python -m pip install gradio
7
+ venv/bin/python -m pip install -r requirements.txt