jbilcke-hf HF staff commited on
Commit
2745124
·
verified ·
1 Parent(s): de858d1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -11
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Dict, Any
2
  import os
 
3
  from pathlib import Path
4
  import time
5
  from datetime import datetime
@@ -12,6 +13,30 @@ from hyvideo.constants import NEGATIVE_PROMPT
12
  # Configure logger
13
  logger.add("handler_debug.log", rotation="500 MB")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def get_default_args():
16
  """Create default arguments instead of parsing from command line"""
17
  parser = argparse.ArgumentParser()
@@ -99,7 +124,6 @@ def get_default_args():
99
  class EndpointHandler:
100
  def __init__(self, path: str = ""):
101
  """Initialize the handler with model path and default config."""
102
- # Log the initial path
103
  logger.info(f"Initializing EndpointHandler with path: {path}")
104
 
105
  # Use default args instead of parsing from command line
@@ -114,19 +138,28 @@ class EndpointHandler:
114
 
115
  # Set paths for model components
116
  dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
117
- vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
118
 
119
  # Log all critical paths
120
  logger.info(f"Model base path: {self.args.model_base}")
121
  logger.info(f"DiT weight path: {dit_weight_path}")
122
- logger.info(f"VAE path: {vae_path}")
123
 
124
  # Verify paths exist
125
  logger.info("Checking if paths exist:")
126
  logger.info(f"DiT weight exists: {dit_weight_path.exists()}")
127
- logger.info(f"VAE path exists: {vae_path.exists()}")
128
- if vae_path.exists():
129
- logger.info(f"VAE path contents: {list(vae_path.glob('*'))}")
 
 
 
 
 
 
 
 
 
130
 
131
  self.args.dit_weight = str(dit_weight_path)
132
 
@@ -135,11 +168,6 @@ class EndpointHandler:
135
  if not models_root_path.exists():
136
  raise ValueError(f"models_root_path does not exist: {models_root_path}")
137
 
138
- # Log directory contents for debugging
139
- logger.info("Directory contents:")
140
- for item in models_root_path.glob("**/*"):
141
- logger.info(f" {item}")
142
-
143
  try:
144
  logger.info("Attempting to initialize HunyuanVideoSampler...")
145
  self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
 
1
  from typing import Dict, Any
2
  import os
3
+ import shutil
4
  from pathlib import Path
5
  import time
6
  from datetime import datetime
 
13
  # Configure logger
14
  logger.add("handler_debug.log", rotation="500 MB")
15
 
16
+ def setup_vae_path(vae_path: Path) -> Path:
17
+ """Create a temporary directory with correctly named VAE config file"""
18
+ tmp_vae_dir = Path("/tmp/vae")
19
+ if tmp_vae_dir.exists():
20
+ shutil.rmtree(tmp_vae_dir)
21
+ tmp_vae_dir.mkdir(parents=True)
22
+
23
+ # Copy files to temp directory
24
+ logger.info(f"Setting up VAE in temporary directory: {tmp_vae_dir}")
25
+
26
+ # Copy and rename config file
27
+ original_config = vae_path / "hunyuan-video-t2v-720p_vae_config.json"
28
+ new_config = tmp_vae_dir / "config.json"
29
+ shutil.copy2(original_config, new_config)
30
+ logger.info(f"Copied VAE config from {original_config} to {new_config}")
31
+
32
+ # Copy model file
33
+ original_model = vae_path / "pytorch_model.pt"
34
+ new_model = tmp_vae_dir / "pytorch_model.pt"
35
+ shutil.copy2(original_model, new_model)
36
+ logger.info(f"Copied VAE model from {original_model} to {new_model}")
37
+
38
+ return tmp_vae_dir
39
+
40
  def get_default_args():
41
  """Create default arguments instead of parsing from command line"""
42
  parser = argparse.ArgumentParser()
 
124
  class EndpointHandler:
125
  def __init__(self, path: str = ""):
126
  """Initialize the handler with model path and default config."""
 
127
  logger.info(f"Initializing EndpointHandler with path: {path}")
128
 
129
  # Use default args instead of parsing from command line
 
138
 
139
  # Set paths for model components
140
  dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
141
+ original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
142
 
143
  # Log all critical paths
144
  logger.info(f"Model base path: {self.args.model_base}")
145
  logger.info(f"DiT weight path: {dit_weight_path}")
146
+ logger.info(f"Original VAE path: {original_vae_path}")
147
 
148
  # Verify paths exist
149
  logger.info("Checking if paths exist:")
150
  logger.info(f"DiT weight exists: {dit_weight_path.exists()}")
151
+ logger.info(f"VAE path exists: {original_vae_path.exists()}")
152
+
153
+ if original_vae_path.exists():
154
+ logger.info(f"VAE path contents: {list(original_vae_path.glob('*'))}")
155
+
156
+ # Set up VAE in temporary directory with correct file names
157
+ tmp_vae_path = setup_vae_path(original_vae_path)
158
+
159
+ # Override the VAE path in constants to use our temporary directory
160
+ from hyvideo.constants import VAE_PATH
161
+ VAE_PATH["884-16c-hy"] = str(tmp_vae_path)
162
+ logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}")
163
 
164
  self.args.dit_weight = str(dit_weight_path)
165
 
 
168
  if not models_root_path.exists():
169
  raise ValueError(f"models_root_path does not exist: {models_root_path}")
170
 
 
 
 
 
 
171
  try:
172
  logger.info("Attempting to initialize HunyuanVideoSampler...")
173
  self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)