Qdssa commited on
Commit
8b8b073
·
verified ·
1 Parent(s): 3072485

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -19
app.py CHANGED
@@ -227,6 +227,33 @@ def run_rmbg(img, sigma=0.0):
227
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  @torch.inference_mode()
232
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
@@ -256,6 +283,7 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
256
 
257
  rng = torch.Generator(device=device).manual_seed(int(seed))
258
 
 
259
  fg = resize_and_center_crop(input_fg, image_width, image_height)
260
 
261
  concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
@@ -277,7 +305,8 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
277
  cross_attention_kwargs={'concat_conds': concat_conds},
278
  ).images.to(vae.dtype) / vae.config.scaling_factor
279
  else:
280
- bg = resize_and_center_crop(input_bg, image_width, image_height)
 
281
  bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
282
  bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
283
  latents = i2i_pipe(
@@ -333,10 +362,11 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
333
  return pytorch2numpy(pixels)
334
 
335
 
336
- @spaces.GPU
337
  @torch.inference_mode()
338
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
339
- input_fg, matting = run_rmbg(input_fg)
 
340
  results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
341
  return input_fg, results
342
 
@@ -378,14 +408,12 @@ class BGSource(Enum):
378
  block = gr.Blocks().queue()
379
  with block:
380
  with gr.Row():
381
- gr.Markdown("## IC-Light (Relighting with Foreground Condition)")
382
- with gr.Row():
383
- gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
384
  with gr.Row():
385
  with gr.Column():
386
  with gr.Row():
387
- input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480)
388
- output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
389
  prompt = gr.Textbox(label="Prompt")
390
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
391
  value=BGSource.NONE.value,
@@ -400,8 +428,8 @@ with block:
400
  seed = gr.Number(label="Seed", value=12345, precision=0)
401
 
402
  with gr.Row():
403
- image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
404
- image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
405
 
406
  with gr.Accordion("Advanced options", open=False):
407
  steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
@@ -415,15 +443,7 @@ with block:
415
  result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
416
  with gr.Row():
417
  dummy_image_for_outputs = gr.Image(visible=False, label='Result')
418
- gr.Examples(
419
- fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
420
- examples=db_examples.foreground_conditioned_examples,
421
- inputs=[
422
- input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
423
- ],
424
- outputs=[result_gallery, output_bg],
425
- run_on_click=True, examples_per_page=1024
426
- )
427
  ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
428
  relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
429
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
 
227
  result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
228
  return result.clip(0, 255).astype(np.uint8), alpha
229
 
230
+ @torch.inference_mode()
231
+ def merge_alpha(img, sigma=0.0):
232
+ if img is None:
233
+ return None
234
+
235
+ if len(img.shape) == 2:
236
+ img = np.stack((img,)*3, axis=-1)
237
+
238
+ H, W, C = img.shape
239
+ print(f"img.shape: {img.shape}")
240
+
241
+ if C == 3:
242
+ img, _ = run_rmbg(img)
243
+ return img
244
+ elif C == 4:
245
+ rgb = img[:, :, :3].astype(np.float32)
246
+ alpha = img[:, :, 3].astype(np.float32) / 255.0
247
+
248
+ result = rgb * alpha[:, :, np.newaxis] + 255 * (1 - alpha[:, :, np.newaxis])
249
+
250
+ if sigma != 0:
251
+ result += sigma * alpha[:, :, np.newaxis]
252
+
253
+ return np.clip(result, 0, 255).astype(np.uint8)
254
+ else:
255
+ raise ValueError(f"Unexpected number of channels: {C}")
256
+
257
 
258
  @torch.inference_mode()
259
  def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
 
283
 
284
  rng = torch.Generator(device=device).manual_seed(int(seed))
285
 
286
+ #fg = input_fg
287
  fg = resize_and_center_crop(input_fg, image_width, image_height)
288
 
289
  concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
 
305
  cross_attention_kwargs={'concat_conds': concat_conds},
306
  ).images.to(vae.dtype) / vae.config.scaling_factor
307
  else:
308
+ #bg = input_bg
309
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
310
  bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
311
  bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
312
  latents = i2i_pipe(
 
362
  return pytorch2numpy(pixels)
363
 
364
 
365
+ @spaces.GPU(duration=240)
366
  @torch.inference_mode()
367
  def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
368
+ #input_fg, matting = run_rmbg(input_fg)
369
+ input_fg = merge_alpha(input_fg)
370
  results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
371
  return input_fg, results
372
 
 
408
  block = gr.Blocks().queue()
409
  with block:
410
  with gr.Row():
411
+ gr.Markdown("##ICLight without mask")
 
 
412
  with gr.Row():
413
  with gr.Column():
414
  with gr.Row():
415
+ input_fg = gr.Image(sources='upload', type="numpy", label="Image", image_mode='RGBA')
416
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground")
417
  prompt = gr.Textbox(label="Prompt")
418
  bg_source = gr.Radio(choices=[e.value for e in BGSource],
419
  value=BGSource.NONE.value,
 
428
  seed = gr.Number(label="Seed", value=12345, precision=0)
429
 
430
  with gr.Row():
431
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=512, step=64)
432
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=640, step=64)
433
 
434
  with gr.Accordion("Advanced options", open=False):
435
  steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
 
443
  result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
444
  with gr.Row():
445
  dummy_image_for_outputs = gr.Image(visible=False, label='Result')
446
+
 
 
 
 
 
 
 
 
447
  ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
448
  relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
449
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)