jbilcke-hf's picture
jbilcke-hf HF staff
Rename handler_legacy.py to handler.py
3c589b4 verified
from typing import Dict, Any
import os
import shutil
from pathlib import Path
import time
from datetime import datetime
import argparse
from loguru import logger
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.inference import HunyuanVideoSampler
from hyvideo.constants import NEGATIVE_PROMPT
# Configure logger
logger.add("handler_debug.log", rotation="500 MB")
DEFAULT_RESOLUTION = "720p"
DEFAULT_WIDTH = 1280
DEFAULT_HEIGHT = 720
DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 frame)
DEFAULT_NB_STEPS = 22 # or 50
DEFAULT_FPS = 24
def setup_vae_path(vae_path: Path) -> Path:
"""Create a temporary directory with correctly named VAE config file"""
tmp_vae_dir = Path("/tmp/vae")
if tmp_vae_dir.exists():
shutil.rmtree(tmp_vae_dir)
tmp_vae_dir.mkdir(parents=True)
# Copy files to temp directory
logger.info(f"Setting up VAE in temporary directory: {tmp_vae_dir}")
# Copy and rename config file
original_config = vae_path / "hunyuan-video-t2v-720p_vae_config.json"
new_config = tmp_vae_dir / "config.json"
shutil.copy2(original_config, new_config)
logger.info(f"Copied VAE config from {original_config} to {new_config}")
# Copy model file
original_model = vae_path / "pytorch_model.pt"
new_model = tmp_vae_dir / "pytorch_model.pt"
shutil.copy2(original_model, new_model)
logger.info(f"Copied VAE model from {original_model} to {new_model}")
return tmp_vae_dir
def get_default_args():
"""Create default arguments instead of parsing from command line"""
parser = argparse.ArgumentParser()
# Model configuration
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
parser.add_argument("--model-resolution", type=str, default=DEFAULT_RESOLUTION, choices=["540p", "720p"])
parser.add_argument("--latent-channels", type=int, default=16)
parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
parser.add_argument("--rope-theta", type=int, default=256)
parser.add_argument("--load-key", type=str, default="module")
parser.add_argument("--use-fp8", action="store_true", default=False)
# VAE settings
parser.add_argument("--vae", type=str, default="884-16c-hy")
parser.add_argument("--vae-precision", type=str, default="fp16")
parser.add_argument("--vae-tiling", action="store_true", default=True)
# Text encoder settings
parser.add_argument("--text-encoder", type=str, default="llm")
parser.add_argument("--text-encoder-precision", type=str, default="fp16")
parser.add_argument("--text-states-dim", type=int, default=4096)
parser.add_argument("--text-len", type=int, default=256)
parser.add_argument("--tokenizer", type=str, default="llm")
# Prompt template settings
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video")
# Additional text encoder settings
parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
parser.add_argument("--apply-final-norm", action="store_true")
parser.add_argument("--text-encoder-2", type=str, default="clipL")
parser.add_argument("--text-encoder-precision-2", type=str, default="fp16")
parser.add_argument("--text-states-dim-2", type=int, default=768)
parser.add_argument("--tokenizer-2", type=str, default="clipL")
parser.add_argument("--text-len-2", type=int, default=77)
# Model architecture settings
parser.add_argument("--hidden-size", type=int, default=1024)
parser.add_argument("--heads-num", type=int, default=16)
parser.add_argument("--layers-num", type=int, default=24)
parser.add_argument("--mlp-ratio", type=float, default=4.0)
parser.add_argument("--use-guidance-net", action="store_true", default=True)
# Inference settings
parser.add_argument("--denoise-type", type=str, default="flow")
parser.add_argument("--flow-shift", type=float, default=7.0)
parser.add_argument("--flow-reverse", action="store_true", default=True)
parser.add_argument("--flow-solver", type=str, default="euler")
parser.add_argument("--use-linear-quadratic-schedule", action="store_true")
parser.add_argument("--linear-schedule-end", type=int, default=25)
# Hardware settings
parser.add_argument("--use-cpu-offload", action="store_true", default=False)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--infer-steps", type=int, default=DEFAULT_NB_STEPS)
parser.add_argument("--disable-autocast", action="store_true")
# Output settings
parser.add_argument("--save-path", type=str, default="outputs")
parser.add_argument("--save-path-suffix", type=str, default="")
parser.add_argument("--name-suffix", type=str, default="")
# Generation settings
parser.add_argument("--num-videos", type=int, default=1)
parser.add_argument("--video-size", nargs="+", type=int, default=[DEFAULT_HEIGHT, DEFAULT_WIDTH])
parser.add_argument("--video-length", type=int, default=DEFAULT_NB_FRAMES)
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--seed-type", type=str, default="auto", choices=["file", "random", "fixed", "auto"])
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--neg-prompt", type=str, default="")
parser.add_argument("--cfg-scale", type=float, default=1.0)
parser.add_argument("--embedded-cfg-scale", type=float, default=6.0)
parser.add_argument("--reproduce", action="store_true")
# Parallel settings
parser.add_argument("--ulysses-degree", type=int, default=1)
parser.add_argument("--ring-degree", type=int, default=1)
# Parse with empty args list to avoid reading sys.argv
args = parser.parse_args([])
return args
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the handler with model path and default config."""
logger.info(f"Initializing EndpointHandler with path: {path}")
# Use default args instead of parsing from command line
self.args = get_default_args()
# Convert path to absolute path if not already
path = str(Path(path).absolute())
logger.info(f"Absolute path: {path}")
# Set up model paths
self.args.model_base = path
# Set paths for model components
dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
# to save on memory, we activate fp8 weights and we override the previous dit_weight_path setting
self.args.use_fp8 = True
dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"
# Log all critical paths
logger.info(f"Model base path: {self.args.model_base}")
logger.info(f"DiT weight path: {dit_weight_path}")
logger.info(f"Use fp8: {self.args.use_fp8}")
logger.info(f"Original VAE path: {original_vae_path}")
# Verify paths exist
logger.info("Checking if paths exist:")
logger.info(f"DiT weight exists: {dit_weight_path.exists()}")
logger.info(f"VAE path exists: {original_vae_path.exists()}")
if original_vae_path.exists():
logger.info(f"VAE path contents: {list(original_vae_path.glob('*'))}")
# Set up VAE in temporary directory with correct file names
tmp_vae_path = setup_vae_path(original_vae_path)
# Override the VAE path in constants to use our temporary directory
from hyvideo.constants import VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
VAE_PATH["884-16c-hy"] = str(tmp_vae_path)
logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}")
# Update text encoder paths to use absolute paths
text_encoder_path = str(Path(path) / "text_encoder")
text_encoder_2_path = str(Path(path) / "text_encoder_2")
# Update both text encoder and tokenizer paths
TEXT_ENCODER_PATH.update({
"llm": text_encoder_path,
"clipL": text_encoder_2_path
})
TOKENIZER_PATH.update({
"llm": text_encoder_path,
"clipL": text_encoder_2_path
})
logger.info(f"Updated text encoder paths:")
logger.info(f"TEXT_ENCODER_PATH['llm']: {TEXT_ENCODER_PATH['llm']}")
logger.info(f"TEXT_ENCODER_PATH['clipL']: {TEXT_ENCODER_PATH['clipL']}")
logger.info(f"TOKENIZER_PATH['llm']: {TOKENIZER_PATH['llm']}")
logger.info(f"TOKENIZER_PATH['clipL']: {TOKENIZER_PATH['clipL']}")
self.args.dit_weight = str(dit_weight_path)
# Initialize model
models_root_path = Path(path)
if not models_root_path.exists():
raise ValueError(f"models_root_path does not exist: {models_root_path}")
try:
logger.info("Attempting to initialize HunyuanVideoSampler...")
self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
logger.info("Successfully initialized HunyuanVideoSampler")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single request"""
# Log incoming request
logger.info(f"Processing request with data: {data}")
# Get inputs from request data
prompt = data.pop("inputs", None)
if prompt is None:
raise ValueError("No prompt provided in the 'inputs' field")
# Parse resolution
resolution = data.pop("resolution", f"{DEFAULT_WIDTH}x{DEFAULT_HEIGHT}")
width, height = map(int, resolution.split("x"))
# Get other parameters with defaults
video_length = int(data.pop("video_length", DEFAULT_NB_FRAMES))
seed = data.pop("seed", -1)
seed = None if seed == -1 else int(seed)
num_inference_steps = int(data.pop("num_inference_steps", DEFAULT_NB_STEPS))
guidance_scale = float(data.pop("guidance_scale", 1.0))
flow_shift = float(data.pop("flow_shift", 7.0))
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
logger.info(f"Processing with parameters: width={width}, height={height}, "
f"video_length={video_length}, seed={seed}, "
f"num_inference_steps={num_inference_steps}")
try:
# Run inference
outputs = self.model.predict(
prompt=prompt,
height=height,
width=width,
video_length=video_length,
seed=seed,
negative_prompt="",
infer_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_videos_per_prompt=1,
flow_shift=flow_shift,
batch_size=1,
embedded_guidance_scale=embedded_guidance_scale
)
# Get the video tensor
samples = outputs['samples']
sample = samples[0].unsqueeze(0)
# Save to temporary file
temp_path = "/tmp/temp_video.mp4"
save_videos_grid(sample, temp_path, fps=DEFAULT_FPS)
# Read video file and convert to base64
with open(temp_path, "rb") as f:
video_bytes = f.read()
import base64
video_base64 = base64.b64encode(video_bytes).decode()
# Add MP4 data URI prefix
video_data_uri = f"data:video/mp4;base64,{video_base64}"
# Cleanup
os.remove(temp_path)
logger.info("Successfully generated and encoded video")
return video_data_uri
except Exception as e:
logger.error(f"Error during video generation: {str(e)}")
raise