jhj0517
commited on
Commit
•
de2727e
1
Parent(s):
da162f8
add progress for musepose inference
Browse files
musepose/pipelines/pipeline_pose2vid_long.py
CHANGED
@@ -3,6 +3,7 @@ import inspect
|
|
3 |
import math
|
4 |
from dataclasses import dataclass
|
5 |
from typing import Callable, List, Optional, Union
|
|
|
6 |
|
7 |
import numpy as np
|
8 |
import torch
|
@@ -53,6 +54,7 @@ class Pose2VideoPipeline(DiffusionPipeline):
|
|
53 |
image_proj_model=None,
|
54 |
tokenizer=None,
|
55 |
text_encoder=None,
|
|
|
56 |
):
|
57 |
super().__init__()
|
58 |
|
@@ -77,6 +79,7 @@ class Pose2VideoPipeline(DiffusionPipeline):
|
|
77 |
do_convert_rgb=True,
|
78 |
do_normalize=False,
|
79 |
)
|
|
|
80 |
|
81 |
def enable_vae_slicing(self):
|
82 |
self.vae.enable_slicing()
|
@@ -117,6 +120,8 @@ class Pose2VideoPipeline(DiffusionPipeline):
|
|
117 |
video = []
|
118 |
for frame_idx in tqdm(range(latents.shape[0])):
|
119 |
video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
|
|
|
|
|
120 |
video = torch.cat(video)
|
121 |
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
122 |
video = (video / 2 + 0.5).clamp(0, 1)
|
@@ -448,6 +453,8 @@ class Pose2VideoPipeline(DiffusionPipeline):
|
|
448 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
449 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
450 |
for i, t in enumerate(timesteps):
|
|
|
|
|
451 |
noise_pred = torch.zeros(
|
452 |
(
|
453 |
latents.shape[0] * (2 if do_classifier_free_guidance else 1),
|
|
|
3 |
import math
|
4 |
from dataclasses import dataclass
|
5 |
from typing import Callable, List, Optional, Union
|
6 |
+
import gradio as gr
|
7 |
|
8 |
import numpy as np
|
9 |
import torch
|
|
|
54 |
image_proj_model=None,
|
55 |
tokenizer=None,
|
56 |
text_encoder=None,
|
57 |
+
gradio_progress: gr.Progress = gr.Progress(),
|
58 |
):
|
59 |
super().__init__()
|
60 |
|
|
|
79 |
do_convert_rgb=True,
|
80 |
do_normalize=False,
|
81 |
)
|
82 |
+
self.gradio_progress = gradio_progress
|
83 |
|
84 |
def enable_vae_slicing(self):
|
85 |
self.vae.enable_slicing()
|
|
|
120 |
video = []
|
121 |
for frame_idx in tqdm(range(latents.shape[0])):
|
122 |
video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
|
123 |
+
self.gradio_progress(frame_idx /latents.shape[0], "Writing Video..")
|
124 |
+
|
125 |
video = torch.cat(video)
|
126 |
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
127 |
video = (video / 2 + 0.5).clamp(0, 1)
|
|
|
453 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
454 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
455 |
for i, t in enumerate(timesteps):
|
456 |
+
self.gradio_progress(i/num_inference_steps, "Making Image Move..")
|
457 |
+
|
458 |
noise_pred = torch.zeros(
|
459 |
(
|
460 |
latents.shape[0] * (2 if do_classifier_free_guidance else 1),
|
musepose_inference.py
CHANGED
@@ -11,6 +11,7 @@ from transformers import CLIPVisionModelWithProjection
|
|
11 |
import torch.nn.functional as F
|
12 |
import gc
|
13 |
from huggingface_hub import hf_hub_download
|
|
|
14 |
|
15 |
from musepose.models.pose_guider import PoseGuider
|
16 |
from musepose.models.unet_2d_condition import UNet2DConditionModel
|
@@ -65,7 +66,8 @@ class MusePoseInference:
|
|
65 |
seed: int,
|
66 |
steps: int,
|
67 |
fps: int,
|
68 |
-
skip: int
|
|
|
69 |
):
|
70 |
download_models(model_dir=self.model_dir)
|
71 |
print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
|
@@ -144,6 +146,7 @@ class MusePoseInference:
|
|
144 |
denoising_unet=self.denoising_unet,
|
145 |
pose_guider=self.pose_guider,
|
146 |
scheduler=scheduler,
|
|
|
147 |
)
|
148 |
self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
|
149 |
|
|
|
11 |
import torch.nn.functional as F
|
12 |
import gc
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
+
import gradio as gr
|
15 |
|
16 |
from musepose.models.pose_guider import PoseGuider
|
17 |
from musepose.models.unet_2d_condition import UNet2DConditionModel
|
|
|
66 |
seed: int,
|
67 |
steps: int,
|
68 |
fps: int,
|
69 |
+
skip: int,
|
70 |
+
gradio_progress=gr.Progress()
|
71 |
):
|
72 |
download_models(model_dir=self.model_dir)
|
73 |
print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
|
|
|
146 |
denoising_unet=self.denoising_unet,
|
147 |
pose_guider=self.pose_guider,
|
148 |
scheduler=scheduler,
|
149 |
+
gradio_progress=gradio_progress
|
150 |
)
|
151 |
self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
|
152 |
|