skytnt commited on
Commit
2f3d724
·
1 Parent(s): e7f3d00

update model

Browse files
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import imageio
3
  import numpy as np
@@ -93,14 +94,15 @@ class Model:
93
  detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
94
  anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
95
 
96
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
97
  g_mapping = onnx.load(g_mapping_path)
98
  w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
99
  w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
100
  w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
101
  self.w_avg = w_avg
102
- self.g_mapping = rt.InferenceSession(g_mapping_path, providers=providers)
103
- self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=providers)
104
  self.encoder = rt.InferenceSession(encoder_path, providers=providers)
105
  self.detector = rt.InferenceSession(detector_path, providers=providers)
106
  detector_meta = self.detector.get_modelmeta().custom_metadata_map
@@ -130,7 +132,7 @@ class Model:
130
  mask = np.transpose(mask, (1, 2, 0))
131
  mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
132
  mask = transform.resize(mask, (h0, w0))
133
- img0 = (img0*mask + 255*(1-mask)).astype(np.uint8)
134
  return img0
135
 
136
  def encode_img(self, img):
@@ -247,10 +249,12 @@ def get_thumbnail(img):
247
 
248
 
249
  def gen_fn(method, seed, psi1, psi2, noise):
250
- z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if method == 1 else np.random.randn(1, 512)
 
 
251
  w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
252
  img_out = model.get_img(w, noise)
253
- return img_out, w, get_thumbnail(img_out)
254
 
255
 
256
  def encode_img_fn(img, noise):
@@ -259,7 +263,7 @@ def encode_img_fn(img, noise):
259
  img = model.remove_bg(img)
260
  imgs = model.detect(img, 0.2, 0.03)
261
  if len(imgs) == 0:
262
- return "failed to detect waifu", None, None, None, None
263
  w = model.encode_img(imgs[0])
264
  img_out = model.get_img(w, noise)
265
  return "success", imgs[0], img_out, w, get_thumbnail(img_out)
@@ -278,8 +282,7 @@ if __name__ == '__main__':
278
  app = gr.Blocks()
279
  with app:
280
  gr.Markdown("# full-body anime GAN\n\n"
281
- "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.full-body-anime-gan)\n\n"
282
- "the model is not well, just use for fun.")
283
  with gr.Tabs():
284
  with gr.TabItem("generate image"):
285
  with gr.Row():
@@ -287,9 +290,9 @@ if __name__ == '__main__':
287
  gr.Markdown("generate image randomly or by seed")
288
  with gr.Row():
289
  gen_input1 = gr.Radio(label="method", value="random",
290
- choices=["random", "use seed"], type="index")
291
- gen_input2 = gr.Number(value=1, label="seed ( int between -2^31 and 2^31 - 1 )")
292
- gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi 1")
293
  gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
294
  gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
295
  with gr.Group():
@@ -304,7 +307,7 @@ if __name__ == '__main__':
304
  with gr.Column():
305
  gr.Markdown("you'd better upload a standing full-body image")
306
  encode_img_input = gr.Image(label="input image")
307
- examples_data = [[f"examples/{x:02d}.png"] for x in range(1, 5)]
308
  encode_img_examples = gr.Dataset(components=[encode_img_input], samples=examples_data)
309
  with gr.Group():
310
  encode_img_submit = gr.Button("Run", variant="primary")
@@ -319,11 +322,10 @@ if __name__ == '__main__':
319
  with gr.TabItem("generate video"):
320
  with gr.Row():
321
  with gr.Column():
322
- gr.Markdown("## generate video between 2 images")
323
  with gr.Row():
324
  with gr.Column():
325
- gr.Markdown("please select image 1")
326
- select_img1_dropdown = gr.Radio(label="source", value="current generated image",
327
  choices=["current generated image",
328
  "current encoded image"], type="index")
329
  with gr.Group():
@@ -331,8 +333,7 @@ if __name__ == '__main__':
331
  select_img1_output_img = gr.Image(label="selected image 1")
332
  select_img1_output_w = gr.Variable()
333
  with gr.Column():
334
- gr.Markdown("please select image 2")
335
- select_img2_dropdown = gr.Radio(label="source", value="current generated image",
336
  choices=["current generated image",
337
  "current encoded image"], type="index")
338
  with gr.Group():
@@ -345,7 +346,7 @@ if __name__ == '__main__':
345
  with gr.Column():
346
  generate_video_output = gr.Video(label="output video")
347
  gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
348
- [gen_output1, select_img_input_w1, select_img_input_img1])
349
  encode_img_submit.click(encode_img_fn, [encode_img_input, gen_input5],
350
  [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
351
  select_img_input_img2])
 
1
+ import random
2
  import gradio as gr
3
  import imageio
4
  import numpy as np
 
94
  detector_path = huggingface_hub.hf_hub_download("skytnt/fbanime-gan", "waifu_dect.onnx")
95
  anime_seg_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
96
 
97
+ providers = ['CPUExecutionProvider']
98
+ gpu_providers = ['CUDAExecutionProvider']
99
  g_mapping = onnx.load(g_mapping_path)
100
  w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
101
  w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
102
  w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
103
  self.w_avg = w_avg
104
+ self.g_mapping = rt.InferenceSession(g_mapping_path, providers=gpu_providers + providers)
105
+ self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=gpu_providers + providers)
106
  self.encoder = rt.InferenceSession(encoder_path, providers=providers)
107
  self.detector = rt.InferenceSession(detector_path, providers=providers)
108
  detector_meta = self.detector.get_modelmeta().custom_metadata_map
 
132
  mask = np.transpose(mask, (1, 2, 0))
133
  mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
134
  mask = transform.resize(mask, (h0, w0))
135
+ img0 = (img0 * mask + 255 * (1 - mask)).astype(np.uint8)
136
  return img0
137
 
138
  def encode_img(self, img):
 
249
 
250
 
251
  def gen_fn(method, seed, psi1, psi2, noise):
252
+ if method == 0:
253
+ seed = random.randint(0, 2 ** 32 - 1)
254
+ z = RandomState(int(seed)).randn(1, 1024)
255
  w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
256
  img_out = model.get_img(w, noise)
257
+ return img_out, seed, w, get_thumbnail(img_out)
258
 
259
 
260
  def encode_img_fn(img, noise):
 
263
  img = model.remove_bg(img)
264
  imgs = model.detect(img, 0.2, 0.03)
265
  if len(imgs) == 0:
266
+ return "failed to detect anime character", None, None, None, None
267
  w = model.encode_img(imgs[0])
268
  img_out = model.get_img(w, noise)
269
  return "success", imgs[0], img_out, w, get_thumbnail(img_out)
 
282
  app = gr.Blocks()
283
  with app:
284
  gr.Markdown("# full-body anime GAN\n\n"
285
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.full-body-anime-gan)\n\n")
 
286
  with gr.Tabs():
287
  with gr.TabItem("generate image"):
288
  with gr.Row():
 
290
  gr.Markdown("generate image randomly or by seed")
291
  with gr.Row():
292
  gen_input1 = gr.Radio(label="method", value="random",
293
+ choices=["random", "seed"], type="index")
294
+ gen_input2 = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, value=0, label="seed")
295
+ gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 1")
296
  gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
297
  gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
298
  with gr.Group():
 
307
  with gr.Column():
308
  gr.Markdown("you'd better upload a standing full-body image")
309
  encode_img_input = gr.Image(label="input image")
310
+ examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 5)]
311
  encode_img_examples = gr.Dataset(components=[encode_img_input], samples=examples_data)
312
  with gr.Group():
313
  encode_img_submit = gr.Button("Run", variant="primary")
 
322
  with gr.TabItem("generate video"):
323
  with gr.Row():
324
  with gr.Column():
325
+ gr.Markdown("generate video between 2 images")
326
  with gr.Row():
327
  with gr.Column():
328
+ select_img1_dropdown = gr.Radio(label="Select image 1", value="current generated image",
 
329
  choices=["current generated image",
330
  "current encoded image"], type="index")
331
  with gr.Group():
 
333
  select_img1_output_img = gr.Image(label="selected image 1")
334
  select_img1_output_w = gr.Variable()
335
  with gr.Column():
336
+ select_img2_dropdown = gr.Radio(label="Select image 2", value="current generated image",
 
337
  choices=["current generated image",
338
  "current encoded image"], type="index")
339
  with gr.Group():
 
346
  with gr.Column():
347
  generate_video_output = gr.Video(label="output video")
348
  gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
349
+ [gen_output1, gen_input2, select_img_input_w1, select_img_input_img1])
350
  encode_img_submit.click(encode_img_fn, [encode_img_input, gen_input5],
351
  [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
352
  select_img_input_img2])
examples/01.jpg ADDED
examples/01.png DELETED
Binary file (405 kB)
 
examples/02.jpg ADDED
examples/02.png DELETED
Binary file (331 kB)
 
examples/03.jpg ADDED
examples/03.png DELETED
Binary file (369 kB)
 
examples/04.jpg ADDED
examples/04.png DELETED
Binary file (452 kB)