|
import logging |
|
import os |
|
import torch |
|
import torch.distributed as dist |
|
from PIL import Image |
|
from datetime import datetime |
|
from tqdm import tqdm |
|
|
|
def generate(args): |
|
print("call generate") |
|
rank = int(os.getenv("RANK", 0)) |
|
world_size = int(os.getenv("WORLD_SIZE", 1)) |
|
local_rank = int(os.getenv("LOCAL_RANK", 0)) |
|
|
|
|
|
if args.t5_cpu or args.dit_fsdp: |
|
device = torch.device("cpu") |
|
print("Using CPU for model inference.") |
|
else: |
|
device = local_rank |
|
torch.cuda.set_device(local_rank) |
|
print(f"Using GPU: {device}") |
|
|
|
_init_logging(rank) |
|
|
|
|
|
if world_size > 1: |
|
dist.init_process_group( |
|
backend="nccl", |
|
init_method="env://", |
|
rank=rank, |
|
world_size=world_size) |
|
else: |
|
assert not ( |
|
args.t5_fsdp or args.dit_fsdp |
|
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
|
|
|
if args.ulysses_size > 1 or args.ring_size > 1: |
|
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." |
|
from xfuser.core.distributed import (initialize_model_parallel, |
|
init_distributed_environment) |
|
init_distributed_environment( |
|
rank=dist.get_rank(), world_size=dist.get_world_size()) |
|
|
|
initialize_model_parallel( |
|
sequence_parallel_degree=dist.get_world_size(), |
|
ring_degree=args.ring_size, |
|
ulysses_degree=args.ulysses_size, |
|
) |
|
|
|
|
|
if args.use_prompt_extend: |
|
if args.prompt_extend_method == "dashscope": |
|
prompt_expander = DashScopePromptExpander( |
|
model_name=args.prompt_extend_model, is_vl="i2v" in args.task) |
|
elif args.prompt_extend_method == "local_qwen": |
|
prompt_expander = QwenPromptExpander( |
|
model_name=args.prompt_extend_model, |
|
is_vl="i2v" in args.task, |
|
device=rank) |
|
else: |
|
raise NotImplementedError(f"Unsupported prompt_extend_method: {args.prompt_extend_method}") |
|
|
|
cfg = WAN_CONFIGS[args.task] |
|
print(f"Generation job args: {args}") |
|
print(f"Generation model config: {cfg}") |
|
|
|
|
|
if dist.is_initialized(): |
|
base_seed = [args.base_seed] if rank == 0 else [None] |
|
dist.broadcast_object_list(base_seed, src=0) |
|
args.base_seed = base_seed[0] |
|
|
|
|
|
if "t2v" in args.task or "t2i" in args.task: |
|
print("tect to x ") |
|
if args.prompt is None: |
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] |
|
print(f"Input prompt: {args.prompt}") |
|
|
|
if args.use_prompt_extend: |
|
logging.info("Extending prompt ...") |
|
if rank == 0: |
|
prompt_output = prompt_expander( |
|
args.prompt, |
|
tar_lang=args.prompt_extend_target_lang, |
|
seed=args.base_seed) |
|
if prompt_output.status == False: |
|
logging.info(f"Prompt extension failed: {prompt_output.message}") |
|
input_prompt = args.prompt |
|
else: |
|
input_prompt = prompt_output.prompt |
|
else: |
|
input_prompt = [None] |
|
if dist.is_initialized(): |
|
dist.broadcast_object_list(input_prompt, src=0) |
|
args.prompt = input_prompt[0] |
|
logging.info(f"Extended prompt: {args.prompt}") |
|
|
|
logging.info("Creating WanT2V pipeline.") |
|
wan_t2v = wan.WanT2V( |
|
config=cfg, |
|
checkpoint_dir=args.ckpt_dir, |
|
device_id=device, |
|
rank=rank, |
|
t5_fsdp=args.t5_fsdp, |
|
dit_fsdp=args.dit_fsdp, |
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
|
t5_cpu=args.t5_cpu, |
|
) |
|
|
|
print(f"Generating {'image' if 't2i' in args.task else 'video'} ...") |
|
try: |
|
video = wan_t2v.generate( |
|
args.prompt, |
|
size=SIZE_CONFIGS[args.size], |
|
frame_num=33, |
|
shift=args.sample_shift, |
|
sample_solver=args.sample_solver, |
|
sampling_steps=args.sample_steps, |
|
guide_scale=args.sample_guide_scale, |
|
seed=args.base_seed, |
|
offload_model=args.offload_model) |
|
except Exception as e: |
|
logging.error(f"Error during video generation: {e}") |
|
raise |
|
|
|
else: |
|
if args.prompt is None: |
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] |
|
if args.image is None: |
|
args.image = EXAMPLE_PROMPT[args.task]["image"] |
|
logging.info(f"Input prompt: {args.prompt}") |
|
logging.info(f"Input image: {args.image}") |
|
|
|
img = Image.open(args.image).convert("RGB") |
|
if args.use_prompt_extend: |
|
logging.info("Extending prompt ...") |
|
if rank == 0: |
|
prompt_output = prompt_expander( |
|
args.prompt, |
|
tar_lang=args.prompt_extend_target_lang, |
|
image=img, |
|
seed=args.base_seed) |
|
if prompt_output.status == False: |
|
logging.info(f"Prompt extension failed: {prompt_output.message}") |
|
input_prompt = args.prompt |
|
else: |
|
input_prompt = prompt_output.prompt |
|
else: |
|
input_prompt = [None] |
|
if dist.is_initialized(): |
|
dist.broadcast_object_list(input_prompt, src=0) |
|
args.prompt = input_prompt[0] |
|
logging.info(f"Extended prompt: {args.prompt}") |
|
|
|
logging.info("Creating WanI2V pipeline.") |
|
wan_i2v = wan.WanI2V( |
|
config=cfg, |
|
checkpoint_dir=args.ckpt_dir, |
|
device_id=device, |
|
rank=rank, |
|
t5_fsdp=args.t5_fsdp, |
|
dit_fsdp=args.dit_fsdp, |
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
|
t5_cpu=args.t5_cpu, |
|
) |
|
|
|
print("Generating video ..6666666666666666") |
|
try: |
|
video = wan_i2v.generate( |
|
args.prompt, |
|
img, |
|
max_area=MAX_AREA_CONFIGS[args.size], |
|
frame_num=33, |
|
shift=args.sample_shift, |
|
sample_solver=args.sample_solver, |
|
sampling_steps=args.sample_steps, |
|
guide_scale=args.sample_guide_scale, |
|
seed=args.base_seed, |
|
offload_model=args.offload_model) |
|
except Exception as e: |
|
logging.error(f"Error during video generation: {e}") |
|
raise |
|
|
|
|
|
if rank == 0: |
|
if args.save_file is None: |
|
|
|
args.save_file = f"generated_video.mp4" |
|
|
|
try: |
|
if "t2i" in args.task: |
|
logging.info(f"Saving generated image to {args.save_file}") |
|
cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True) |
|
else: |
|
logging.info(f"Saving generated video to {args.save_file}") |
|
cache_video(tensor=video, save_file=args.save_file) |
|
except Exception as e: |
|
logging.error(f"Error saving output: {e}") |
|
raise |
|
|