| | |
| | import logging |
| | import os |
| | import sys |
| | import warnings |
| |
|
| | warnings.filterwarnings('ignore') |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from easydict import EasyDict |
| | from torchvision import transforms |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data.distributed import DistributedSampler |
| |
|
| | from diffusers_lite import wan |
| | from diffusers_lite.wan.configs import WAN_CONFIGS, MAX_AREA_CONFIGS, SIZE_CONFIGS |
| | from diffusers_lite.wan.utils.utils import cache_video |
| | from diffusers_lite.arguments import args_wan_init |
| | from diffusers_lite.datasets.image2video_dataset import Image2VideoEvalDataset |
| |
|
| |
|
| | def _init_logging(rank): |
| | if rank == 0: |
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="[%(asctime)s] %(levelname)s: %(message)s", |
| | handlers=[logging.StreamHandler(stream=sys.stdout)]) |
| | else: |
| | logging.basicConfig(level=logging.ERROR) |
| |
|
| |
|
| | def basic_init(args): |
| | rank = int(os.getenv("RANK", 0)) |
| | world_size = int(os.getenv("WORLD_SIZE", 1)) |
| | local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| | device = local_rank |
| | _init_logging(rank) |
| |
|
| | if rank == 0: |
| | os.makedirs(args.save_folder, exist_ok=True) |
| | logging.info(f"Creating save directory: {args.save_folder}") |
| |
|
| | if args.offload_model is None: |
| | args.offload_model = False if world_size > 1 else True |
| | logging.info( |
| | f"offload_model is not specified, set to {args.offload_model}.") |
| | |
| | if args.ulysses_size == 1 and args.ring_size == 1: |
| | args.ddp_mode = True |
| | |
| | |
| | logging.info(f"DDP mode enabled.") |
| |
|
| | if world_size > 1: |
| | torch.cuda.set_device(local_rank) |
| | dist.init_process_group( |
| | backend="nccl", |
| | init_method="env://", |
| | rank=rank, |
| | world_size=world_size) |
| | else: |
| | assert not ( |
| | args.t5_fsdp or args.dit_fsdp |
| | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
| | assert not ( |
| | args.ulysses_size > 1 or args.ring_size > 1 |
| | ), f"context parallel are not supported in non-distributed environments." |
| |
|
| | if args.ulysses_size > 1 or args.ring_size > 1: |
| | assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." |
| | from xfuser.core.distributed import (initialize_model_parallel, |
| | init_distributed_environment) |
| | init_distributed_environment( |
| | rank=dist.get_rank(), world_size=dist.get_world_size()) |
| |
|
| | initialize_model_parallel( |
| | sequence_parallel_degree=dist.get_world_size(), |
| | ring_degree=args.ring_size, |
| | ulysses_degree=args.ulysses_size, |
| | ) |
| | |
| |
|
| | cfg = WAN_CONFIGS[args.task] |
| | |
| | if args.ulysses_size > 1: |
| | assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." |
| |
|
| | logging.info(f"Generation job args: {args}") |
| | logging.info(f"Generation model config: {cfg}") |
| |
|
| | if dist.is_initialized(): |
| | base_seed = [args.base_seed] if rank == 0 else [None] |
| | dist.broadcast_object_list(base_seed, src=0) |
| | args.base_seed = base_seed[0] |
| |
|
| | |
| |
|
| | basic_kwargs = EasyDict({ |
| | "rank": rank, |
| | "local_rank": local_rank, |
| | "world_size": world_size, |
| | "device": device, |
| | "cfg": cfg, |
| | }) |
| | return basic_kwargs |
| |
|
| |
|
| | def dataset_init(args, basic_kwargs): |
| | dataset = Image2VideoEvalDataset( |
| | args.dataset_path, |
| | do_scale=True, |
| | resolution=SIZE_CONFIGS[args.size] |
| | ) |
| | logging.info(f"Dataset length: {len(dataset)}") |
| | |
| | if args.ddp_mode: |
| | sampler = DistributedSampler( |
| | dataset, |
| | num_replicas=basic_kwargs.world_size, |
| | rank=basic_kwargs.rank, |
| | shuffle=False, |
| | drop_last=False, |
| | ) |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | sampler=sampler, |
| | drop_last=False |
| | ) |
| | dataset = dataloader |
| | |
| | return dataset |
| |
|
| |
|
| | def pipeline_t2v_init(args, basic_kwargs): |
| | logging.info("Creating WanT2V pipeline.") |
| | wan_t2v = wan.WanT2V( |
| | config=basic_kwargs.cfg, |
| | checkpoint_dir=args.ckpt_dir, |
| | transformer_path=args.transformer_path, |
| | lora_path=args.lora_path, |
| | lora_alpha=args.lora_alpha, |
| | distill_lora_path=args.distill_lora_path, |
| | distill_lora_alpha=args.distill_lora_alpha, |
| | device_id=basic_kwargs.device, |
| | rank=basic_kwargs.rank, |
| | t5_fsdp=args.t5_fsdp, |
| | dit_fsdp=args.dit_fsdp, |
| | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
| | t5_cpu=args.t5_cpu, |
| | teacache_thresh=args.teacache_thresh, |
| | sample_steps=args.sample_steps, |
| | ckpt_dir=args.ckpt_dir, |
| | ) |
| |
|
| | return wan_t2v |
| |
|
| |
|
| | def pipeline_i2v_init(args, basic_kwargs): |
| | logging.info("Creating WanI2V pipeline.") |
| | wan_i2v = wan.WanI2V( |
| | config=basic_kwargs.cfg, |
| | checkpoint_dir=args.ckpt_dir, |
| | transformer_path=args.transformer_path, |
| | lora_path=args.lora_path, |
| | lora_alpha=args.lora_alpha, |
| | distill_lora_path=args.distill_lora_path, |
| | distill_lora_alpha=args.distill_lora_alpha, |
| | device_id=basic_kwargs.device, |
| | rank=basic_kwargs.rank, |
| | t5_fsdp=args.t5_fsdp, |
| | dit_fsdp=args.dit_fsdp, |
| | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
| | t5_cpu=args.t5_cpu, |
| | teacache_thresh=args.teacache_thresh, |
| | sample_steps=args.sample_steps, |
| | ckpt_dir=args.ckpt_dir, |
| | ) |
| |
|
| | return wan_i2v |
| |
|
| |
|
| | def pipeline_flf2v_init(args, basic_kwargs): |
| | logging.info("Creating WanFLF2V pipeline.") |
| | wan_flf2v = wan.WanFLF2V( |
| | config=basic_kwargs.cfg, |
| | checkpoint_dir=args.ckpt_dir, |
| | transformer_path=args.transformer_path, |
| | lora_path=args.lora_path, |
| | lora_alpha=args.lora_alpha, |
| | distill_lora_path=args.distill_lora_path, |
| | distill_lora_alpha=args.distill_lora_alpha, |
| | device_id=basic_kwargs.device, |
| | rank=basic_kwargs.rank, |
| | t5_fsdp=args.t5_fsdp, |
| | dit_fsdp=args.dit_fsdp, |
| | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
| | t5_cpu=args.t5_cpu, |
| | teacache_thresh=args.teacache_thresh, |
| | sample_steps=args.sample_steps, |
| | ckpt_dir=args.ckpt_dir, |
| | ) |
| |
|
| | return wan_flf2v |
| |
|
| |
|
| | def inference_t2v_loop(args, pipeline, batch): |
| | if args.ddp_mode: |
| | prompt = batch["prompt"][0] |
| | image_id = batch["image_id"][0] |
| | else: |
| | prompt = batch["prompt"] |
| | image_id = batch["image_id"] |
| | |
| |
|
| | info_str = f""" |
| | height: {args.resolution[1]} |
| | width: {args.resolution[0]} |
| | video_length: {args.frame_num} |
| | prompt: {prompt} |
| | neg_prompt: {args.negative_prompt} |
| | seed: {int(batch["seed"])} |
| | infer_steps: {args.sample_steps} |
| | guidance_scale: {args.sample_guide_scale} |
| | flow_shift: {args.sample_shift}""" |
| | logging.info(info_str) |
| |
|
| | video = pipeline.generate( |
| | prompt, |
| | n_prompt=args.negative_prompt, |
| | size=args.resolution, |
| | frame_num=args.frame_num, |
| | shift=args.sample_shift, |
| | sample_solver=args.sample_solver, |
| | sampling_steps=args.sample_steps, |
| | guide_scale=args.sample_guide_scale, |
| | seed=int(batch["seed"]), |
| | |
| | offload_model=args.offload_model, |
| | ddp_mode=args.ddp_mode, |
| | ) |
| |
|
| | return video, image_id |
| |
|
| |
|
| | def inference_i2v_loop(args, pipeline, batch): |
| | if args.ddp_mode: |
| | prompt = batch["prompt"][0] |
| | image_id = batch["image_id"][0] |
| | cond_image = transforms.ToPILImage()(batch["image"][0]) |
| | else: |
| | prompt = batch["prompt"] |
| | image_id = batch["image_id"] |
| | cond_image = transforms.ToPILImage()(batch["image"]) |
| |
|
| | width, height = cond_image.size[0], cond_image.size[1] |
| |
|
| | info_str = f""" |
| | height: {height} |
| | width: {width} |
| | current_araa: {height} * {width} |
| | max_area: {MAX_AREA_CONFIGS[args.size]} |
| | video_length: {args.frame_num} |
| | prompt: {prompt} |
| | neg_prompt: {args.negative_prompt} |
| | seed: {int(batch["seed"])} |
| | infer_steps: {args.sample_steps} |
| | guidance_scale: {args.sample_guide_scale} |
| | flow_shift: {args.sample_shift}""" |
| | logging.info(info_str) |
| |
|
| | video = pipeline.generate( |
| | prompt, |
| | cond_image, |
| | n_prompt=args.negative_prompt, |
| | max_area=MAX_AREA_CONFIGS[args.size], |
| | frame_num=args.frame_num, |
| | shift=args.sample_shift, |
| | sample_solver=args.sample_solver, |
| | sampling_steps=args.sample_steps, |
| | guide_scale=args.sample_guide_scale, |
| | |
| | seed=int(batch["seed"]), |
| | offload_model=args.offload_model, |
| | ddp_mode=args.ddp_mode, |
| | ) |
| |
|
| | return video, image_id |
| |
|
| |
|
| | def inference_flf2v_loop(args, pipeline, batch): |
| | if args.ddp_mode: |
| | prompt = batch["prompt"][0] |
| | image_id = batch["image_id"][0] |
| | cond_image = transforms.ToPILImage()(batch["image"][0]) |
| | last_image = transforms.ToPILImage()(batch["last_image"][0]) |
| | else: |
| | prompt = batch["prompt"] |
| | image_id = batch["image_id"] |
| | cond_image = transforms.ToPILImage()(batch["image"]) |
| | last_image = transforms.ToPILImage()(batch["last_image"]) |
| | width, height = cond_image.size[0], cond_image.size[1] |
| |
|
| | info_str = f""" |
| | height: {height} |
| | width: {width} |
| | max_area: {MAX_AREA_CONFIGS[args.size]} |
| | video_length: {args.frame_num} |
| | prompt: {prompt} |
| | neg_prompt: {args.negative_prompt} |
| | seed: {args.base_seed} |
| | infer_steps: {args.sample_steps} |
| | guidance_scale: {args.sample_guide_scale} |
| | flow_shift: {args.sample_shift}""" |
| | logging.info(info_str) |
| |
|
| | video = pipeline.generate( |
| | prompt, |
| | cond_image, |
| | last_image, |
| | n_prompt=args.negative_prompt, |
| | max_area=MAX_AREA_CONFIGS[args.size], |
| | frame_num=args.frame_num, |
| | shift=args.sample_shift, |
| | sample_solver=args.sample_solver, |
| | sampling_steps=args.sample_steps, |
| | guide_scale=args.sample_guide_scale, |
| | seed=args.base_seed, |
| | offload_model=args.offload_model, |
| | ddp_mode=args.ddp_mode, |
| | ) |
| |
|
| | return video, image_id |
| |
|
| |
|
| | def main(args): |
| |
|
| | basic_kwargs = basic_init(args) |
| | dataset = dataset_init(args, basic_kwargs) |
| |
|
| | if "t2v" in args.task: |
| | pipeline = pipeline_t2v_init(args, basic_kwargs) |
| | elif "i2v" in args.task: |
| | pipeline = pipeline_i2v_init(args, basic_kwargs) |
| | elif "flf2v" in args.task: |
| | pipeline = pipeline_flf2v_init(args, basic_kwargs) |
| |
|
| | for i, batch in enumerate(dataset): |
| | image_id = batch["image_id"][0] |
| | save_path = os.path.join(args.save_folder, f"{image_id}.mp4") |
| | if os.path.exists(save_path): |
| | continue |
| | else: |
| | if "t2v" in args.task: |
| | video, image_id = inference_t2v_loop( |
| | args, pipeline, batch |
| | ) |
| | elif "i2v" in args.task: |
| | video, image_id = inference_i2v_loop( |
| | args, pipeline, batch |
| | ) |
| | elif "flf2v" in args.task: |
| | video, image_id = inference_flf2v_loop( |
| | args, pipeline, batch |
| | ) |
| |
|
| | if basic_kwargs.rank == 0 or args.ddp_mode: |
| | save_path = os.path.join(args.save_folder, f"{image_id}.mp4") |
| | cache_video( |
| | tensor=video[None], |
| | save_file=save_path, |
| | fps=basic_kwargs.cfg.sample_fps, |
| | nrow=1, |
| | normalize=True, |
| | value_range=(-1, 1) |
| | ) |
| |
|
| | logging.info(f"Saving generated video to {save_path}") |
| |
|
| | logging.info("Finished.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = args_wan_init() |
| | main(args) |
| |
|