Update handler.py
Browse files- handler.py +57 -7
handler.py
CHANGED
@@ -13,7 +13,6 @@ import base64
|
|
13 |
from hyvideo.utils.file_utils import save_videos_grid
|
14 |
from hyvideo.inference import HunyuanVideoSampler
|
15 |
from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
|
16 |
-
from hyvideo.modules.attenion import get_attention_modes
|
17 |
|
18 |
try:
|
19 |
import triton
|
@@ -37,8 +36,45 @@ DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 fr
|
|
37 |
DEFAULT_NB_STEPS = 22 # Default for standard model
|
38 |
DEFAULT_FPS = 24
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Get supported attention modes
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def setup_vae_path(vae_path: Path) -> Path:
|
44 |
"""Create a temporary directory with correctly named VAE config file"""
|
@@ -317,10 +353,20 @@ class EndpointHandler:
|
|
317 |
try:
|
318 |
logger.info("Attempting to initialize HunyuanVideoSampler...")
|
319 |
|
320 |
-
#
|
321 |
-
|
|
|
322 |
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
# Set attention mode for transformer blocks
|
326 |
if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
|
@@ -362,7 +408,7 @@ class EndpointHandler:
|
|
362 |
logger.error(f"Error initializing model: {str(e)}")
|
363 |
raise
|
364 |
|
365 |
-
def __call__(self, data: Dict[str, Any]) ->
|
366 |
"""Process a single request"""
|
367 |
# Log incoming request
|
368 |
logger.info(f"Processing request with data: {data}")
|
@@ -385,6 +431,7 @@ class EndpointHandler:
|
|
385 |
flow_shift = float(data.pop("flow_shift", 7.0))
|
386 |
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
|
387 |
enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
|
|
|
388 |
|
389 |
logger.info(f"Processing with parameters: width={width}, height={height}, "
|
390 |
f"video_length={video_length}, seed={seed}, "
|
@@ -392,10 +439,12 @@ class EndpointHandler:
|
|
392 |
|
393 |
try:
|
394 |
# Set up TeaCache for this generation if enabled
|
395 |
-
if hasattr(self.model.pipeline, 'transformer') and
|
396 |
transformer = self.model.pipeline.transformer
|
|
|
397 |
transformer.num_steps = num_inference_steps
|
398 |
transformer.cnt = 0
|
|
|
399 |
transformer.accumulated_rel_l1_distance = 0
|
400 |
transformer.previous_modulated_input = None
|
401 |
transformer.previous_residual = None
|
@@ -450,6 +499,7 @@ class EndpointHandler:
|
|
450 |
|
451 |
logger.info("Successfully generated and encoded video")
|
452 |
|
|
|
453 |
return video_data_uri
|
454 |
|
455 |
except Exception as e:
|
|
|
13 |
from hyvideo.utils.file_utils import save_videos_grid
|
14 |
from hyvideo.inference import HunyuanVideoSampler
|
15 |
from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
|
|
|
16 |
|
17 |
try:
|
18 |
import triton
|
|
|
36 |
DEFAULT_NB_STEPS = 22 # Default for standard model
|
37 |
DEFAULT_FPS = 24
|
38 |
|
39 |
+
def get_attention_modes():
|
40 |
+
"""Get available attention modes - fallback if module function isn't available"""
|
41 |
+
modes = ["sdpa"] # Always available
|
42 |
+
|
43 |
+
try:
|
44 |
+
import torch
|
45 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
46 |
+
modes.append("sdpa")
|
47 |
+
except:
|
48 |
+
pass
|
49 |
+
|
50 |
+
try:
|
51 |
+
import flash_attn
|
52 |
+
modes.append("flash")
|
53 |
+
except:
|
54 |
+
pass
|
55 |
+
|
56 |
+
try:
|
57 |
+
import sageattention
|
58 |
+
modes.append("sage")
|
59 |
+
if hasattr(sageattention, 'efficient_attention_v2'):
|
60 |
+
modes.append("sage2")
|
61 |
+
except:
|
62 |
+
pass
|
63 |
+
|
64 |
+
try:
|
65 |
+
import xformers
|
66 |
+
modes.append("xformers")
|
67 |
+
except:
|
68 |
+
pass
|
69 |
+
|
70 |
+
return modes
|
71 |
+
|
72 |
# Get supported attention modes
|
73 |
+
try:
|
74 |
+
from hyvideo.modules.attenion import get_attention_modes
|
75 |
+
attention_modes_supported = get_attention_modes()
|
76 |
+
except:
|
77 |
+
attention_modes_supported = get_attention_modes()
|
78 |
|
79 |
def setup_vae_path(vae_path: Path) -> Path:
|
80 |
"""Create a temporary directory with correctly named VAE config file"""
|
|
|
353 |
try:
|
354 |
logger.info("Attempting to initialize HunyuanVideoSampler...")
|
355 |
|
356 |
+
# Extract necessary paths
|
357 |
+
transformer_path = str(self.args.dit_weight)
|
358 |
+
text_encoder_path = str(Path(self.args.model_base) / "text_encoder")
|
359 |
|
360 |
+
logger.info(f"Transformer path: {transformer_path}")
|
361 |
+
logger.info(f"Text encoder path: {text_encoder_path}")
|
362 |
+
|
363 |
+
# Initialize the model using the exact signature from gradio_server.py
|
364 |
+
self.model = HunyuanVideoSampler.from_pretrained(
|
365 |
+
transformer_path,
|
366 |
+
text_encoder_path,
|
367 |
+
attention_mode=self.attention_mode,
|
368 |
+
args=self.args
|
369 |
+
)
|
370 |
|
371 |
# Set attention mode for transformer blocks
|
372 |
if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
|
|
|
408 |
logger.error(f"Error initializing model: {str(e)}")
|
409 |
raise
|
410 |
|
411 |
+
def __call__(self, data: Dict[str, Any]) -> str:
|
412 |
"""Process a single request"""
|
413 |
# Log incoming request
|
414 |
logger.info(f"Processing request with data: {data}")
|
|
|
431 |
flow_shift = float(data.pop("flow_shift", 7.0))
|
432 |
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
|
433 |
enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
|
434 |
+
tea_cache = float(data.pop("tea_cache", 0.0))
|
435 |
|
436 |
logger.info(f"Processing with parameters: width={width}, height={height}, "
|
437 |
f"video_length={video_length}, seed={seed}, "
|
|
|
439 |
|
440 |
try:
|
441 |
# Set up TeaCache for this generation if enabled
|
442 |
+
if hasattr(self.model.pipeline, 'transformer') and tea_cache > 0:
|
443 |
transformer = self.model.pipeline.transformer
|
444 |
+
transformer.enable_teacache = True
|
445 |
transformer.num_steps = num_inference_steps
|
446 |
transformer.cnt = 0
|
447 |
+
transformer.rel_l1_thresh = tea_cache
|
448 |
transformer.accumulated_rel_l1_distance = 0
|
449 |
transformer.previous_modulated_input = None
|
450 |
transformer.previous_residual = None
|
|
|
499 |
|
500 |
logger.info("Successfully generated and encoded video")
|
501 |
|
502 |
+
# Return exactly what the demo.py expects
|
503 |
return video_data_uri
|
504 |
|
505 |
except Exception as e:
|