jbilcke-hf HF staff commited on
Commit
6752edd
·
verified ·
1 Parent(s): 1a61ca3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +57 -7
handler.py CHANGED
@@ -13,7 +13,6 @@ import base64
13
  from hyvideo.utils.file_utils import save_videos_grid
14
  from hyvideo.inference import HunyuanVideoSampler
15
  from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
16
- from hyvideo.modules.attenion import get_attention_modes
17
 
18
  try:
19
  import triton
@@ -37,8 +36,45 @@ DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 fr
37
  DEFAULT_NB_STEPS = 22 # Default for standard model
38
  DEFAULT_FPS = 24
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Get supported attention modes
41
- attention_modes_supported = get_attention_modes()
 
 
 
 
42
 
43
  def setup_vae_path(vae_path: Path) -> Path:
44
  """Create a temporary directory with correctly named VAE config file"""
@@ -317,10 +353,20 @@ class EndpointHandler:
317
  try:
318
  logger.info("Attempting to initialize HunyuanVideoSampler...")
319
 
320
- # Apply attention mode setting
321
- self.args.attention = self.attention_mode
 
322
 
323
- self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
 
 
 
 
 
 
 
 
 
324
 
325
  # Set attention mode for transformer blocks
326
  if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
@@ -362,7 +408,7 @@ class EndpointHandler:
362
  logger.error(f"Error initializing model: {str(e)}")
363
  raise
364
 
365
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
366
  """Process a single request"""
367
  # Log incoming request
368
  logger.info(f"Processing request with data: {data}")
@@ -385,6 +431,7 @@ class EndpointHandler:
385
  flow_shift = float(data.pop("flow_shift", 7.0))
386
  embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
387
  enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
 
388
 
389
  logger.info(f"Processing with parameters: width={width}, height={height}, "
390
  f"video_length={video_length}, seed={seed}, "
@@ -392,10 +439,12 @@ class EndpointHandler:
392
 
393
  try:
394
  # Set up TeaCache for this generation if enabled
395
- if hasattr(self.model.pipeline, 'transformer') and self.model.pipeline.transformer.enable_teacache:
396
  transformer = self.model.pipeline.transformer
 
397
  transformer.num_steps = num_inference_steps
398
  transformer.cnt = 0
 
399
  transformer.accumulated_rel_l1_distance = 0
400
  transformer.previous_modulated_input = None
401
  transformer.previous_residual = None
@@ -450,6 +499,7 @@ class EndpointHandler:
450
 
451
  logger.info("Successfully generated and encoded video")
452
 
 
453
  return video_data_uri
454
 
455
  except Exception as e:
 
13
  from hyvideo.utils.file_utils import save_videos_grid
14
  from hyvideo.inference import HunyuanVideoSampler
15
  from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
 
16
 
17
  try:
18
  import triton
 
36
  DEFAULT_NB_STEPS = 22 # Default for standard model
37
  DEFAULT_FPS = 24
38
 
39
+ def get_attention_modes():
40
+ """Get available attention modes - fallback if module function isn't available"""
41
+ modes = ["sdpa"] # Always available
42
+
43
+ try:
44
+ import torch
45
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
46
+ modes.append("sdpa")
47
+ except:
48
+ pass
49
+
50
+ try:
51
+ import flash_attn
52
+ modes.append("flash")
53
+ except:
54
+ pass
55
+
56
+ try:
57
+ import sageattention
58
+ modes.append("sage")
59
+ if hasattr(sageattention, 'efficient_attention_v2'):
60
+ modes.append("sage2")
61
+ except:
62
+ pass
63
+
64
+ try:
65
+ import xformers
66
+ modes.append("xformers")
67
+ except:
68
+ pass
69
+
70
+ return modes
71
+
72
  # Get supported attention modes
73
+ try:
74
+ from hyvideo.modules.attenion import get_attention_modes
75
+ attention_modes_supported = get_attention_modes()
76
+ except:
77
+ attention_modes_supported = get_attention_modes()
78
 
79
  def setup_vae_path(vae_path: Path) -> Path:
80
  """Create a temporary directory with correctly named VAE config file"""
 
353
  try:
354
  logger.info("Attempting to initialize HunyuanVideoSampler...")
355
 
356
+ # Extract necessary paths
357
+ transformer_path = str(self.args.dit_weight)
358
+ text_encoder_path = str(Path(self.args.model_base) / "text_encoder")
359
 
360
+ logger.info(f"Transformer path: {transformer_path}")
361
+ logger.info(f"Text encoder path: {text_encoder_path}")
362
+
363
+ # Initialize the model using the exact signature from gradio_server.py
364
+ self.model = HunyuanVideoSampler.from_pretrained(
365
+ transformer_path,
366
+ text_encoder_path,
367
+ attention_mode=self.attention_mode,
368
+ args=self.args
369
+ )
370
 
371
  # Set attention mode for transformer blocks
372
  if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
 
408
  logger.error(f"Error initializing model: {str(e)}")
409
  raise
410
 
411
+ def __call__(self, data: Dict[str, Any]) -> str:
412
  """Process a single request"""
413
  # Log incoming request
414
  logger.info(f"Processing request with data: {data}")
 
431
  flow_shift = float(data.pop("flow_shift", 7.0))
432
  embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
433
  enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
434
+ tea_cache = float(data.pop("tea_cache", 0.0))
435
 
436
  logger.info(f"Processing with parameters: width={width}, height={height}, "
437
  f"video_length={video_length}, seed={seed}, "
 
439
 
440
  try:
441
  # Set up TeaCache for this generation if enabled
442
+ if hasattr(self.model.pipeline, 'transformer') and tea_cache > 0:
443
  transformer = self.model.pipeline.transformer
444
+ transformer.enable_teacache = True
445
  transformer.num_steps = num_inference_steps
446
  transformer.cnt = 0
447
+ transformer.rel_l1_thresh = tea_cache
448
  transformer.accumulated_rel_l1_distance = 0
449
  transformer.previous_modulated_input = None
450
  transformer.previous_residual = None
 
499
 
500
  logger.info("Successfully generated and encoded video")
501
 
502
+ # Return exactly what the demo.py expects
503
  return video_data_uri
504
 
505
  except Exception as e: