skytnt commited on
Commit
5f2c171
1 Parent(s): 17e3440

add remove background before encode img

Browse files
Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -82,14 +82,16 @@ class Model:
82
  self.detector_stride = None
83
  self.detector_imgsz = None
84
  self.detector_class_names = None
 
85
  self.w_avg = None
86
- self.load_models("skytnt/fbanime-gan")
87
 
88
- def load_models(self, repo):
89
- g_mapping_path = huggingface_hub.hf_hub_download(repo, "g_mapping.onnx")
90
- g_synthesis_path = huggingface_hub.hf_hub_download(repo, "g_synthesis.onnx")
91
- encoder_path = huggingface_hub.hf_hub_download(repo, "encoder.onnx")
92
- detector_path = huggingface_hub.hf_hub_download(repo, "waifu_dect.onnx")
 
93
 
94
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
95
  g_mapping = onnx.load(g_mapping_path)
@@ -105,6 +107,7 @@ class Model:
105
  self.detector_stride = int(detector_meta['stride'])
106
  self.detector_imgsz = 1088
107
  self.detector_class_names = eval(detector_meta['names'])
 
108
 
109
  def get_img(self, w, noise=0):
110
  img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
@@ -113,6 +116,23 @@ class Model:
113
  def get_w(self, z, psi1, psi2):
114
  return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def encode_img(self, img):
117
  img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
118
  np.float32)
@@ -236,6 +256,7 @@ def gen_fn(method, seed, psi1, psi2, noise):
236
  def encode_img_fn(img, noise):
237
  if img is None:
238
  return "please upload a image", None, None, None, None
 
239
  imgs = model.detect(img, 0.2, 0.03)
240
  if len(imgs) == 0:
241
  return "failed to detect waifu", None, None, None, None
 
82
  self.detector_stride = None
83
  self.detector_imgsz = None
84
  self.detector_class_names = None
85
+ self.anime_seg = None
86
  self.w_avg = None
87
+ self.load_models()
88
 
89
+ def load_models(self):
90
+ g_mapping_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "g_mapping.onnx")
91
+ g_synthesis_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "g_synthesis.onnx")
92
+ encoder_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "encoder.onnx")
93
+ detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
94
+ anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
95
 
96
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
97
  g_mapping = onnx.load(g_mapping_path)
 
107
  self.detector_stride = int(detector_meta['stride'])
108
  self.detector_imgsz = 1088
109
  self.detector_class_names = eval(detector_meta['names'])
110
+ self.anime_seg = rt.InferenceSession(anime_seg_path, providers=providers)
111
 
112
  def get_img(self, w, noise=0):
113
  img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
 
116
  def get_w(self, z, psi1, psi2):
117
  return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
118
 
119
+ def remove_bg(self, img, s=1024):
120
+ img0 = img
121
+ img = (img / 255).astype(np.float32)
122
+ h, w = h0, w0 = img.shape[:-1]
123
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
124
+ ph, pw = s - h, s - w
125
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
126
+ img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = transform.resize(img, (h, w))
127
+ img_input = np.transpose(img_input, (2, 0, 1))
128
+ img_input = img_input[np.newaxis, :]
129
+ mask = self.anime_seg.run(None, {'img': img_input})[0][0]
130
+ mask = np.transpose(mask, (1, 2, 0))
131
+ mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
132
+ mask = transform.resize(mask, (h0, w0))
133
+ img0 = (img0*mask + 255*(1-mask)).astype(np.uint8)
134
+ return img0
135
+
136
  def encode_img(self, img):
137
  img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
138
  np.float32)
 
256
  def encode_img_fn(img, noise):
257
  if img is None:
258
  return "please upload a image", None, None, None, None
259
+ img = model.remove_bg(img)
260
  imgs = model.detect(img, 0.2, 0.03)
261
  if len(imgs) == 0:
262
  return "failed to detect waifu", None, None, None, None