jhj0517 commited on
Commit
de2727e
·
1 Parent(s): da162f8

add progress for musepose inference

Browse files
musepose/pipelines/pipeline_pose2vid_long.py CHANGED
@@ -3,6 +3,7 @@ import inspect
3
  import math
4
  from dataclasses import dataclass
5
  from typing import Callable, List, Optional, Union
 
6
 
7
  import numpy as np
8
  import torch
@@ -53,6 +54,7 @@ class Pose2VideoPipeline(DiffusionPipeline):
53
  image_proj_model=None,
54
  tokenizer=None,
55
  text_encoder=None,
 
56
  ):
57
  super().__init__()
58
 
@@ -77,6 +79,7 @@ class Pose2VideoPipeline(DiffusionPipeline):
77
  do_convert_rgb=True,
78
  do_normalize=False,
79
  )
 
80
 
81
  def enable_vae_slicing(self):
82
  self.vae.enable_slicing()
@@ -117,6 +120,8 @@ class Pose2VideoPipeline(DiffusionPipeline):
117
  video = []
118
  for frame_idx in tqdm(range(latents.shape[0])):
119
  video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
 
 
120
  video = torch.cat(video)
121
  video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
122
  video = (video / 2 + 0.5).clamp(0, 1)
@@ -448,6 +453,8 @@ class Pose2VideoPipeline(DiffusionPipeline):
448
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
449
  with self.progress_bar(total=num_inference_steps) as progress_bar:
450
  for i, t in enumerate(timesteps):
 
 
451
  noise_pred = torch.zeros(
452
  (
453
  latents.shape[0] * (2 if do_classifier_free_guidance else 1),
 
3
  import math
4
  from dataclasses import dataclass
5
  from typing import Callable, List, Optional, Union
6
+ import gradio as gr
7
 
8
  import numpy as np
9
  import torch
 
54
  image_proj_model=None,
55
  tokenizer=None,
56
  text_encoder=None,
57
+ gradio_progress: gr.Progress = gr.Progress(),
58
  ):
59
  super().__init__()
60
 
 
79
  do_convert_rgb=True,
80
  do_normalize=False,
81
  )
82
+ self.gradio_progress = gradio_progress
83
 
84
  def enable_vae_slicing(self):
85
  self.vae.enable_slicing()
 
120
  video = []
121
  for frame_idx in tqdm(range(latents.shape[0])):
122
  video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
123
+ self.gradio_progress(frame_idx /latents.shape[0], "Writing Video..")
124
+
125
  video = torch.cat(video)
126
  video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
127
  video = (video / 2 + 0.5).clamp(0, 1)
 
453
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
454
  with self.progress_bar(total=num_inference_steps) as progress_bar:
455
  for i, t in enumerate(timesteps):
456
+ self.gradio_progress(i/num_inference_steps, "Making Image Move..")
457
+
458
  noise_pred = torch.zeros(
459
  (
460
  latents.shape[0] * (2 if do_classifier_free_guidance else 1),
musepose_inference.py CHANGED
@@ -11,6 +11,7 @@ from transformers import CLIPVisionModelWithProjection
11
  import torch.nn.functional as F
12
  import gc
13
  from huggingface_hub import hf_hub_download
 
14
 
15
  from musepose.models.pose_guider import PoseGuider
16
  from musepose.models.unet_2d_condition import UNet2DConditionModel
@@ -65,7 +66,8 @@ class MusePoseInference:
65
  seed: int,
66
  steps: int,
67
  fps: int,
68
- skip: int
 
69
  ):
70
  download_models(model_dir=self.model_dir)
71
  print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
@@ -144,6 +146,7 @@ class MusePoseInference:
144
  denoising_unet=self.denoising_unet,
145
  pose_guider=self.pose_guider,
146
  scheduler=scheduler,
 
147
  )
148
  self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
149
 
 
11
  import torch.nn.functional as F
12
  import gc
13
  from huggingface_hub import hf_hub_download
14
+ import gradio as gr
15
 
16
  from musepose.models.pose_guider import PoseGuider
17
  from musepose.models.unet_2d_condition import UNet2DConditionModel
 
66
  seed: int,
67
  steps: int,
68
  fps: int,
69
+ skip: int,
70
+ gradio_progress=gr.Progress()
71
  ):
72
  download_models(model_dir=self.model_dir)
73
  print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
 
146
  denoising_unet=self.denoising_unet,
147
  pose_guider=self.pose_guider,
148
  scheduler=scheduler,
149
+ gradio_progress=gradio_progress
150
  )
151
  self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
152