jbilcke-hf HF staff commited on
Commit
afceeed
·
verified ·
1 Parent(s): f08eddf

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +67 -69
handler.py CHANGED
@@ -3,87 +3,88 @@ import os
3
  from pathlib import Path
4
  import time
5
  from datetime import datetime
6
- import torch
7
- import base64
8
- from io import BytesIO
9
-
10
  from hyvideo.utils.file_utils import save_videos_grid
11
- from hyvideo.config import parse_args
12
  from hyvideo.inference import HunyuanVideoSampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class EndpointHandler:
15
  def __init__(self, path: str = ""):
16
- """Initialize the handler with the model path.
 
 
 
17
 
18
- Args:
19
- path: Path to the model weights directory
20
- """
21
- self.args = parse_args()
22
  models_root_path = Path(path)
23
  if not models_root_path.exists():
24
  raise ValueError(f"`models_root` not exists: {models_root_path}")
25
 
26
- # Initialize model
27
  self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
28
 
29
- # Default parameters
30
- self.default_params = {
31
- "num_inference_steps": 50,
32
- "guidance_scale": 1.0,
33
- "flow_shift": 7.0,
34
- "embedded_guidance_scale": 6.0,
35
- "video_length": 129, # 5s
36
- "resolution": "1280x720"
37
- }
38
-
39
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
40
- """Process the input data and generate video.
41
 
42
  Args:
43
- data: Dictionary containing the input parameters
44
- Required:
45
- - inputs (str): The prompt text
46
- Optional:
47
- - resolution (str): Video resolution like "1280x720"
48
- - video_length (int): Number of frames
49
- - seed (int): Random seed (-1 for random)
50
- - num_inference_steps (int): Number of inference steps
51
- - guidance_scale (float): Guidance scale value
52
- - flow_shift (float): Flow shift value
53
- - embedded_guidance_scale (float): Embedded guidance scale value
54
 
55
  Returns:
56
- Dictionary containing the base64 encoded video
57
  """
58
- # Get prompt
59
  prompt = data.pop("inputs", None)
60
  if prompt is None:
61
  raise ValueError("No prompt provided in the 'inputs' field")
62
-
63
- # Get optional parameters with defaults
64
- resolution = data.pop("resolution", self.default_params["resolution"])
65
- video_length = int(data.pop("video_length", self.default_params["video_length"]))
66
- seed = int(data.pop("seed", -1))
67
- num_inference_steps = int(data.pop("num_inference_steps", self.default_params["num_inference_steps"]))
68
- guidance_scale = float(data.pop("guidance_scale", self.default_params["guidance_scale"]))
69
- flow_shift = float(data.pop("flow_shift", self.default_params["flow_shift"]))
70
- embedded_guidance_scale = float(data.pop("embedded_guidance_scale", self.default_params["embedded_guidance_scale"]))
71
-
72
- # Process resolution
73
- width, height = resolution.split("x")
74
- width, height = int(width), int(height)
75
-
76
- # Set seed
77
- seed = None if seed == -1 else seed
78
-
79
- # Generate video
80
  outputs = self.model.predict(
81
  prompt=prompt,
82
  height=height,
83
  width=width,
84
  video_length=video_length,
85
  seed=seed,
86
- negative_prompt="", # not applicable in inference
87
  infer_steps=num_inference_steps,
88
  guidance_scale=guidance_scale,
89
  num_videos_per_prompt=1,
@@ -91,27 +92,24 @@ class EndpointHandler:
91
  batch_size=1,
92
  embedded_guidance_scale=embedded_guidance_scale
93
  )
94
-
95
- # Process output video
96
  samples = outputs['samples']
97
  sample = samples[0].unsqueeze(0)
98
-
99
- # Save video to temporary file
100
- temp_dir = "/tmp/video_output"
101
- os.makedirs(temp_dir, exist_ok=True)
102
 
103
- time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
104
- video_path = f"{temp_dir}/{time_flag}_seed{outputs['seeds'][0]}.mp4"
105
- save_videos_grid(sample, video_path, fps=24)
106
-
107
  # Read video file and convert to base64
108
- with open(video_path, "rb") as f:
109
  video_bytes = f.read()
 
110
  video_base64 = base64.b64encode(video_bytes).decode()
111
-
112
- # Clean up
113
- os.remove(video_path)
114
-
115
  return {
116
  "video_base64": video_base64,
117
  "seed": outputs['seeds'][0],
 
3
  from pathlib import Path
4
  import time
5
  from datetime import datetime
6
+ import argparse
 
 
 
7
  from hyvideo.utils.file_utils import save_videos_grid
 
8
  from hyvideo.inference import HunyuanVideoSampler
9
+ from hyvideo.config import parse_args
10
+ from hyvideo.constants import NEGATIVE_PROMPT
11
+
12
+ def get_default_args():
13
+ """Create default arguments instead of parsing from command line"""
14
+ parser = argparse.ArgumentParser()
15
+
16
+ # Add all the arguments that were in the original parser
17
+ parser.add_argument("--model", type=str, default="HYVideo-T/2")
18
+ parser.add_argument("--model-resolution", type=str, default="720p", choices=["540p", "720p"])
19
+ parser.add_argument("--latent-channels", type=int, default=4)
20
+ parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
21
+ parser.add_argument("--batch-size", type=int, default=1)
22
+ parser.add_argument("--infer-steps", type=int, default=50)
23
+ parser.add_argument("--model-base", type=str, default=None)
24
+ parser.add_argument("--save-path", type=str, default="outputs")
25
+ parser.add_argument("--video-length", type=int, default=129) # 5 seconds
26
+
27
+ # Parse with empty args list to avoid reading sys.argv
28
+ args = parser.parse_args([])
29
+ return args
30
 
31
  class EndpointHandler:
32
  def __init__(self, path: str = ""):
33
+ """Initialize the handler with model path and default config."""
34
+ # Use default args instead of parsing from command line
35
+ self.args = get_default_args()
36
+ self.args.model_base = path # Use the provided model path
37
 
38
+ # Initialize model
 
 
 
39
  models_root_path = Path(path)
40
  if not models_root_path.exists():
41
  raise ValueError(f"`models_root` not exists: {models_root_path}")
42
 
 
43
  self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
44
 
 
 
 
 
 
 
 
 
 
 
45
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
46
+ """Process a single request
47
 
48
  Args:
49
+ data: Dictionary containing:
50
+ - inputs (str): The prompt text
51
+ - resolution (str, optional): Video resolution like "1280x720"
52
+ - video_length (int, optional): Number of frames
53
+ - num_inference_steps (int, optional): Number of inference steps
54
+ - seed (int, optional): Random seed (-1 for random)
55
+ - guidance_scale (float, optional): Guidance scale value
56
+ - flow_shift (float, optional): Flow shift value
57
+ - embedded_guidance_scale (float, optional): Embedded guidance scale
 
 
58
 
59
  Returns:
60
+ Dictionary containing the generated video as base64 string
61
  """
62
+ # Get inputs from request data
63
  prompt = data.pop("inputs", None)
64
  if prompt is None:
65
  raise ValueError("No prompt provided in the 'inputs' field")
66
+
67
+ # Parse resolution
68
+ resolution = data.pop("resolution", "1280x720")
69
+ width, height = map(int, resolution.split("x"))
70
+
71
+ # Get other parameters with defaults
72
+ video_length = int(data.pop("video_length", 129))
73
+ seed = data.pop("seed", -1)
74
+ seed = None if seed == -1 else int(seed)
75
+ num_inference_steps = int(data.pop("num_inference_steps", 50))
76
+ guidance_scale = float(data.pop("guidance_scale", 1.0))
77
+ flow_shift = float(data.pop("flow_shift", 7.0))
78
+ embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
79
+
80
+ # Run inference
 
 
 
81
  outputs = self.model.predict(
82
  prompt=prompt,
83
  height=height,
84
  width=width,
85
  video_length=video_length,
86
  seed=seed,
87
+ negative_prompt="",
88
  infer_steps=num_inference_steps,
89
  guidance_scale=guidance_scale,
90
  num_videos_per_prompt=1,
 
92
  batch_size=1,
93
  embedded_guidance_scale=embedded_guidance_scale
94
  )
95
+
96
+ # Get the video tensor
97
  samples = outputs['samples']
98
  sample = samples[0].unsqueeze(0)
 
 
 
 
99
 
100
+ # Save to temporary file
101
+ temp_path = "/tmp/temp_video.mp4"
102
+ save_videos_grid(sample, temp_path, fps=24)
103
+
104
  # Read video file and convert to base64
105
+ with open(temp_path, "rb") as f:
106
  video_bytes = f.read()
107
+ import base64
108
  video_base64 = base64.b64encode(video_bytes).decode()
109
+
110
+ # Cleanup
111
+ os.remove(temp_path)
112
+
113
  return {
114
  "video_base64": video_base64,
115
  "seed": outputs['seeds'][0],