jbilcke-hf HF staff commited on
Commit
ef15707
1 Parent(s): e349e43

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +108 -38
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Any, Union, Optional
2
  import torch
3
  from diffusers import LTXPipeline, LTXImageToVideoPipeline
4
  from PIL import Image
@@ -15,6 +15,19 @@ logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
  class EndpointHandler:
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def __init__(self, path: str = ""):
19
  """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
20
 
@@ -35,11 +48,55 @@ class EndpointHandler:
35
  # Enable memory optimizations
36
  self.text_to_video.enable_model_cpu_offload()
37
  self.image_to_video.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Set default FPS
40
- self.fps = 24
41
 
42
- def _create_video_file(self, frames: torch.Tensor, fps: int = 24) -> bytes:
43
  """Convert frames to an MP4 video file.
44
 
45
  Args:
@@ -50,11 +107,11 @@ class EndpointHandler:
50
  bytes: MP4 video file content
51
  """
52
  # Log frame information
53
- num_frames = frames.shape[1] # Shape should be [1, num_frames, channels, height, width]
54
  duration = num_frames / fps
55
  logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")
56
 
57
- # Convert tensor to numpy array - remove batch dimension and rearrange to [num_frames, height, width, channels]
58
  video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
59
  video_np = (video_np * 255).astype(np.uint8)
60
 
@@ -68,8 +125,7 @@ class EndpointHandler:
68
  try:
69
  # Create video clip and write to file
70
  clip = ImageSequenceClip(list(video_np), fps=fps)
71
- resized = clip.resize((width, height))
72
- resized.write_videofile(output_path, codec="libx264", audio=False)
73
 
74
  # Read the video file
75
  with open(output_path, "rb") as f:
@@ -93,60 +149,66 @@ class EndpointHandler:
93
  data (Dict[str, Any]): Input data containing:
94
  - prompt (str): Text description for video generation
95
  - image (Optional[str]): Base64 encoded image for image-to-video generation
96
- - num_frames (Optional[int]): Number of frames to generate (default: 24)
 
 
97
  - fps (Optional[int]): Frames per second (default: 24)
 
98
  - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
99
- - num_inference_steps (Optional[int]): Number of inference steps (default: 50)
100
 
101
  Returns:
102
  Dict[str, Any]: Dictionary containing:
103
  - video: Base64 encoded MP4 video
104
  - content-type: MIME type of the video
 
105
  """
106
- # Extract parameters
107
  prompt = data.get("prompt")
108
  if not prompt:
109
  raise ValueError("'prompt' is required in the input data")
110
 
111
- # Get optional parameters with defaults
112
- num_frames = data.get("num_frames", 24)
113
- fps = data.get("fps", self.fps)
 
 
 
 
 
 
 
 
114
  guidance_scale = data.get("guidance_scale", 7.5)
115
- num_inference_steps = data.get("num_inference_steps", 50)
116
 
117
  logger.info(f"Generating video with prompt: '{prompt}'")
118
- logger.info(f"Parameters: num_frames={num_frames}, fps={fps}, guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")
 
119
 
120
- # Check if image is provided for image-to-video generation
121
- image_data = data.get("image")
122
-
123
  try:
124
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
125
  if image_data:
126
  # Decode base64 image
127
  image_bytes = base64.b64decode(image_data)
128
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
129
  logger.info("Using image-to-video generation mode")
130
-
131
- # Generate video from image
132
- output = self.image_to_video(
133
- prompt=prompt,
134
- image=image,
135
- num_frames=num_frames,
136
- guidance_scale=guidance_scale,
137
- num_inference_steps=num_inference_steps,
138
- output_type="pt"
139
- ).frames # Remove [0] to keep all frames
140
  else:
141
  logger.info("Using text-to-video generation mode")
142
- # Generate video from text only
143
- output = self.text_to_video(
144
- prompt=prompt,
145
- num_frames=num_frames,
146
- guidance_scale=guidance_scale,
147
- num_inference_steps=num_inference_steps,
148
- output_type="pt"
149
- ).frames # Remove [0] to keep all frames
150
 
151
  # Convert frames to video file
152
  video_content = self._create_video_file(output, fps=fps)
@@ -156,7 +218,15 @@ class EndpointHandler:
156
 
157
  return {
158
  "video": video_base64,
159
- "content-type": "video/mp4"
 
 
 
 
 
 
 
 
160
  }
161
 
162
  except Exception as e:
 
1
+ from typing import Dict, Any, Union, Optional, Tuple
2
  import torch
3
  from diffusers import LTXPipeline, LTXImageToVideoPipeline
4
  from PIL import Image
 
15
  logger = logging.getLogger(__name__)
16
 
17
  class EndpointHandler:
18
+ # Default configuration
19
+ DEFAULT_FPS = 24
20
+ DEFAULT_DURATION = 4 # seconds
21
+ DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1 # 97 frames
22
+ DEFAULT_NUM_STEPS = 25
23
+ DEFAULT_WIDTH = 768
24
+ DEFAULT_HEIGHT = 512
25
+
26
+ # Constraints
27
+ MAX_WIDTH = 1280
28
+ MAX_HEIGHT = 720
29
+ MAX_FRAMES = 257
30
+
31
  def __init__(self, path: str = ""):
32
  """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
33
 
 
48
  # Enable memory optimizations
49
  self.text_to_video.enable_model_cpu_offload()
50
  self.image_to_video.enable_model_cpu_offload()
51
+
52
+ def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]:
53
+ """Validate and adjust resolution to meet constraints.
54
+
55
+ Args:
56
+ width (int): Requested width
57
+ height (int): Requested height
58
+
59
+ Returns:
60
+ Tuple[int, int]: Adjusted (width, height)
61
+ """
62
+ # Round to nearest multiple of 32
63
+ width = round(width / 32) * 32
64
+ height = round(height / 32) * 32
65
+
66
+ # Enforce maximum dimensions
67
+ width = min(width, self.MAX_WIDTH)
68
+ height = min(height, self.MAX_HEIGHT)
69
+
70
+ # Enforce minimum dimensions
71
+ width = max(width, 32)
72
+ height = max(height, 32)
73
+
74
+ return width, height
75
+
76
+ def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]:
77
+ """Validate and adjust frame count and FPS to meet constraints.
78
+
79
+ Args:
80
+ num_frames (Optional[int]): Requested number of frames
81
+ fps (Optional[int]): Requested frames per second
82
+
83
+ Returns:
84
+ Tuple[int, int]: Adjusted (num_frames, fps)
85
+ """
86
+ # Use defaults if not provided
87
+ fps = fps or self.DEFAULT_FPS
88
+ num_frames = num_frames or self.DEFAULT_NUM_FRAMES
89
+
90
+ # Adjust frames to be in format 8k + 1
91
+ k = (num_frames - 1) // 8
92
+ num_frames = (k * 8) + 1
93
+
94
+ # Enforce maximum frame count
95
+ num_frames = min(num_frames, self.MAX_FRAMES)
96
 
97
+ return num_frames, fps
 
98
 
99
+ def _create_video_file(self, frames: torch.Tensor, fps: int = DEFAULT_FPS) -> bytes:
100
  """Convert frames to an MP4 video file.
101
 
102
  Args:
 
107
  bytes: MP4 video file content
108
  """
109
  # Log frame information
110
+ num_frames = frames.shape[1]
111
  duration = num_frames / fps
112
  logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")
113
 
114
+ # Convert tensor to numpy array
115
  video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
116
  video_np = (video_np * 255).astype(np.uint8)
117
 
 
125
  try:
126
  # Create video clip and write to file
127
  clip = ImageSequenceClip(list(video_np), fps=fps)
128
+ clip.write_videofile(output_path, codec="libx264", audio=False)
 
129
 
130
  # Read the video file
131
  with open(output_path, "rb") as f:
 
149
  data (Dict[str, Any]): Input data containing:
150
  - prompt (str): Text description for video generation
151
  - image (Optional[str]): Base64 encoded image for image-to-video generation
152
+ - width (Optional[int]): Video width (default: 768)
153
+ - height (Optional[int]): Video height (default: 512)
154
+ - num_frames (Optional[int]): Number of frames (default: 97)
155
  - fps (Optional[int]): Frames per second (default: 24)
156
+ - num_inference_steps (Optional[int]): Number of inference steps (default: 25)
157
  - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
 
158
 
159
  Returns:
160
  Dict[str, Any]: Dictionary containing:
161
  - video: Base64 encoded MP4 video
162
  - content-type: MIME type of the video
163
+ - metadata: Dictionary with actual values used for generation
164
  """
165
+ # Extract and validate prompt
166
  prompt = data.get("prompt")
167
  if not prompt:
168
  raise ValueError("'prompt' is required in the input data")
169
 
170
+ # Get and validate resolution
171
+ width = data.get("width", self.DEFAULT_WIDTH)
172
+ height = data.get("height", self.DEFAULT_HEIGHT)
173
+ width, height = self._validate_and_adjust_resolution(width, height)
174
+
175
+ # Get and validate frames and FPS
176
+ num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES)
177
+ fps = data.get("fps", self.DEFAULT_FPS)
178
+ num_frames, fps = self._validate_and_adjust_frames(num_frames, fps)
179
+
180
+ # Get other parameters with defaults
181
  guidance_scale = data.get("guidance_scale", 7.5)
182
+ num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS)
183
 
184
  logger.info(f"Generating video with prompt: '{prompt}'")
185
+ logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}")
186
+ logger.info(f"Additional params: guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")
187
 
 
 
 
188
  try:
189
  with torch.no_grad():
190
+ generation_kwargs = {
191
+ "prompt": prompt,
192
+ "height": height,
193
+ "width": width,
194
+ "num_frames": num_frames,
195
+ "guidance_scale": guidance_scale,
196
+ "num_inference_steps": num_inference_steps,
197
+ "output_type": "pt"
198
+ }
199
+
200
+ # Check if image is provided for image-to-video generation
201
+ image_data = data.get("image")
202
  if image_data:
203
  # Decode base64 image
204
  image_bytes = base64.b64decode(image_data)
205
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
206
  logger.info("Using image-to-video generation mode")
207
+ generation_kwargs["image"] = image
208
+ output = self.image_to_video(**generation_kwargs).frames
 
 
 
 
 
 
 
 
209
  else:
210
  logger.info("Using text-to-video generation mode")
211
+ output = self.text_to_video(**generation_kwargs).frames
 
 
 
 
 
 
 
212
 
213
  # Convert frames to video file
214
  video_content = self._create_video_file(output, fps=fps)
 
218
 
219
  return {
220
  "video": video_base64,
221
+ "content-type": "video/mp4",
222
+ "metadata": {
223
+ "width": width,
224
+ "height": height,
225
+ "num_frames": num_frames,
226
+ "fps": fps,
227
+ "duration": num_frames / fps,
228
+ "num_inference_steps": num_inference_steps
229
+ }
230
  }
231
 
232
  except Exception as e: