musepose / musepose_inference.py
jhj0517
increase duration for musepose inference
768f6bf
raw
history blame
No virus
9.94 kB
import os
from datetime import datetime
from pathlib import Path
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from einops import repeat
from omegaconf import OmegaConf, DictConfig
from PIL import Image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection
import torch.nn.functional as F
import gc
from huggingface_hub import hf_hub_download
import gradio as gr
from musepose.models.pose_guider import PoseGuider
from musepose.models.unet_2d_condition import UNet2DConditionModel
from musepose.models.unet_3d import UNet3DConditionModel
from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from musepose.utils.util import get_fps, read_frames, save_videos_grid
from downloading_weights import download_models
# ZeroGPU
import spaces
class MusePoseInference:
def __init__(self,
model_dir,
output_dir):
self.image_gen_model_paths = {
"pretrained_base_model": os.path.join(model_dir, "sd-image-variations-diffusers"),
"pretrained_vae": os.path.join(model_dir, "sd-vae-ft-mse"),
"image_encoder": os.path.join(model_dir, "image_encoder"),
}
self.musepose_model_paths = {
"denoising_unet": os.path.join(model_dir, "MusePose", "denoising_unet.pth"),
"reference_unet": os.path.join(model_dir, "MusePose", "reference_unet.pth"),
"pose_guider": os.path.join(model_dir, "MusePose", "pose_guider.pth"),
"motion_module": os.path.join(model_dir, "MusePose", "motion_module.pth"),
}
self.inference_config_path = os.path.join("configs", "inference_v2.yaml")
self.vae = None
self.reference_unet = None
self.denoising_unet = None
self.pose_guider = None
self.image_enc = None
self.pipe = None
self.model_dir = model_dir
self.output_dir = os.path.join(output_dir, "musepose_inference")
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
@spaces.GPU(duration=180)
def infer_musepose(
self,
ref_image_path: str,
pose_video_path: str,
weight_dtype: str,
W: int,
H: int,
L: int,
S: int,
O: int,
cfg: float,
seed: int,
steps: int,
fps: int,
skip: int,
gradio_progress=gr.Progress()
):
download_models(model_dir=self.model_dir)
print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
print(f"Input Image Path: {ref_image_path}")
print(f"Pose Video Path: {pose_video_path}")
print(f"Dtype: {weight_dtype}")
print(f"Width: {W}")
print(f"Height: {H}")
print(f"Video Frame Length: {L}")
print(f"VIDEO SLICE FRAME LENGTH:: {S}")
print(f"VIDEO SLICE OVERLAP_FRAME NUMBER: {O}")
print(f"CFG: {cfg}")
print(f"Seed: {seed}")
print(f"Steps: {steps}")
print(f"FPS: {fps}")
print(f"Skip: {skip}")
output_filename = f"output_temp"
output_path = os.path.abspath(os.path.join(self.output_dir, f'{output_filename}.mp4'))
output_path_demo = os.path.abspath(os.path.join(self.output_dir, f'{output_filename}_demo.mp4'))
if weight_dtype == "fp16":
weight_dtype = torch.float16
else:
weight_dtype = torch.float32
inference_config_path = self.inference_config_path
infer_config = OmegaConf.load(inference_config_path)
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
scheduler = DDIMScheduler(**sched_kwargs)
generator = torch.manual_seed(seed)
width, height = W, H
self.init_model(weight_dtype=weight_dtype, infer_config=infer_config)
self.pipe = Pose2VideoPipeline(
vae=self.vae,
image_encoder=self.image_enc,
reference_unet=self.reference_unet,
denoising_unet=self.denoising_unet,
pose_guider=self.pose_guider,
scheduler=scheduler,
gradio_progress=gradio_progress
)
self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
print("image: ", ref_image_path, "pose_video: ", pose_video_path)
ref_image_pil = Image.open(ref_image_path).convert("RGB")
pose_list = []
pose_tensor_list = []
pose_images = read_frames(pose_video_path)
src_fps = get_fps(pose_video_path)
print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
L = min(L, len(pose_images))
pose_transform = transforms.Compose(
[transforms.Resize((height, width)), transforms.ToTensor()]
)
original_width, original_height = 0, 0
pose_images = pose_images[::skip + 1]
print("processing length:", len(pose_images))
src_fps = src_fps // (skip + 1)
print("fps", src_fps)
L = L // ((skip + 1))
for pose_image_pil in pose_images[: L]:
pose_tensor_list.append(pose_transform(pose_image_pil))
pose_list.append(pose_image_pil)
original_width, original_height = pose_image_pil.size
pose_image_pil = pose_image_pil.resize((width, height))
# repeart the last segment
last_segment_frame_num = (L - S) % (S - O)
repeart_frame_num = (S - O - last_segment_frame_num) % (S - O)
for i in range(repeart_frame_num):
pose_list.append(pose_list[-1])
pose_tensor_list.append(pose_tensor_list[-1])
ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L)
pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
pose_tensor = pose_tensor.transpose(0, 1)
pose_tensor = pose_tensor.unsqueeze(0)
video = self.pipe(
ref_image_pil,
pose_list,
width,
height,
len(pose_list),
steps,
cfg,
generator=generator,
context_frames=S,
context_stride=1,
context_overlap=O,
).videos
result = self.scale_video(video[:, :, :L], original_width, original_height)
save_videos_grid(
result,
output_path,
n_rows=1,
fps=src_fps if fps is None or fps < 0 else fps,
)
video = torch.cat([ref_image_tensor, pose_tensor[:, :, :L], video[:, :, :L]], dim=0)
video = self.scale_video(video, original_width, original_height)
save_videos_grid(
video,
output_path_demo,
n_rows=3,
fps=src_fps if fps is None or fps < 0 else fps,
)
return output_path, output_path_demo
@spaces.GPU(duration=120)
def init_model(self,
weight_dtype: torch.dtype,
infer_config: DictConfig
):
if self.vae is None:
self.vae = AutoencoderKL.from_pretrained(
self.image_gen_model_paths["pretrained_vae"],
).to("cuda", dtype=weight_dtype)
if self.reference_unet is None:
self.reference_unet = UNet2DConditionModel.from_pretrained(
self.image_gen_model_paths["pretrained_base_model"],
subfolder="unet",
).to(dtype=weight_dtype, device="cuda")
self.reference_unet.load_state_dict(
torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"),
)
if self.denoising_unet is None:
self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
Path(self.image_gen_model_paths["pretrained_base_model"]),
Path(self.musepose_model_paths["motion_module"]),
subfolder="unet",
unet_additional_kwargs=infer_config.unet_additional_kwargs,
).to(dtype=weight_dtype, device="cuda")
self.denoising_unet.load_state_dict(
torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"),
strict=False,
)
if self.pose_guider is None:
self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
dtype=weight_dtype, device="cuda"
)
self.pose_guider.load_state_dict(
torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"),
)
if self.image_enc is None:
self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
self.image_gen_model_paths["image_encoder"]
).to(dtype=weight_dtype, device="cuda")
def release_vram(self):
models = [
'vae', 'reference_unet', 'denoising_unet',
'pose_guider', 'image_enc', 'pipe'
]
for model_name in models:
model = getattr(self, model_name, None)
if model is not None:
del model
setattr(self, model_name, None)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def scale_video(video, width, height):
video_reshaped = video.view(-1, *video.shape[2:]) # [batch*frames, channels, height, width]
scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False)
scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height,
width) # [batch, frames, channels, height, width]
return scaled_video