# Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import datetime import inspect import os import random import numpy as np from PIL import Image from omegaconf import OmegaConf from collections import OrderedDict import torch import torch.distributed as dist from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler from tqdm import tqdm from transformers import CLIPTextModel, CLIPTokenizer from magicanimate.models.unet_controlnet import UNet3DConditionModel from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid from magicanimate.utils.dist_tools import distributed_init from accelerate.utils import set_seed from magicanimate.utils.videoreader import VideoReader from einops import rearrange from pathlib import Path def main(args): *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) config = OmegaConf.load(args.config) # Initialize distributed training device = torch.device(f"cuda:{args.rank}") dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} if config.savename is None: time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" else: savedir = f"samples/{config.savename}" if args.dist: dist.broadcast_object_list([savedir], 0) dist.barrier() if args.rank == 0: os.makedirs(savedir, exist_ok=True) inference_config = OmegaConf.load(config.inference_config) motion_module = config.motion_module ### >>> create animation pipeline >>> ### tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") if config.pretrained_unet_path: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) else: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) if config.pretrained_vae_path is not None: vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) else: vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") ### Load controlnet controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) unet.enable_xformers_memory_efficient_attention() appearance_encoder.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() vae.to(torch.float16) unet.to(torch.float16) text_encoder.to(torch.float16) appearance_encoder.to(torch.float16) controlnet.to(torch.float16) pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), # NOTE: UniPCMultistepScheduler ) # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(motion_module, map_location="cpu") if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict try: # extra steps for self-trained models state_dict = OrderedDict() for key in motion_module_state_dict.keys(): if key.startswith("module."): _key = key.split("module.")[-1] state_dict[_key] = motion_module_state_dict[key] else: state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): _key = key.split('unet.')[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] missing, unexpected = unet.load_state_dict(_tmp_, strict=False) assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict pipeline.to(device) ### <<< create validation pipeline <<< ### random_seeds = config.get("seed", [-1]) random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds # input test videos (either source video/ conditions) test_videos = config.video_path source_images = config.source_image num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) # read size, step from yaml file sizes = [config.size] * len(test_videos) steps = [config.S] * len(test_videos) config.random_seed = [] prompt = n_prompt = "" for idx, (source_image, test_video, random_seed, size, step) in tqdm( enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), total=len(test_videos), disable=(args.rank!=0) ): samples_per_video = [] samples_per_clip = [] # manually set random seed for reproduction if random_seed != -1: torch.manual_seed(random_seed) set_seed(random_seed) else: torch.seed() config.random_seed.append(torch.initial_seed()) if test_video.endswith('.mp4'): control = VideoReader(test_video).read() if control[0].shape[0] != size: control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] if config.max_length is not None: control = control[config.offset: (config.offset+config.max_length)] control = np.array(control) if source_image.endswith(".mp4"): source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) else: source_image = np.array(Image.open(source_image).resize((size, size))) H, W, C = source_image.shape print(f"current seed: {torch.initial_seed()}") init_latents = None # print(f"sampling {prompt} ...") original_length = control.shape[0] if control.shape[0] % config.L > 0: control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') generator = torch.Generator(device=torch.device("cuda:0")) generator.manual_seed(torch.initial_seed()) sample = pipeline( prompt, negative_prompt = n_prompt, num_inference_steps = config.steps, guidance_scale = config.guidance_scale, width = W, height = H, video_length = len(control), controlnet_condition = control, init_latents = init_latents, generator = generator, num_actual_inference_steps = num_actual_inference_steps, appearance_encoder = appearance_encoder, reference_control_writer = reference_control_writer, reference_control_reader = reference_control_reader, source_image = source_image, **dist_kwargs, ).videos if args.rank == 0: source_images = np.array([source_image] * original_length) source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 samples_per_video.append(source_images) control = control / 255.0 control = rearrange(control, "t h w c -> 1 c t h w") control = torch.from_numpy(control) samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) samples_per_video = torch.cat(samples_per_video) video_name = os.path.basename(test_video)[:-4] source_name = os.path.basename(config.source_image[idx]).split(".")[0] save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") if config.save_individual_videos: save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") if args.dist: dist.barrier() if args.rank == 0: OmegaConf.save(config, f"{savedir}/config.yaml") def distributed_main(device_id, args): args.rank = device_id args.device_id = device_id if torch.cuda.is_available(): torch.cuda.set_device(args.device_id) torch.cuda.init() distributed_init(args) main(args) def run(args): if args.dist: args.world_size = max(1, torch.cuda.device_count()) assert args.world_size <= torch.cuda.device_count() if args.world_size > 0 and torch.cuda.device_count() > 1: port = random.randint(10000, 20000) args.init_method = f"tcp://localhost:{port}" torch.multiprocessing.spawn( fn=distributed_main, args=(args,), nprocs=args.world_size, ) else: main(args) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--dist", action="store_true", required=False) parser.add_argument("--rank", type=int, default=0, required=False) parser.add_argument("--world_size", type=int, default=1, required=False) args = parser.parse_args() run(args)