Spaces:
Running
Running
add remove background before encode img
Browse files
app.py
CHANGED
@@ -82,14 +82,16 @@ class Model:
|
|
82 |
self.detector_stride = None
|
83 |
self.detector_imgsz = None
|
84 |
self.detector_class_names = None
|
|
|
85 |
self.w_avg = None
|
86 |
-
self.load_models(
|
87 |
|
88 |
-
def load_models(self
|
89 |
-
g_mapping_path = huggingface_hub.hf_hub_download(
|
90 |
-
g_synthesis_path = huggingface_hub.hf_hub_download(
|
91 |
-
encoder_path = huggingface_hub.hf_hub_download(
|
92 |
-
detector_path = huggingface_hub.hf_hub_download(
|
|
|
93 |
|
94 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
95 |
g_mapping = onnx.load(g_mapping_path)
|
@@ -105,6 +107,7 @@ class Model:
|
|
105 |
self.detector_stride = int(detector_meta['stride'])
|
106 |
self.detector_imgsz = 1088
|
107 |
self.detector_class_names = eval(detector_meta['names'])
|
|
|
108 |
|
109 |
def get_img(self, w, noise=0):
|
110 |
img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
|
@@ -113,6 +116,23 @@ class Model:
|
|
113 |
def get_w(self, z, psi1, psi2):
|
114 |
return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def encode_img(self, img):
|
117 |
img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
|
118 |
np.float32)
|
@@ -236,6 +256,7 @@ def gen_fn(method, seed, psi1, psi2, noise):
|
|
236 |
def encode_img_fn(img, noise):
|
237 |
if img is None:
|
238 |
return "please upload a image", None, None, None, None
|
|
|
239 |
imgs = model.detect(img, 0.2, 0.03)
|
240 |
if len(imgs) == 0:
|
241 |
return "failed to detect waifu", None, None, None, None
|
|
|
82 |
self.detector_stride = None
|
83 |
self.detector_imgsz = None
|
84 |
self.detector_class_names = None
|
85 |
+
self.anime_seg = None
|
86 |
self.w_avg = None
|
87 |
+
self.load_models()
|
88 |
|
89 |
+
def load_models(self):
|
90 |
+
g_mapping_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "g_mapping.onnx")
|
91 |
+
g_synthesis_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "g_synthesis.onnx")
|
92 |
+
encoder_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "encoder.onnx")
|
93 |
+
detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
|
94 |
+
anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
|
95 |
|
96 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
97 |
g_mapping = onnx.load(g_mapping_path)
|
|
|
107 |
self.detector_stride = int(detector_meta['stride'])
|
108 |
self.detector_imgsz = 1088
|
109 |
self.detector_class_names = eval(detector_meta['names'])
|
110 |
+
self.anime_seg = rt.InferenceSession(anime_seg_path, providers=providers)
|
111 |
|
112 |
def get_img(self, w, noise=0):
|
113 |
img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
|
|
|
116 |
def get_w(self, z, psi1, psi2):
|
117 |
return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
|
118 |
|
119 |
+
def remove_bg(self, img, s=1024):
|
120 |
+
img0 = img
|
121 |
+
img = (img / 255).astype(np.float32)
|
122 |
+
h, w = h0, w0 = img.shape[:-1]
|
123 |
+
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
124 |
+
ph, pw = s - h, s - w
|
125 |
+
img_input = np.zeros([s, s, 3], dtype=np.float32)
|
126 |
+
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = transform.resize(img, (h, w))
|
127 |
+
img_input = np.transpose(img_input, (2, 0, 1))
|
128 |
+
img_input = img_input[np.newaxis, :]
|
129 |
+
mask = self.anime_seg.run(None, {'img': img_input})[0][0]
|
130 |
+
mask = np.transpose(mask, (1, 2, 0))
|
131 |
+
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
132 |
+
mask = transform.resize(mask, (h0, w0))
|
133 |
+
img0 = (img0*mask + 255*(1-mask)).astype(np.uint8)
|
134 |
+
return img0
|
135 |
+
|
136 |
def encode_img(self, img):
|
137 |
img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
|
138 |
np.float32)
|
|
|
256 |
def encode_img_fn(img, noise):
|
257 |
if img is None:
|
258 |
return "please upload a image", None, None, None, None
|
259 |
+
img = model.remove_bg(img)
|
260 |
imgs = model.detect(img, 0.2, 0.03)
|
261 |
if len(imgs) == 0:
|
262 |
return "failed to detect waifu", None, None, None, None
|