hysts HF staff commited on
Commit
85f6a01
·
1 Parent(s): d578b5a

Support background replacement

Browse files
Files changed (1) hide show
  1. app.py +76 -9
app.py CHANGED
@@ -44,9 +44,49 @@ def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str
44
  return trimap
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @spaces.GPU
48
  @torch.inference_mode()
49
- def run(image: PIL.Image.Image, trimap: PIL.Image.Image) -> tuple[PIL.Image.Image, PIL.Image.Image]:
 
 
 
 
 
50
  if image.size != trimap.size:
51
  raise gr.Error("Image and trimap must have the same size.")
52
  if max(image.size) > MAX_IMAGE_SIZE:
@@ -67,7 +107,12 @@ def run(image: PIL.Image.Image, trimap: PIL.Image.Image) -> tuple[PIL.Image.Imag
67
  foreground = (foreground * 255).astype(np.uint8)
68
  foreground = PIL.Image.fromarray(foreground)
69
 
70
- return alpha, foreground
 
 
 
 
 
71
 
72
 
73
  with gr.Blocks(css="style.css") as demo:
@@ -104,19 +149,33 @@ with gr.Blocks(css="style.css") as demo:
104
  height=500,
105
  )
106
  set_trimap_button = gr.Button("Set trimap")
 
 
107
  run_button = gr.Button("Run")
108
  with gr.Column():
109
  with gr.Box():
110
  out_alpha = gr.Image(label="Alpha", height=500)
111
  out_foreground = gr.Image(label="Foreground", height=500)
112
-
 
 
 
 
 
 
 
 
 
 
 
 
113
  gr.Examples(
114
  examples=[
115
- ["assets/bulb_rgb.png", "assets/bulb_trimap.png"],
116
- ["assets/retriever_rgb.png", "assets/retriever_trimap.png"],
117
  ],
118
- inputs=[image, trimap],
119
- outputs=[out_alpha, out_foreground],
120
  fn=run,
121
  cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
122
  )
@@ -141,10 +200,18 @@ with gr.Blocks(css="style.css") as demo:
141
  queue=False,
142
  api_name=False,
143
  )
 
 
 
 
 
 
 
 
144
  run_button.click(
145
  fn=run,
146
- inputs=[image, trimap],
147
- outputs=[out_alpha, out_foreground],
148
  api_name="run",
149
  )
150
 
 
44
  return trimap
45
 
46
 
47
+ def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image:
48
+ target_w, target_h = target_size
49
+ bg_w, bg_h = background_image.size
50
+
51
+ scale = max(target_w / bg_w, target_h / bg_h)
52
+ new_bg_w = int(bg_w * scale)
53
+ new_bg_h = int(bg_h * scale)
54
+ background_image = background_image.resize((new_bg_w, new_bg_h))
55
+ left = (new_bg_w - target_w) // 2
56
+ top = (new_bg_h - target_h) // 2
57
+ right = left + target_w
58
+ bottom = top + target_h
59
+ background_image = background_image.crop((left, top, right, bottom))
60
+ return background_image
61
+
62
+
63
+ def replace_background(
64
+ image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None
65
+ ) -> PIL.Image.Image | None:
66
+ if background_image is None:
67
+ return None
68
+
69
+ if image.mode != "RGB":
70
+ raise gr.Error("Image must be RGB.")
71
+
72
+ background_image = background_image.convert("RGB")
73
+ background_image = adjust_background_image(background_image, image.size)
74
+
75
+ image = np.array(image).astype(float) / 255
76
+ background_image = np.array(background_image).astype(float) / 255
77
+ result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None])
78
+ result = (result * 255).astype(np.uint8)
79
+ return result
80
+
81
+
82
  @spaces.GPU
83
  @torch.inference_mode()
84
+ def run(
85
+ image: PIL.Image.Image,
86
+ trimap: PIL.Image.Image,
87
+ apply_background_replacement: bool,
88
+ background_image: PIL.Image.Image | None,
89
+ ) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]:
90
  if image.size != trimap.size:
91
  raise gr.Error("Image and trimap must have the same size.")
92
  if max(image.size) > MAX_IMAGE_SIZE:
 
107
  foreground = (foreground * 255).astype(np.uint8)
108
  foreground = PIL.Image.fromarray(foreground)
109
 
110
+ if apply_background_replacement:
111
+ res_bg_replacement = replace_background(image, alpha, background_image)
112
+ else:
113
+ res_bg_replacement = None
114
+
115
+ return alpha, foreground, res_bg_replacement
116
 
117
 
118
  with gr.Blocks(css="style.css") as demo:
 
149
  height=500,
150
  )
151
  set_trimap_button = gr.Button("Set trimap")
152
+ apply_background_replacement = gr.Checkbox(label="Apply background replacement", checked=False)
153
+ background_image = gr.Image(label="Background image", type="pil", height=500, visible=False)
154
  run_button = gr.Button("Run")
155
  with gr.Column():
156
  with gr.Box():
157
  out_alpha = gr.Image(label="Alpha", height=500)
158
  out_foreground = gr.Image(label="Foreground", height=500)
159
+ out_background_replacement = gr.Image(label="Background replacement", height=500, visible=False)
160
+
161
+ inputs = [
162
+ image,
163
+ trimap,
164
+ apply_background_replacement,
165
+ background_image,
166
+ ]
167
+ outputs = [
168
+ out_alpha,
169
+ out_foreground,
170
+ out_background_replacement,
171
+ ]
172
  gr.Examples(
173
  examples=[
174
+ ["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None],
175
+ ["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"],
176
  ],
177
+ inputs=inputs,
178
+ outputs=outputs,
179
  fn=run,
180
  cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
181
  )
 
200
  queue=False,
201
  api_name=False,
202
  )
203
+ apply_background_replacement.change(
204
+ fn=lambda checked: (gr.update(visible=checked), gr.update(visible=checked)),
205
+ inputs=apply_background_replacement,
206
+ outputs=[background_image, out_background_replacement],
207
+ queue=False,
208
+ api_name=False,
209
+ )
210
+
211
  run_button.click(
212
  fn=run,
213
+ inputs=inputs,
214
+ outputs=outputs,
215
  api_name="run",
216
  )
217