Yw22 commited on
Commit
7bffd64
1 Parent(s): 665c7d3
Files changed (2) hide show
  1. app.py +153 -146
  2. pipelines/pipeline_imagecoductor.py +1 -8
app.py CHANGED
@@ -295,7 +295,7 @@ class ImageConductor:
295
  if isinstance(tracking_points, list):
296
  input_all_points = tracking_points
297
  else:
298
- input_all_points = tracking_points.constructor_args['value']
299
 
300
 
301
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
@@ -304,6 +304,10 @@ class ImageConductor:
304
  id = base.split('_')[-1]
305
 
306
 
 
 
 
 
307
  visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
308
 
309
  ## image condition
@@ -377,16 +381,18 @@ class ImageConductor:
377
  vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
378
  torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
379
 
380
- return visualized_drag, outputs_path
381
 
382
 
383
  def reset_states(first_frame_path, tracking_points):
384
  first_frame_path = gr.State()
385
  tracking_points = gr.State([])
386
- return None, first_frame_path, tracking_points
387
 
388
 
389
- def preprocess_image(image):
 
 
390
  image_pil = image2pil(image.name)
391
  raw_w, raw_h = image_pil.size
392
  resize_ratio = max(384/raw_w, 256/raw_h)
@@ -395,7 +401,7 @@ def preprocess_image(image):
395
  id = str(uuid.uuid4())[:4]
396
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
397
  image_pil.save(first_frame_path, quality=95)
398
- return first_frame_path, first_frame_path, gr.State([])
399
 
400
 
401
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
@@ -405,13 +411,13 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
405
  color = (0, 0, 255, 255)
406
 
407
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
408
- tracking_points.constructor_args['value'][-1].append(evt.index)
409
- print(tracking_points.constructor_args)
410
 
411
  transparent_background = Image.open(first_frame_path).convert('RGBA')
412
  w, h = transparent_background.size
413
  transparent_layer = np.zeros((h, w, 4))
414
- for track in tracking_points.constructor_args['value']:
415
  if len(track) > 1:
416
  for i in range(len(track)-1):
417
  start_point = track[i]
@@ -428,13 +434,13 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
428
 
429
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
430
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
431
- return tracking_points, trajectory_map
432
 
433
 
434
  def add_drag(tracking_points):
435
- tracking_points.constructor_args['value'].append([])
436
- print(tracking_points.constructor_args)
437
- return tracking_points
438
 
439
 
440
  def delete_last_drag(tracking_points, first_frame_path, drag_mode):
@@ -442,11 +448,11 @@ def delete_last_drag(tracking_points, first_frame_path, drag_mode):
442
  color = (255, 0, 0, 255)
443
  elif drag_mode=='camera':
444
  color = (0, 0, 255, 255)
445
- tracking_points.constructor_args['value'].pop()
446
  transparent_background = Image.open(first_frame_path).convert('RGBA')
447
  w, h = transparent_background.size
448
  transparent_layer = np.zeros((h, w, 4))
449
- for track in tracking_points.constructor_args['value']:
450
  if len(track) > 1:
451
  for i in range(len(track)-1):
452
  start_point = track[i]
@@ -463,7 +469,7 @@ def delete_last_drag(tracking_points, first_frame_path, drag_mode):
463
 
464
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
465
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
466
- return tracking_points, trajectory_map
467
 
468
 
469
  def delete_last_step(tracking_points, first_frame_path, drag_mode):
@@ -471,11 +477,11 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
471
  color = (255, 0, 0, 255)
472
  elif drag_mode=='camera':
473
  color = (0, 0, 255, 255)
474
- tracking_points.constructor_args['value'][-1].pop()
475
  transparent_background = Image.open(first_frame_path).convert('RGBA')
476
  w, h = transparent_background.size
477
  transparent_layer = np.zeros((h, w, 4))
478
- for track in tracking_points.constructor_args['value']:
479
  if len(track) > 1:
480
  for i in range(len(track)-1):
481
  start_point = track[i]
@@ -492,147 +498,148 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
492
 
493
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
494
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
495
- return tracking_points, trajectory_map
496
-
497
-
498
- block = gr.Blocks(
499
- theme=gr.themes.Soft(
500
- radius_size=gr.themes.sizes.radius_none,
501
- text_size=gr.themes.sizes.text_md
502
- )
503
- ).queue()
504
- with block as demo:
505
- with gr.Row():
506
- with gr.Column():
507
- gr.HTML(head)
508
-
509
- gr.Markdown(descriptions)
510
-
511
- with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
512
- with gr.Row(equal_height=True):
513
- gr.Markdown(instructions)
514
-
515
-
516
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
517
- device = torch.device("cuda")
518
- unet_path = 'models/unet.ckpt'
519
- image_controlnet_path = 'models/image_controlnet.ckpt'
520
- flow_controlnet_path = 'models/flow_controlnet.ckpt'
521
- ImageConductor_net = ImageConductor(device=device,
522
- unet_path=unet_path,
523
- image_controlnet_path=image_controlnet_path,
524
- flow_controlnet_path=flow_controlnet_path,
525
- height=256,
526
- width=384,
527
- model_length=16
528
- )
529
- first_frame_path = gr.State()
530
- tracking_points = gr.State([])
531
-
532
- with gr.Row():
533
- with gr.Column(scale=1):
534
- image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
535
- add_drag_button = gr.Button(value="Add Drag")
536
- reset_button = gr.Button(value="Reset")
537
- delete_last_drag_button = gr.Button(value="Delete last drag")
538
- delete_last_step_button = gr.Button(value="Delete last step")
539
-
540
-
541
-
542
- with gr.Column(scale=7):
543
- with gr.Row():
544
- with gr.Column(scale=6):
545
- input_image = gr.Image(label="Input Image",
546
- interactive=True,
547
- height=265,
548
- width=384,)
549
- with gr.Column(scale=6):
550
- output_image = gr.Image(label="Motion Path",
551
- interactive=False,
552
  height=256,
553
- width=384,)
554
- with gr.Row():
555
- with gr.Column(scale=1):
556
- prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
557
- negative_prompt = gr.Text(
558
- label="Negative Prompt",
559
- max_lines=5,
560
- placeholder="Please input your negative prompt",
561
- value='worst quality, low quality, letterboxed',lines=1
562
- )
563
- drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
564
- run_button = gr.Button(value="Run")
565
-
566
- with gr.Accordion("More input params", open=False, elem_id="accordion1"):
567
- with gr.Group():
568
- seed = gr.Textbox(
569
- label="Seed: ", value=561793204,
570
- )
571
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
572
 
573
- with gr.Group():
574
- with gr.Row():
575
- guidance_scale = gr.Slider(
576
- label="Guidance scale",
577
- minimum=1,
578
- maximum=12,
579
- step=0.1,
580
- value=8.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  )
582
- num_inference_steps = gr.Slider(
583
- label="Number of inference steps",
584
- minimum=1,
585
- maximum=50,
586
- step=1,
587
- value=25,
 
588
  )
589
-
590
- with gr.Group():
591
- personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
592
- examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
593
-
594
- with gr.Column(scale=7):
595
- output_video = gr.Video(value=None,
596
- label="Output Video",
597
- width=384,
598
- height=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
 
601
- with gr.Row():
602
- def process_example(input_image, prompt, drag_mode, seed, personalized, examples_type):
603
-
604
- return input_image, prompt, drag_mode, seed, personalized, examples_type
605
-
606
- example = gr.Examples(
607
- label="Input Example",
608
- examples=image_examples,
609
- inputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
610
- outputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
611
- fn=process_example,
612
- run_on_click=True,
613
- examples_per_page=10,
614
- cache_examples=False,
615
- )
616
-
617
-
618
- with gr.Row():
619
- gr.Markdown(citation)
620
 
621
-
622
- image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
623
 
624
- add_drag_button.click(add_drag, tracking_points, tracking_points)
625
 
626
- delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
627
 
628
- delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
629
 
630
- reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points])
631
 
632
- input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
633
 
634
- run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode,
635
- negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
636
- [output_image, output_video])
637
 
638
  demo.launch()
 
295
  if isinstance(tracking_points, list):
296
  input_all_points = tracking_points
297
  else:
298
+ input_all_points = tracking_points
299
 
300
 
301
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
 
304
  id = base.split('_')[-1]
305
 
306
 
307
+ # with open(f'{output_dir}/points-{id}.json', 'w') as f:
308
+ # json.dump(input_all_points, f)
309
+
310
+
311
  visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
312
 
313
  ## image condition
 
381
  vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
382
  torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
383
 
384
+ return {output_image: visualized_drag, output_video: outputs_path}
385
 
386
 
387
  def reset_states(first_frame_path, tracking_points):
388
  first_frame_path = gr.State()
389
  tracking_points = gr.State([])
390
+ return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
391
 
392
 
393
+ def preprocess_image(image, tracking_points):
394
+ if len(tracking_points) != 0:
395
+ tracking_points = gr.State([])
396
  image_pil = image2pil(image.name)
397
  raw_w, raw_h = image_pil.size
398
  resize_ratio = max(384/raw_w, 256/raw_h)
 
401
  id = str(uuid.uuid4())[:4]
402
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
403
  image_pil.save(first_frame_path, quality=95)
404
+ return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
405
 
406
 
407
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
 
411
  color = (0, 0, 255, 255)
412
 
413
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
414
+ tracking_points[-1].append(evt.index)
415
+ print(tracking_points)
416
 
417
  transparent_background = Image.open(first_frame_path).convert('RGBA')
418
  w, h = transparent_background.size
419
  transparent_layer = np.zeros((h, w, 4))
420
+ for track in tracking_points:
421
  if len(track) > 1:
422
  for i in range(len(track)-1):
423
  start_point = track[i]
 
434
 
435
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
436
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
437
+ return {tracking_points_var: tracking_points, input_image: trajectory_map}
438
 
439
 
440
  def add_drag(tracking_points):
441
+ tracking_points.append([])
442
+ print(tracking_points)
443
+ return {tracking_points_var: tracking_points}
444
 
445
 
446
  def delete_last_drag(tracking_points, first_frame_path, drag_mode):
 
448
  color = (255, 0, 0, 255)
449
  elif drag_mode=='camera':
450
  color = (0, 0, 255, 255)
451
+ tracking_points.pop()
452
  transparent_background = Image.open(first_frame_path).convert('RGBA')
453
  w, h = transparent_background.size
454
  transparent_layer = np.zeros((h, w, 4))
455
+ for track in tracking_points:
456
  if len(track) > 1:
457
  for i in range(len(track)-1):
458
  start_point = track[i]
 
469
 
470
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
471
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
472
+ return {tracking_points_var: tracking_points, input_image: trajectory_map}
473
 
474
 
475
  def delete_last_step(tracking_points, first_frame_path, drag_mode):
 
477
  color = (255, 0, 0, 255)
478
  elif drag_mode=='camera':
479
  color = (0, 0, 255, 255)
480
+ tracking_points[-1].pop()
481
  transparent_background = Image.open(first_frame_path).convert('RGBA')
482
  w, h = transparent_background.size
483
  transparent_layer = np.zeros((h, w, 4))
484
+ for track in tracking_points:
485
  if len(track) > 1:
486
  for i in range(len(track)-1):
487
  start_point = track[i]
 
498
 
499
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
500
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
501
+ return {tracking_points_var: tracking_points, input_image: trajectory_map}
502
+
503
+
504
+ if __name__=="__main__":
505
+ block = gr.Blocks(
506
+ theme=gr.themes.Soft(
507
+ radius_size=gr.themes.sizes.radius_none,
508
+ text_size=gr.themes.sizes.text_md
509
+ )
510
+ ).queue()
511
+ with block as demo:
512
+ with gr.Row():
513
+ with gr.Column():
514
+ gr.HTML(head)
515
+
516
+ gr.Markdown(descriptions)
517
+
518
+ with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
519
+ with gr.Row(equal_height=True):
520
+ gr.Markdown(instructions)
521
+
522
+
523
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
524
+ device = torch.device("cuda")
525
+ unet_path = 'models/unet.ckpt'
526
+ image_controlnet_path = 'models/image_controlnet.ckpt'
527
+ flow_controlnet_path = 'models/flow_controlnet.ckpt'
528
+ ImageConductor_net = ImageConductor(device=device,
529
+ unet_path=unet_path,
530
+ image_controlnet_path=image_controlnet_path,
531
+ flow_controlnet_path=flow_controlnet_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  height=256,
533
+ width=384,
534
+ model_length=16
535
+ )
536
+ first_frame_path_var = gr.State(value=None)
537
+ tracking_points_var = gr.State([])
538
+
539
+ with gr.Row():
540
+ with gr.Column(scale=1):
541
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
542
+ add_drag_button = gr.Button(value="Add Drag")
543
+ reset_button = gr.Button(value="Reset")
544
+ delete_last_drag_button = gr.Button(value="Delete last drag")
545
+ delete_last_step_button = gr.Button(value="Delete last step")
546
+
 
 
 
 
 
547
 
548
+
549
+ with gr.Column(scale=7):
550
+ with gr.Row():
551
+ with gr.Column(scale=6):
552
+ input_image = gr.Image(label="Input Image",
553
+ interactive=True,
554
+ height=300,
555
+ width=384,)
556
+ with gr.Column(scale=6):
557
+ output_image = gr.Image(label="Motion Path",
558
+ interactive=False,
559
+ height=256,
560
+ width=384,)
561
+ with gr.Row():
562
+ with gr.Column(scale=1):
563
+ prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
564
+ negative_prompt = gr.Text(
565
+ label="Negative Prompt",
566
+ max_lines=5,
567
+ placeholder="Please input your negative prompt",
568
+ value='worst quality, low quality, letterboxed',lines=1
569
  )
570
+ drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
571
+ run_button = gr.Button(value="Run")
572
+
573
+ with gr.Accordion("More input params", open=False, elem_id="accordion1"):
574
+ with gr.Group():
575
+ seed = gr.Textbox(
576
+ label="Seed: ", value=561793204,
577
  )
578
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
579
+
580
+ with gr.Group():
581
+ with gr.Row():
582
+ guidance_scale = gr.Slider(
583
+ label="Guidance scale",
584
+ minimum=1,
585
+ maximum=12,
586
+ step=0.1,
587
+ value=8.5,
588
+ )
589
+ num_inference_steps = gr.Slider(
590
+ label="Number of inference steps",
591
+ minimum=1,
592
+ maximum=50,
593
+ step=1,
594
+ value=25,
595
+ )
596
+
597
+ with gr.Group():
598
+ personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
599
+ examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
600
+
601
+ with gr.Column(scale=7):
602
+ output_video = gr.Video(
603
+ label="Output Video",
604
+ width=384,
605
+ height=256)
606
+
607
+
608
+ with gr.Row():
609
+ def process_example(input_image, prompt, drag_mode, seed, personalized, examples_type):
610
+
611
+ return input_image, prompt, drag_mode, seed, personalized, examples_type
612
+
613
+ example = gr.Examples(
614
+ label="Input Example",
615
+ examples=image_examples,
616
+ inputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
617
+ outputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
618
+ fn=process_example,
619
+ run_on_click=True,
620
+ examples_per_page=10,
621
+ cache_examples=False,
622
+ )
623
 
624
 
625
+ with gr.Row():
626
+ gr.Markdown(citation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
+
629
+ image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
630
 
631
+ add_drag_button.click(add_drag, [tracking_points_var], tracking_points_var)
632
 
633
+ delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
634
 
635
+ delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
636
 
637
+ reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
638
 
639
+ input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
640
 
641
+ run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
642
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
643
+ [output_image, output_video])
644
 
645
  demo.launch()
pipelines/pipeline_imagecoductor.py CHANGED
@@ -404,7 +404,6 @@ class ImageConductorPipeline(DiffusionPipeline):
404
  obj_latents = copy.deepcopy(latents)
405
  cam_latents = copy.deepcopy(latents)
406
 
407
- print("device", device)
408
  # Prepare extra step kwargs.
409
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
410
 
@@ -463,13 +462,7 @@ class ImageConductorPipeline(DiffusionPipeline):
463
  controlnet_images_mask = controlnet_images_mask.half()
464
  controlnet_flows = controlnet_flows.half()
465
  text_embeddings = text_embeddings.half()
466
- print("controlnet_noisy_latents device", controlnet_noisy_latents.device)
467
- print("controlnet_prompt_embeds device", controlnet_prompt_embeds.device)
468
- print("controlnet_images device", controlnet_images.device)
469
- print("t", t.device)
470
-
471
-
472
- print("self.image_controlnet", self.image_controlnet.controlnet_mid_block.weight.device)
473
 
474
  img_down_block_additional_residuals, img_mid_block_additional_residuals = self.image_controlnet(
475
  controlnet_noisy_latents, t,
 
404
  obj_latents = copy.deepcopy(latents)
405
  cam_latents = copy.deepcopy(latents)
406
 
 
407
  # Prepare extra step kwargs.
408
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
409
 
 
462
  controlnet_images_mask = controlnet_images_mask.half()
463
  controlnet_flows = controlnet_flows.half()
464
  text_embeddings = text_embeddings.half()
465
+
 
 
 
 
 
 
466
 
467
  img_down_block_additional_residuals, img_mid_block_additional_residuals = self.image_controlnet(
468
  controlnet_noisy_latents, t,