Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import logging | |
| import math | |
| from omegaconf import OmegaConf | |
| from datetime import datetime | |
| from pathlib import Path | |
| import numpy as np | |
| import torch.jit | |
| from torchvision.datasets.folder import pil_loader | |
| from torchvision.transforms.functional import pil_to_tensor, resize, center_crop | |
| from torchvision.transforms.functional import to_pil_image | |
| from mimicmotion.utils.geglu_patch import patch_geglu_inplace | |
| patch_geglu_inplace() | |
| from constants import ASPECT_RATIO | |
| from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline | |
| from mimicmotion.utils.loader import create_pipeline | |
| from mimicmotion.utils.utils import save_to_mp4 | |
| from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def preprocess(video_path, image_path, resolution=576, sample_stride=2): | |
| """preprocess ref image pose and video pose | |
| Args: | |
| video_path (str): input video pose path | |
| image_path (str): reference image path | |
| resolution (int, optional): Defaults to 576. | |
| sample_stride (int, optional): Defaults to 2. | |
| """ | |
| image_pixels = pil_loader(image_path) | |
| image_pixels = pil_to_tensor(image_pixels) # (c, h, w) | |
| h, w = image_pixels.shape[-2:] | |
| ############################ compute target h/w according to original aspect ratio ############################### | |
| if h>w: | |
| w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 | |
| else: | |
| w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution | |
| h_w_ratio = float(h) / float(w) | |
| if h_w_ratio < h_target / w_target: | |
| h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) | |
| else: | |
| h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target | |
| image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) | |
| image_pixels = center_crop(image_pixels, [h_target, w_target]) | |
| image_pixels = image_pixels.permute((1, 2, 0)).numpy() | |
| ##################################### get image&video pose value ################################################# | |
| image_pose = get_image_pose(image_pixels) | |
| video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride) | |
| pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) | |
| image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) | |
| return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1 | |
| def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config): | |
| image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5] | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(task_config.seed) | |
| frames = pipeline( | |
| image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0), | |
| tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap, | |
| height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7, | |
| noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps, | |
| generator=generator, min_guidance_scale=task_config.guidance_scale, | |
| max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device | |
| ).frames.cpu() | |
| video_frames = (frames * 255.0).to(torch.uint8) | |
| for vid_idx in range(video_frames.shape[0]): | |
| # deprecated first frame because of ref image | |
| _video_frames = video_frames[vid_idx, 1:] | |
| return _video_frames | |
| def main(args): | |
| if not args.no_use_float16 : | |
| torch.set_default_dtype(torch.float16) | |
| infer_config = OmegaConf.load(args.inference_config) | |
| pipeline = create_pipeline(infer_config, device) | |
| for task in infer_config.test_case: | |
| ############################################## Pre-process data ############################################## | |
| pose_pixels, image_pixels = preprocess( | |
| task.ref_video_path, task.ref_image_path, | |
| resolution=task.resolution, sample_stride=task.sample_stride | |
| ) | |
| ########################################### Run MimicMotion pipeline ########################################### | |
| _video_frames = run_pipeline( | |
| pipeline, | |
| image_pixels, pose_pixels, | |
| device, task | |
| ) | |
| ################################### save results to output folder. ########################################### | |
| save_to_mp4( | |
| _video_frames, | |
| f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}" \ | |
| f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4", | |
| fps=task.fps, | |
| ) | |
| def set_logger(log_file=None, log_level=logging.INFO): | |
| log_handler = logging.FileHandler(log_file, "w") | |
| log_handler.setFormatter( | |
| logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s") | |
| ) | |
| log_handler.setLevel(log_level) | |
| logger.addHandler(log_handler) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--log_file", type=str, default=None) | |
| parser.add_argument("--inference_config", type=str, default="configs/test.yaml") #ToDo | |
| parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output") | |
| parser.add_argument("--no_use_float16", | |
| action="store_true", | |
| help="Whether use float16 to speed up inference", | |
| ) | |
| args = parser.parse_args() | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| set_logger(args.log_file \ | |
| if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log") | |
| main(args) | |
| logger.info(f"--- Finished ---") | |