hysts HF staff commited on
Commit
46ce1fe
1 Parent(s): 69a5598

Support raw pose image and depth image

Browse files
Files changed (3) hide show
  1. app_depth.py +3 -0
  2. app_pose.py +3 -0
  3. model.py +26 -15
app_depth.py CHANGED
@@ -13,6 +13,8 @@ def create_demo(process, max_images=12, default_num_images=3):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
@@ -67,6 +69,7 @@ def create_demo(process, max_images=12, default_num_images=3):
67
  num_steps,
68
  guidance_scale,
69
  seed,
 
70
  ]
71
  prompt.submit(fn=process, inputs=inputs, outputs=result)
72
  run_button.click(fn=process,
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_depth_image = gr.Checkbox(label='Is depth image',
17
+ value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
 
69
  num_steps,
70
  guidance_scale,
71
  seed,
72
+ is_depth_image,
73
  ]
74
  prompt.submit(fn=process, inputs=inputs, outputs=result)
75
  run_button.click(fn=process,
app_pose.py CHANGED
@@ -13,6 +13,8 @@ def create_demo(process, max_images=12, default_num_images=3):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
@@ -67,6 +69,7 @@ def create_demo(process, max_images=12, default_num_images=3):
67
  num_steps,
68
  guidance_scale,
69
  seed,
 
70
  ]
71
  prompt.submit(fn=process, inputs=inputs, outputs=result)
72
  run_button.click(fn=process,
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_pose_image = gr.Checkbox(label='Is pose image',
17
+ value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
 
69
  num_steps,
70
  guidance_scale,
71
  seed,
72
+ is_pose_image,
73
  ]
74
  prompt.submit(fn=process, inputs=inputs, outputs=result)
75
  run_button.click(fn=process,
model.py CHANGED
@@ -438,16 +438,19 @@ class Model:
438
  input_image: np.ndarray,
439
  image_resolution: int,
440
  detect_resolution: int,
 
441
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
442
  input_image = HWC3(input_image)
443
- control_image, _ = apply_openpose(
444
- resize_image(input_image, detect_resolution))
445
- control_image = HWC3(control_image)
446
- image = resize_image(input_image, image_resolution)
447
- H, W = image.shape[:2]
448
-
449
- control_image = cv2.resize(control_image, (W, H),
450
- interpolation=cv2.INTER_NEAREST)
 
 
451
 
452
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
453
  control_image)
@@ -465,11 +468,13 @@ class Model:
465
  num_steps: int,
466
  guidance_scale: float,
467
  seed: int,
 
468
  ) -> list[PIL.Image.Image]:
469
  control_image, vis_control_image = self.preprocess_pose(
470
  input_image=input_image,
471
  image_resolution=image_resolution,
472
  detect_resolution=detect_resolution,
 
473
  )
474
  return self.process(
475
  task_name='pose',
@@ -537,15 +542,19 @@ class Model:
537
  input_image: np.ndarray,
538
  image_resolution: int,
539
  detect_resolution: int,
 
540
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
541
  input_image = HWC3(input_image)
542
- control_image, _ = apply_midas(
543
- resize_image(input_image, detect_resolution))
544
- control_image = HWC3(control_image)
545
- image = resize_image(input_image, image_resolution)
546
- H, W = image.shape[:2]
547
- control_image = cv2.resize(control_image, (W, H),
548
- interpolation=cv2.INTER_LINEAR)
 
 
 
549
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
550
  control_image)
551
 
@@ -562,11 +571,13 @@ class Model:
562
  num_steps: int,
563
  guidance_scale: float,
564
  seed: int,
 
565
  ) -> list[PIL.Image.Image]:
566
  control_image, vis_control_image = self.preprocess_depth(
567
  input_image=input_image,
568
  image_resolution=image_resolution,
569
  detect_resolution=detect_resolution,
 
570
  )
571
  return self.process(
572
  task_name='depth',
 
438
  input_image: np.ndarray,
439
  image_resolution: int,
440
  detect_resolution: int,
441
+ is_pose_image: bool,
442
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
443
  input_image = HWC3(input_image)
444
+ if not is_pose_image:
445
+ control_image, _ = apply_openpose(
446
+ resize_image(input_image, detect_resolution))
447
+ control_image = HWC3(control_image)
448
+ image = resize_image(input_image, image_resolution)
449
+ H, W = image.shape[:2]
450
+ control_image = cv2.resize(control_image, (W, H),
451
+ interpolation=cv2.INTER_NEAREST)
452
+ else:
453
+ control_image = input_image
454
 
455
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
456
  control_image)
 
468
  num_steps: int,
469
  guidance_scale: float,
470
  seed: int,
471
+ is_pose_image: bool,
472
  ) -> list[PIL.Image.Image]:
473
  control_image, vis_control_image = self.preprocess_pose(
474
  input_image=input_image,
475
  image_resolution=image_resolution,
476
  detect_resolution=detect_resolution,
477
+ is_pose_image=is_pose_image,
478
  )
479
  return self.process(
480
  task_name='pose',
 
542
  input_image: np.ndarray,
543
  image_resolution: int,
544
  detect_resolution: int,
545
+ is_depth_image: bool,
546
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
547
  input_image = HWC3(input_image)
548
+ if not is_depth_image:
549
+ control_image, _ = apply_midas(
550
+ resize_image(input_image, detect_resolution))
551
+ control_image = HWC3(control_image)
552
+ image = resize_image(input_image, image_resolution)
553
+ H, W = image.shape[:2]
554
+ control_image = cv2.resize(control_image, (W, H),
555
+ interpolation=cv2.INTER_LINEAR)
556
+ else:
557
+ control_image = input_image
558
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
559
  control_image)
560
 
 
571
  num_steps: int,
572
  guidance_scale: float,
573
  seed: int,
574
+ is_depth_image: bool,
575
  ) -> list[PIL.Image.Image]:
576
  control_image, vis_control_image = self.preprocess_depth(
577
  input_image=input_image,
578
  image_resolution=image_resolution,
579
  detect_resolution=detect_resolution,
580
+ is_depth_image=is_depth_image,
581
  )
582
  return self.process(
583
  task_name='depth',