File size: 26,215 Bytes
f08eddf be29b01 f08eddf be29b01 f08eddf be29b01 f08eddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 |
import os
import time
import random
import functools
from typing import List, Optional, Tuple, Union
from pathlib import Path
from loguru import logger
import torch
import torch.distributed as dist
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
from hyvideo.vae import load_vae
from hyvideo.modules import load_model
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.data_utils import align_to
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
from hyvideo.modules.fp8_optimization import convert_fp8_linear
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
try:
import xfuser
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
initialize_model_parallel,
init_distributed_environment
)
except:
xfuser = None
get_sequence_parallel_world_size = None
get_sequence_parallel_rank = None
get_sp_group = None
initialize_model_parallel = None
init_distributed_environment = None
def parallelize_transformer(pipe):
transformer = pipe.transformer
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
):
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
# try to split x by height
split_dim = -2
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
# try to split x by width
split_dim = -1
else:
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
# patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
dim_thw = freqs_cos.shape[-1]
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
freqs_cos = freqs_cos.reshape(-1, dim_thw)
dim_thw = freqs_sin.shape[-1]
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
freqs_sin = freqs_sin.reshape(-1, dim_thw)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
for block in transformer.double_blocks + transformer.single_blocks:
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
output = original_forward(
x,
t,
text_states,
text_mask,
text_states_2,
freqs_cos,
freqs_sin,
guidance,
return_dict,
)
return_dict = not isinstance(output, tuple)
sample = output["x"]
sample = get_sp_group().all_gather(sample, dim=split_dim)
output["x"] = sample
return output
new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward
class Inference(object):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=None,
logger=None,
parallel_args=None,
):
self.vae = vae
self.vae_kwargs = vae_kwargs
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.model = model
self.pipeline = pipeline
self.use_cpu_offload = use_cpu_offload
self.args = args
self.device = (
device
if device is not None
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
self.logger = logger
self.parallel_args = parallel_args
@classmethod
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
"""
Initialize the Inference pipeline.
Args:
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
args (argparse.Namespace): The arguments for the pipeline.
device (int): The device for inference. Default is 0.
"""
# ========================================================================
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
# ==================== Initialize Distributed Environment ================
if args.ulysses_degree > 1 or args.ring_degree > 1:
assert xfuser is not None, \
"Ulysses Attention and Ring Attention requires xfuser package."
assert args.use_cpu_offload is False, \
"Cannot enable use_cpu_offload in the distributed environment."
dist.init_process_group("nccl")
assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
"number of GPUs should be equal to ring_degree * ulysses_degree."
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
)
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
# ======================== Get the args path =============================
# Disable gradient
torch.set_grad_enabled(False)
# =========================== Build main model ===========================
logger.info("Building model...")
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
in_channels = args.latent_channels
out_channels = args.latent_channels
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
if args.use_fp8:
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
model.eval()
# ============================= Build extra models ========================
# VAE
vae, _, s_ratio, t_ratio = load_vae(
args.vae,
args.vae_precision,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
# Text encoder
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
"crop_start", 0
)
elif args.prompt_template is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
# prompt_template
prompt_template = (
PROMPT_TEMPLATE[args.prompt_template]
if args.prompt_template is not None
else None
)
# prompt_template_video
prompt_template_video = (
PROMPT_TEMPLATE[args.prompt_template_video]
if args.prompt_template_video is not None
else None
)
text_encoder = TextEncoder(
text_encoder_type=args.text_encoder,
max_length=max_length,
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
return cls(
args=args,
vae=vae,
vae_kwargs=vae_kwargs,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
model=model,
use_cpu_offload=args.use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args
)
@staticmethod
def load_state_dict(args, model, pretrained_model_path):
load_key = args.load_key
dit_weight = Path(args.dit_weight)
if dit_weight is None:
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
files = list(model_dir.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {model_dir}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(
f"Multiple model weights found in {dit_weight}, using {model_path}"
)
bare_model = False
else:
raise ValueError(
f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file."
)
else:
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(
f"Multiple model weights found in {dit_weight}, using {model_path}"
)
bare_model = False
else:
raise ValueError(
f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file."
)
elif dit_weight.is_file():
model_path = dit_weight
bare_model = "unknown"
else:
raise ValueError(f"Invalid model path: {dit_weight}")
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}."
)
model.load_state_dict(state_dict, strict=True)
return model
@staticmethod
def parse_size(size):
if isinstance(size, int):
size = [size]
if not isinstance(size, (list, tuple)):
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
if len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
return size
class HunyuanVideoSampler(Inference):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=0,
logger=None,
parallel_args=None
):
super().__init__(
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=text_encoder_2,
pipeline=pipeline,
use_cpu_offload=use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args
)
self.pipeline = self.load_diffusion_pipeline(
args=args,
vae=self.vae,
text_encoder=self.text_encoder,
text_encoder_2=self.text_encoder_2,
model=self.model,
device=self.device,
)
self.default_negative_prompt = NEGATIVE_PROMPT
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
parallelize_transformer(self.pipeline)
def load_diffusion_pipeline(
self,
args,
vae,
text_encoder,
text_encoder_2,
model,
scheduler=None,
device=None,
progress_bar_config=None,
data_type="video",
):
"""Load the denoising scheduler for inference."""
if scheduler is None:
if args.denoise_type == "flow":
scheduler = FlowMatchDiscreteScheduler(
shift=args.flow_shift,
reverse=args.flow_reverse,
solver=args.flow_solver,
)
else:
raise ValueError(f"Invalid denoise type {args.denoise_type}")
pipeline = HunyuanVideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
transformer=model,
scheduler=scheduler,
progress_bar_config=progress_bar_config,
args=args,
)
if self.use_cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to(device)
return pipeline
def get_rotary_pos_embed(self, video_length, height, width):
target_ndim = 3
ndim = 5 - 2
# 884
if "884" in self.args.vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
elif "888" in self.args.vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
else:
latents_size = [video_length, height // 8, width // 8]
if isinstance(self.model.patch_size, int):
assert all(s % self.model.patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // self.model.patch_size for s in latents_size]
elif isinstance(self.model.patch_size, list):
assert all(
s % self.model.patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = self.model.hidden_size // self.model.heads_num
rope_dim_list = self.model.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.args.rope_theta,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
@torch.no_grad()
def predict(
self,
prompt,
height=192,
width=336,
video_length=129,
seed=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6,
flow_shift=5.0,
embedded_guidance_scale=None,
batch_size=1,
num_videos_per_prompt=1,
**kwargs,
):
"""
Predict the image/video from the given text.
Args:
prompt (str or List[str]): The input text.
kwargs:
height (int): The height of the output video. Default is 192.
width (int): The width of the output video. Default is 336.
video_length (int): The frame number of the output video. Default is 129.
seed (int or List[str]): The random seed for the generation. Default is a random integer.
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
num_images_per_prompt (int): The number of images per prompt. Default is 1.
infer_steps (int): The number of inference steps. Default is 100.
"""
out_dict = dict()
# ========================================================================
# Arguments: seed
# ========================================================================
if isinstance(seed, torch.Tensor):
seed = seed.tolist()
if seed is None:
seeds = [
random.randint(0, 1_000_000)
for _ in range(batch_size * num_videos_per_prompt)
]
elif isinstance(seed, int):
seeds = [
seed + i
for _ in range(batch_size)
for i in range(num_videos_per_prompt)
]
elif isinstance(seed, (list, tuple)):
if len(seed) == batch_size:
seeds = [
int(seed[i]) + j
for i in range(batch_size)
for j in range(num_videos_per_prompt)
]
elif len(seed) == batch_size * num_videos_per_prompt:
seeds = [int(s) for s in seed]
else:
raise ValueError(
f"Length of seed must be equal to number of prompt(batch_size) or "
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
)
else:
raise ValueError(
f"Seed must be an integer, a list of integers, or None, got {seed}."
)
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
out_dict["seeds"] = seeds
# ========================================================================
# Arguments: target_width, target_height, target_video_length
# ========================================================================
if width <= 0 or height <= 0 or video_length <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
)
if (video_length - 1) % 4 != 0:
raise ValueError(
f"`video_length-1` must be a multiple of 4, got {video_length}"
)
logger.info(
f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
)
target_height = align_to(height, 16)
target_width = align_to(width, 16)
target_video_length = video_length
out_dict["size"] = (target_height, target_width, target_video_length)
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
prompt = [prompt.strip()]
# negative prompt
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
negative_prompt = [negative_prompt.strip()]
# ========================================================================
# Scheduler
# ========================================================================
scheduler = FlowMatchDiscreteScheduler(
shift=flow_shift,
reverse=self.args.flow_reverse,
solver=self.args.flow_solver
)
self.pipeline.scheduler = scheduler
# ========================================================================
# Build Rope freqs
# ========================================================================
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
target_video_length, target_height, target_width
)
n_tokens = freqs_cos.shape[0]
# ========================================================================
# Print infer args
# ========================================================================
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
num_videos_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}"""
logger.debug(debug_str)
# ========================================================================
# Pipeline inference
# ========================================================================
start_time = time.time()
samples = self.pipeline(
prompt=prompt,
height=target_height,
width=target_width,
video_length=target_video_length,
num_inference_steps=infer_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
generator=generator,
output_type="pil",
freqs_cis=(freqs_cos, freqs_sin),
n_tokens=n_tokens,
embedded_guidance_scale=embedded_guidance_scale,
data_type="video" if target_video_length > 1 else "image",
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
)[0]
out_dict["samples"] = samples
out_dict["prompts"] = prompt
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
return out_dict
|