Yw22 commited on
Commit
112b465
1 Parent(s): ec1429c
Files changed (1) hide show
  1. app.py +142 -159
app.py CHANGED
@@ -302,30 +302,26 @@ class ImageConductor:
302
  self.blur_kernel = blur_kernel
303
 
304
  @spaces.GPU(duration=120)
305
- def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type):
306
- print("Run!")
307
- # if examples_type != "":
308
- # ### for adapting high version gradio
309
- # tracking_points = gr.State([])
310
- # first_frame_path = IMAGE_PATH[examples_type]
311
- # points = json.load(open(POINTS[examples_type]))
312
- # tracking_points.value.extend(points)
313
- # print("example first_frame_path", first_frame_path)
314
- # print("example tracking_points", tracking_points.value)
315
-
316
  original_width, original_height=384, 256
317
  if isinstance(tracking_points, list):
318
  input_all_points = tracking_points
319
  else:
320
  input_all_points = tracking_points.value
321
 
322
- print("input_all_points", input_all_points)
323
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
324
- print("first_frame_path", first_frame_path)
325
  dir, base, ext = split_filename(first_frame_path)
326
  id = base.split('_')[-1]
327
 
328
 
 
 
 
 
329
  visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
330
 
331
  ## image condition
@@ -337,8 +333,9 @@ class ImageConductor:
337
  transforms.ToTensor(),
338
  ])
339
 
 
340
  image_paths = [first_frame_path]
341
- controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
342
  controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
343
  controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
344
  num_controlnet_images = controlnet_images.shape[2]
@@ -398,9 +395,10 @@ class ImageConductor:
398
  # vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
399
  # torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
400
 
 
401
  outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
402
  save_videos_grid(sample[0][None], outputs_path)
403
- print("Done!")
404
  return {output_image: visualized_drag, output_video: outputs_path}
405
 
406
 
@@ -410,7 +408,7 @@ def reset_states(first_frame_path, tracking_points):
410
  return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
411
 
412
 
413
- def preprocess_image(image, tracking_points):
414
  image_pil = image2pil(image.name)
415
  raw_w, raw_h = image_pil.size
416
  resize_ratio = max(384/raw_w, 256/raw_h)
@@ -419,8 +417,7 @@ def preprocess_image(image, tracking_points):
419
  id = str(uuid.uuid4())[:4]
420
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
421
  image_pil.save(first_frame_path, quality=95)
422
- tracking_points = gr.State([])
423
- return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points, personalized:""}
424
 
425
 
426
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
@@ -429,27 +426,14 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
429
  elif drag_mode=='camera':
430
  color = (0, 0, 255, 255)
431
 
432
- if not isinstance(tracking_points ,list):
433
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
434
- tracking_points.value[-1].append(evt.index)
435
- print(tracking_points.value)
436
- tracking_points_values = tracking_points.value
437
- else:
438
- try:
439
- tracking_points[-1].append(evt.index)
440
- except Exception as e:
441
- tracking_points.append([])
442
- tracking_points[-1].append(evt.index)
443
- print(f"Solved Error: {e}")
444
-
445
- tracking_points_values = tracking_points
446
-
447
 
448
  transparent_background = Image.open(first_frame_path).convert('RGBA')
449
  w, h = transparent_background.size
450
  transparent_layer = np.zeros((h, w, 4))
451
-
452
- for track in tracking_points_values:
453
  if len(track) > 1:
454
  for i in range(len(track)-1):
455
  start_point = track[i]
@@ -470,12 +454,9 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
470
 
471
 
472
  def add_drag(tracking_points):
473
- if not isinstance(tracking_points ,list):
474
- # print("before", tracking_points.value)
475
- tracking_points.value.append([])
476
- # print(tracking_points.value)
477
- else:
478
- tracking_points.append([])
479
  return {tracking_points_var: tracking_points}
480
 
481
 
@@ -537,142 +518,144 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
537
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
538
 
539
 
540
- block = gr.Blocks(
541
- theme=gr.themes.Soft(
542
- radius_size=gr.themes.sizes.radius_none,
543
- text_size=gr.themes.sizes.text_md
544
- )
545
- )
546
- with block:
547
- with gr.Row():
548
- with gr.Column():
549
- gr.HTML(head)
550
-
551
- gr.Markdown(descriptions)
552
-
553
- with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
554
- with gr.Row(equal_height=True):
555
- gr.Markdown(instructions)
556
-
557
-
558
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
559
- device = torch.device("cuda")
560
- unet_path = 'models/unet.ckpt'
561
- image_controlnet_path = 'models/image_controlnet.ckpt'
562
- flow_controlnet_path = 'models/flow_controlnet.ckpt'
563
- ImageConductor_net = ImageConductor(device=device,
564
- unet_path=unet_path,
565
- image_controlnet_path=image_controlnet_path,
566
- flow_controlnet_path=flow_controlnet_path,
567
- height=256,
568
- width=384,
569
- model_length=16
570
- )
571
- first_frame_path_var = gr.State()
572
- tracking_points_var = gr.State([])
573
-
574
- with gr.Row():
575
- with gr.Column(scale=1):
576
- image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
577
- add_drag_button = gr.Button(value="Add Drag")
578
- reset_button = gr.Button(value="Reset")
579
- delete_last_drag_button = gr.Button(value="Delete last drag")
580
- delete_last_step_button = gr.Button(value="Delete last step")
581
-
582
-
583
-
584
- with gr.Column(scale=7):
585
- with gr.Row():
586
- with gr.Column(scale=6):
587
- input_image = gr.Image(label="Input Image",
588
- interactive=True,
589
- height=300,
590
- width=384,)
591
- with gr.Column(scale=6):
592
- output_image = gr.Image(label="Motion Path",
593
- interactive=False,
594
  height=256,
595
- width=384,)
596
- with gr.Row():
597
- with gr.Column(scale=1):
598
- prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
599
- negative_prompt = gr.Text(
600
- label="Negative Prompt",
601
- max_lines=5,
602
- placeholder="Please input your negative prompt",
603
- value='worst quality, low quality, letterboxed',lines=1
604
- )
605
- drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
606
- run_button = gr.Button(value="Run")
607
-
608
- with gr.Accordion("More input params", open=False, elem_id="accordion1"):
609
- with gr.Group():
610
- seed = gr.Textbox(
611
- label="Seed: ", value=561793204,
612
- )
613
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
614
 
615
- with gr.Group():
616
- with gr.Row():
617
- guidance_scale = gr.Slider(
618
- label="Guidance scale",
619
- minimum=1,
620
- maximum=12,
621
- step=0.1,
622
- value=8.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  )
624
- num_inference_steps = gr.Slider(
625
- label="Number of inference steps",
626
- minimum=1,
627
- maximum=50,
628
- step=1,
629
- value=25,
 
630
  )
631
-
632
- with gr.Group():
633
- personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
634
- examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
635
-
636
- with gr.Column(scale=7):
637
- # output_video = gr.Video(
638
- # label="Output Video",
639
- # width=384,
640
- # height=256)
641
- output_video = gr.Image(label="Output Video",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  height=256,
643
  width=384,)
644
-
645
-
646
- with gr.Row():
647
-
648
-
649
- example = gr.Examples(
650
  label="Input Example",
651
  examples=image_examples,
652
  inputs=[input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var],
 
 
653
  examples_per_page=10,
654
  cache_examples=False,
655
  )
656
-
657
-
658
- with gr.Row():
659
- gr.Markdown(citation)
660
 
661
-
662
- image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var, personalized])
663
 
664
- add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
665
 
666
- delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
667
 
668
- delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
669
 
670
- reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
671
 
672
- input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
673
 
674
- run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
675
- negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
676
- [output_image, output_video])
677
 
678
  block.queue().launch()
 
302
  self.blur_kernel = blur_kernel
303
 
304
  @spaces.GPU(duration=120)
305
+ def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized,):
306
+
307
+
 
 
 
 
 
 
 
 
308
  original_width, original_height=384, 256
309
  if isinstance(tracking_points, list):
310
  input_all_points = tracking_points
311
  else:
312
  input_all_points = tracking_points.value
313
 
314
+
315
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
316
+
317
  dir, base, ext = split_filename(first_frame_path)
318
  id = base.split('_')[-1]
319
 
320
 
321
+ # with open(f'{output_dir}/points-{id}.json', 'w') as f:
322
+ # json.dump(input_all_points, f)
323
+
324
+
325
  visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
326
 
327
  ## image condition
 
333
  transforms.ToTensor(),
334
  ])
335
 
336
+ image_norm = lambda x: x
337
  image_paths = [first_frame_path]
338
+ controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
339
  controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
340
  controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
341
  num_controlnet_images = controlnet_images.shape[2]
 
395
  # vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
396
  # torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
397
 
398
+
399
  outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
400
  save_videos_grid(sample[0][None], outputs_path)
401
+
402
  return {output_image: visualized_drag, output_video: outputs_path}
403
 
404
 
 
408
  return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
409
 
410
 
411
+ def preprocess_image(image):
412
  image_pil = image2pil(image.name)
413
  raw_w, raw_h = image_pil.size
414
  resize_ratio = max(384/raw_w, 256/raw_h)
 
417
  id = str(uuid.uuid4())[:4]
418
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
419
  image_pil.save(first_frame_path, quality=95)
420
+ return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: gr.State([]), personalized: ""}
 
421
 
422
 
423
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
 
426
  elif drag_mode=='camera':
427
  color = (0, 0, 255, 255)
428
 
429
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
430
+ tracking_points.value[-1].append(evt.index)
431
+ print(tracking_points.value)
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  transparent_background = Image.open(first_frame_path).convert('RGBA')
434
  w, h = transparent_background.size
435
  transparent_layer = np.zeros((h, w, 4))
436
+ for track in tracking_points.value:
 
437
  if len(track) > 1:
438
  for i in range(len(track)-1):
439
  start_point = track[i]
 
454
 
455
 
456
  def add_drag(tracking_points):
457
+ # import ipdb; ipdb.set_trace()
458
+ tracking_points.value.append([])
459
+ print(tracking_points.value)
 
 
 
460
  return {tracking_points_var: tracking_points}
461
 
462
 
 
518
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
519
 
520
 
521
+ if __name__=="__main__":
522
+ block = gr.Blocks(
523
+ theme=gr.themes.Soft(
524
+ radius_size=gr.themes.sizes.radius_none,
525
+ text_size=gr.themes.sizes.text_md
526
+ )
527
+ ).queue()
528
+ with block as demo:
529
+ with gr.Row():
530
+ with gr.Column():
531
+ gr.HTML(head)
532
+
533
+ gr.Markdown(descriptions)
534
+
535
+ with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
536
+ with gr.Row(equal_height=True):
537
+ gr.Markdown(instructions)
538
+
539
+
540
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
541
+ device = torch.device("cuda")
542
+ unet_path = 'models/unet.ckpt'
543
+ image_controlnet_path = 'models/image_controlnet.ckpt'
544
+ flow_controlnet_path = 'models/flow_controlnet.ckpt'
545
+ ImageConductor_net = ImageConductor(device=device,
546
+ unet_path=unet_path,
547
+ image_controlnet_path=image_controlnet_path,
548
+ flow_controlnet_path=flow_controlnet_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  height=256,
550
+ width=384,
551
+ model_length=16
552
+ )
553
+ first_frame_path_var = gr.State(value=None)
554
+ tracking_points_var = gr.State([])
555
+
556
+ with gr.Row():
557
+ with gr.Column(scale=1):
558
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
559
+ add_drag_button = gr.Button(value="Add Drag")
560
+ reset_button = gr.Button(value="Reset")
561
+ delete_last_drag_button = gr.Button(value="Delete last drag")
562
+ delete_last_step_button = gr.Button(value="Delete last step")
563
+
 
 
 
 
 
564
 
565
+
566
+ with gr.Column(scale=7):
567
+ with gr.Row():
568
+ with gr.Column(scale=6):
569
+ input_image = gr.Image(label="Input Image",
570
+ interactive=True,
571
+ height=300,
572
+ width=384,)
573
+ with gr.Column(scale=6):
574
+ output_image = gr.Image(label="Motion Path",
575
+ interactive=False,
576
+ height=256,
577
+ width=384,)
578
+ with gr.Row():
579
+ with gr.Column(scale=1):
580
+ prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
581
+ negative_prompt = gr.Text(
582
+ label="Negative Prompt",
583
+ max_lines=5,
584
+ placeholder="Please input your negative prompt",
585
+ value='worst quality, low quality, letterboxed',lines=1
586
  )
587
+ drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
588
+ run_button = gr.Button(value="Run")
589
+
590
+ with gr.Accordion("More input params", open=False, elem_id="accordion1"):
591
+ with gr.Group():
592
+ seed = gr.Textbox(
593
+ label="Seed: ", value=561793204,
594
  )
595
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
596
+
597
+ with gr.Group():
598
+ with gr.Row():
599
+ guidance_scale = gr.Slider(
600
+ label="Guidance scale",
601
+ minimum=1,
602
+ maximum=12,
603
+ step=0.1,
604
+ value=8.5,
605
+ )
606
+ num_inference_steps = gr.Slider(
607
+ label="Number of inference steps",
608
+ minimum=1,
609
+ maximum=50,
610
+ step=1,
611
+ value=25,
612
+ )
613
+
614
+ with gr.Group():
615
+ personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
616
+ # examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
617
+
618
+ with gr.Column(scale=7):
619
+ # output_video = gr.Video(
620
+ # label="Output Video",
621
+ # width=384,
622
+ # height=256)
623
+ output_video = gr.Image(label="Output Video",
624
  height=256,
625
  width=384,)
626
+
627
+
628
+ with gr.Row():
629
+ def process_examples(input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var):
630
+ return input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var
631
+ example = gr.Examples(
632
  label="Input Example",
633
  examples=image_examples,
634
  inputs=[input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var],
635
+ outputs=[input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var],
636
+ fn=process_examples,
637
  examples_per_page=10,
638
  cache_examples=False,
639
  )
640
+
641
+ with gr.Row():
642
+ gr.Markdown(citation)
 
643
 
644
+
645
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path_var, tracking_points_var, personalized])
646
 
647
+ add_drag_button.click(add_drag, [tracking_points_var], tracking_points_var)
648
 
649
+ delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
650
 
651
+ delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
652
 
653
+ reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
654
 
655
+ input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
656
 
657
+ run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
658
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized],
659
+ [output_image, output_video])
660
 
661
  block.queue().launch()