ford442 commited on
Commit
58944ea
·
verified ·
1 Parent(s): 78bd780

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -90
app.py CHANGED
@@ -103,7 +103,7 @@ if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
103
  target_inference_device = "cuda"
104
  print(f"Target inference device: {target_inference_device}")
105
  pipeline_instance.to(target_inference_device)
106
- if latent_upsampler_instance:
107
  latent_upsampler_instance.to(target_inference_device)
108
 
109
 
@@ -125,22 +125,22 @@ def calculate_new_dimensions(orig_w, orig_h):
125
  new_h = TARGET_FIXED_SIDE
126
  aspect_ratio = orig_w / orig_h
127
  new_w_ideal = new_h * aspect_ratio
128
-
129
  # Round to nearest multiple of 32
130
  new_w = round(new_w_ideal / 32) * 32
131
-
132
  # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
133
  new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
134
  # Ensure new_h is also clamped (TARGET_FIXED_SIDE should be within these bounds if configured correctly)
135
- new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
136
  else: # Portrait
137
  new_w = TARGET_FIXED_SIDE
138
  aspect_ratio = orig_h / orig_w # Use H/W ratio for portrait scaling
139
  new_h_ideal = new_w * aspect_ratio
140
-
141
  # Round to nearest multiple of 32
142
  new_h = round(new_h_ideal / 32) * 32
143
-
144
  # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
145
  new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
146
  # Ensure new_w is also clamped
@@ -165,37 +165,55 @@ def get_duration(prompt, negative_prompt, input_image_filepath, input_video_file
165
  else:
166
  return 45
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  @spaces.GPU(duration=get_duration)
169
  def generate(prompt, negative_prompt, input_image_filepath=None, input_video_filepath=None,
170
  height_ui=512, width_ui=704, mode="text-to-video",
171
- duration_ui=2.0,
172
  ui_frames_to_use=9,
173
  seed_ui=42, randomize_seed=True, ui_guidance_scale=3.0, improve_texture_flag=True, num_steps=20, fps=30.0,
174
  progress=gr.Progress(track_tqdm=True)):
175
- """
176
- Generate high-quality videos using LTX Video model with support for text-to-video, image-to-video, and video-to-video modes.
177
-
178
- Args:
179
- prompt (str): Text description of the desired video content. Required for all modes.
180
- negative_prompt (str): Text describing what to avoid in the generated video. Optional, can be empty string.
181
- input_image_filepath (str or None): Path to input image file. Required for image-to-video mode, None for other modes.
182
- input_video_filepath (str or None): Path to input video file. Required for video-to-video mode, None for other modes.
183
- height_ui (int): Height of the output video in pixels, must be divisible by 32. Default: 512.
184
- width_ui (int): Width of the output video in pixels, must be divisible by 32. Default: 704.
185
- mode (str): Generation mode. Required. One of "text-to-video", "image-to-video", or "video-to-video". Default: "text-to-video".
186
- duration_ui (float): Duration of the output video in seconds. Range: 0.3 to 8.5. Default: 2.0.
187
- ui_frames_to_use (int): Number of frames to use from input video. Only used in video-to-video mode. Must be N*8+1. Default: 9.
188
- seed_ui (int): Random seed for reproducible generation. Range: 0 to 2^32-1. Default: 42.
189
- randomize_seed (bool): Whether to use a random seed instead of seed_ui. Default: True.
190
- ui_guidance_scale (float): CFG scale controlling prompt influence. Range: 1.0 to 10.0. Higher values = stronger prompt influence. Default: 3.0.
191
- improve_texture_flag (bool): Whether to use multi-scale generation for better texture quality. Slower but higher quality. Default: True.
192
- progress (gr.Progress): Progress tracker for the generation process. Optional, used for UI updates.
193
-
194
- Returns:
195
- tuple: A tuple containing (output_video_path, used_seed) where output_video_path is the path to the generated video file and used_seed is the actual seed used for generation.
196
- """
197
 
198
- # Validate mode-specific required parameters
199
  if mode == "image-to-video":
200
  if not input_image_filepath:
201
  raise gr.Error("input_image_filepath is required for image-to-video mode")
@@ -203,7 +221,6 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
203
  if not input_video_filepath:
204
  raise gr.Error("input_video_filepath is required for video-to-video mode")
205
  elif mode == "text-to-video":
206
- # No additional file inputs required for text-to-video
207
  pass
208
  else:
209
  raise gr.Error(f"Invalid mode: {mode}. Must be one of: text-to-video, image-to-video, video-to-video")
@@ -211,27 +228,27 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
211
  if randomize_seed:
212
  seed_ui = random.randint(0, 2**32 - 1)
213
  seed_everething(int(seed_ui))
214
-
215
  target_frames_ideal = duration_ui * fps
216
  target_frames_rounded = round(target_frames_ideal)
217
- if target_frames_rounded < 1:
218
  target_frames_rounded = 1
219
-
220
  n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
221
  actual_num_frames = int(n_val * 8 + 1)
222
 
223
  actual_num_frames = max(9, actual_num_frames)
224
  actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
225
-
226
  actual_height = int(height_ui)
227
  actual_width = int(width_ui)
228
 
229
  height_padded = ((actual_height - 1) // 32 + 1) * 32
230
  width_padded = ((actual_width - 1) // 32 + 1) * 32
231
- num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
232
  if num_frames_padded != actual_num_frames:
233
  print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
234
-
235
  padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
236
 
237
  call_kwargs = {
@@ -239,11 +256,11 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
239
  "negative_prompt": negative_prompt,
240
  "height": height_padded,
241
  "width": width_padded,
242
- "num_frames": num_frames_padded,
243
  "num_inference_steps": num_steps,
244
- "frame_rate": int(fps),
245
  "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
246
- "output_type": "pt",
247
  "conditioning_items": None,
248
  "media_items": None,
249
  "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
@@ -283,9 +300,9 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
283
  try:
284
  call_kwargs["media_items"] = load_media_file(
285
  media_path=input_video_filepath,
286
- height=actual_height,
287
  width=actual_width,
288
- max_frames=int(ui_frames_to_use),
289
  padding=padding_values
290
  ).to(target_inference_device)
291
  except Exception as e:
@@ -293,7 +310,7 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
293
  raise gr.Error(f"Could not load video: {e}")
294
 
295
  print(f"Moving models to {target_inference_device} for inference (if not already there)...")
296
-
297
  active_latent_upsampler = None
298
  if improve_texture_flag and latent_upsampler_instance:
299
  active_latent_upsampler = latent_upsampler_instance
@@ -302,27 +319,24 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
302
  if improve_texture_flag:
303
  if not active_latent_upsampler:
304
  raise gr.Error("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
305
-
306
  multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
307
-
308
  first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
309
- first_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
310
- # num_inference_steps will be derived from len(timesteps) in the pipeline
311
  first_pass_args.pop("num_inference_steps", None)
312
 
313
-
314
  second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
315
- second_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
316
- # num_inference_steps will be derived from len(timesteps) in the pipeline
317
  second_pass_args.pop("num_inference_steps", None)
318
-
319
  multi_scale_call_kwargs = call_kwargs.copy()
320
  multi_scale_call_kwargs.update({
321
  "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
322
  "first_pass": first_pass_args,
323
  "second_pass": second_pass_args,
324
  })
325
-
326
  print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
327
  result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
328
  else:
@@ -330,17 +344,16 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
330
  first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
331
 
332
  single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
333
- single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
334
  single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
335
  single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
336
  single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
337
-
338
- # Remove keys that might conflict or are not used in single pass / handled by above
339
- single_pass_call_kwargs.pop("num_inference_steps", None)
340
- single_pass_call_kwargs.pop("first_pass", None)
341
  single_pass_call_kwargs.pop("second_pass", None)
342
  single_pass_call_kwargs.pop("downscale_factor", None)
343
-
344
  print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
345
  result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
346
 
@@ -350,20 +363,20 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
350
  pad_left, pad_right, pad_top, pad_bottom = padding_values
351
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
352
  slice_w_end = -pad_right if pad_right > 0 else None
353
-
354
  result_images_tensor = result_images_tensor[
355
  :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
356
  ]
357
 
358
  video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
359
-
360
- video_np = np.clip(video_np, 0, 1)
361
  video_np = (video_np * 255).astype(np.uint8)
362
 
363
  temp_dir = tempfile.mkdtemp()
364
  timestamp = random.randint(10000,99999)
365
  output_video_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
366
-
367
  try:
368
  with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
369
  for frame_idx in range(video_np.shape[0]):
@@ -379,8 +392,10 @@ def generate(prompt, negative_prompt, input_image_filepath=None, input_video_fil
379
  except Exception as e2:
380
  print(f"Fallback video saving error: {e2}")
381
  raise gr.Error(f"Failed to save video: {e2}")
382
-
383
- return output_video_path, seed_ui
 
 
384
 
385
  def update_task_image():
386
  return "image-to-video"
@@ -401,11 +416,10 @@ with gr.Blocks(css=css) as demo:
401
  gr.Markdown("# LTX Video 0.9.8 13B Distilled")
402
  gr.Markdown("Fast high quality video generation.**Update (17/07):** now with the new v0.9.8 for improved prompt understanding and detail generation" )
403
  gr.Markdown("[Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.8-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](https://huggingface.co/Lightricks/LTX-Video-0.9.8-13B-distilled#diffusers-🧨)")
404
-
405
  with gr.Row():
406
  with gr.Column():
407
  with gr.Tab("image-to-video") as image_tab:
408
- # The hidden textboxes are a good way to manage state for each tab
409
  video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
410
  image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam", "clipboard"])
411
  i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
@@ -416,8 +430,7 @@ with gr.Blocks(css=css) as demo:
416
  video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
417
  t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
418
  t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
419
-
420
- # This tab was set to visible=False, making it inaccessible. We need to change that.
421
  with gr.Tab("video-to-video") as video_tab:
422
  image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
423
  video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"])
@@ -426,20 +439,22 @@ with gr.Blocks(css=css) as demo:
426
  v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
427
 
428
  duration_input = gr.Slider(
429
- label="Video Duration (seconds)",
430
- minimum=3.0,
431
- maximum=60.0,
432
- value=5.0,
433
- step=0.1,
434
  info=f"Target video duration (0.3s to 8.5s)"
435
  )
436
  improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, visible=True, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
437
 
438
  with gr.Column():
439
  output_video = gr.Video(label="Generated Video", interactive=False)
440
-
 
 
 
441
  with gr.Accordion("Advanced settings", open=False):
442
- # We'll use this dropdown to track the currently selected mode from the tabs
443
  mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", value="image-to-video", visible=False)
444
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
445
  with gr.Row():
@@ -471,11 +486,11 @@ with gr.Blocks(css=css) as demo:
471
  if not video_filepath:
472
  return gr.update(value=current_h), gr.update(value=current_w)
473
  try:
474
- video_filepath_str = str(video_filepath)
475
  if not os.path.exists(video_filepath_str):
476
  print(f"Video file path does not exist for dimension update: {video_filepath_str}")
477
  return gr.update(value=current_h), gr.update(value=current_w)
478
-
479
  orig_w, orig_h = -1, -1
480
  with imageio.get_reader(video_filepath_str) as reader:
481
  meta = reader.get_meta_data()
@@ -488,7 +503,7 @@ with gr.Blocks(css=css) as demo:
488
  except Exception as e_frame:
489
  print(f"Could not get video size from metadata or first frame: {e_frame}")
490
  return gr.update(value=current_h), gr.update(value=current_w)
491
-
492
  if orig_w == -1 or orig_h == -1:
493
  print(f"Could not determine dimensions for video: {video_filepath_str}")
494
  return gr.update(value=current_h), gr.update(value=current_w)
@@ -498,7 +513,7 @@ with gr.Blocks(css=css) as demo:
498
  except Exception as e:
499
  print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
500
  return gr.update(value=current_h), gr.update(value=current_w)
501
-
502
  image_i2v.upload(
503
  fn=handle_image_upload_for_dims,
504
  inputs=[image_i2v, height_input, width_input],
@@ -518,34 +533,50 @@ with gr.Blocks(css=css) as demo:
518
  fn=update_task_text,
519
  outputs=[mode]
520
  )
521
-
522
- # This is the new, crucial event handler for the video tab
523
  video_tab.select(
524
  fn=update_task_video,
525
  outputs=[mode]
526
  )
527
 
 
 
 
 
528
  t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
529
  height_input, width_input, mode,
530
- duration_input, frames_to_use,
531
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps]
532
-
533
  i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
534
  height_input, width_input, mode,
535
- duration_input, frames_to_use,
536
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps]
537
-
538
  v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
539
  height_input, width_input, mode,
540
- duration_input, frames_to_use,
541
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps]
542
 
543
- t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input], api_name="text_to_video")
544
- i2v_button.click(fn=generate, inputs=i2v_inputs, outputs=[output_video, seed_input], api_name="image_to_video")
545
- v2v_button.click(fn=generate, inputs=v2v_inputs, outputs=[output_video, seed_input], api_name="video_to_video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  if __name__ == "__main__":
548
  if os.path.exists(models_dir) and os.path.isdir(models_dir):
549
  print(f"Model directory: {Path(models_dir).resolve()}")
550
-
551
  demo.queue().launch(debug=True, share=False, mcp_server=True)
 
103
  target_inference_device = "cuda"
104
  print(f"Target inference device: {target_inference_device}")
105
  pipeline_instance.to(target_inference_device)
106
+ if latent_upsampler_instance:
107
  latent_upsampler_instance.to(target_inference_device)
108
 
109
 
 
125
  new_h = TARGET_FIXED_SIDE
126
  aspect_ratio = orig_w / orig_h
127
  new_w_ideal = new_h * aspect_ratio
128
+
129
  # Round to nearest multiple of 32
130
  new_w = round(new_w_ideal / 32) * 32
131
+
132
  # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
133
  new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
134
  # Ensure new_h is also clamped (TARGET_FIXED_SIDE should be within these bounds if configured correctly)
135
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
136
  else: # Portrait
137
  new_w = TARGET_FIXED_SIDE
138
  aspect_ratio = orig_h / orig_w # Use H/W ratio for portrait scaling
139
  new_h_ideal = new_w * aspect_ratio
140
+
141
  # Round to nearest multiple of 32
142
  new_h = round(new_h_ideal / 32) * 32
143
+
144
  # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
145
  new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
146
  # Ensure new_w is also clamped
 
165
  else:
166
  return 45
167
 
168
+ # --- NEW ---
169
+ # This function handles the logic for extracting the last frame of the output video.
170
+ def use_last_frame_as_input(video_filepath):
171
+ """
172
+ Extracts the last frame from a video, saves it as a temp image,
173
+ and returns its path to update the image-to-video tab.
174
+ Also returns an update to switch to the image-to-video tab.
175
+ """
176
+ if not video_filepath or not os.path.exists(video_filepath):
177
+ gr.Warning("No video available to get the last frame from.")
178
+ return None, gr.update()
179
+
180
+ try:
181
+ print(f"Extracting last frame from {video_filepath}")
182
+ with imageio.get_reader(video_filepath) as reader:
183
+ # Iterating is a robust way to get the last frame if len() is not available
184
+ last_frame_np = None
185
+ for frame in reader:
186
+ last_frame_np = frame
187
+
188
+ if last_frame_np is None:
189
+ raise ValueError("Could not read any frames from the video.")
190
+
191
+ pil_image = Image.fromarray(last_frame_np)
192
+
193
+ # Save to a temporary file
194
+ temp_dir = tempfile.mkdtemp()
195
+ timestamp = random.randint(10000, 99999)
196
+ output_image_path = os.path.join(temp_dir, f"last_frame_{timestamp}.png")
197
+ pil_image.save(output_image_path)
198
+ print(f"Saved last frame to {output_image_path}")
199
+
200
+ # Return the path to the new image and an update to select the i2v tab
201
+ return output_image_path, gr.Tab(selected=True)
202
+
203
+ except Exception as e:
204
+ print(f"Error extracting last frame: {e}")
205
+ gr.Error(f"Failed to extract the last frame: {e}")
206
+ return None, gr.update()
207
+
208
  @spaces.GPU(duration=get_duration)
209
  def generate(prompt, negative_prompt, input_image_filepath=None, input_video_filepath=None,
210
  height_ui=512, width_ui=704, mode="text-to-video",
211
+ duration_ui=2.0,
212
  ui_frames_to_use=9,
213
  seed_ui=42, randomize_seed=True, ui_guidance_scale=3.0, improve_texture_flag=True, num_steps=20, fps=30.0,
214
  progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ # ... (the beginning of the generate function is unchanged)
217
  if mode == "image-to-video":
218
  if not input_image_filepath:
219
  raise gr.Error("input_image_filepath is required for image-to-video mode")
 
221
  if not input_video_filepath:
222
  raise gr.Error("input_video_filepath is required for video-to-video mode")
223
  elif mode == "text-to-video":
 
224
  pass
225
  else:
226
  raise gr.Error(f"Invalid mode: {mode}. Must be one of: text-to-video, image-to-video, video-to-video")
 
228
  if randomize_seed:
229
  seed_ui = random.randint(0, 2**32 - 1)
230
  seed_everething(int(seed_ui))
231
+
232
  target_frames_ideal = duration_ui * fps
233
  target_frames_rounded = round(target_frames_ideal)
234
+ if target_frames_rounded < 1:
235
  target_frames_rounded = 1
236
+
237
  n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
238
  actual_num_frames = int(n_val * 8 + 1)
239
 
240
  actual_num_frames = max(9, actual_num_frames)
241
  actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
242
+
243
  actual_height = int(height_ui)
244
  actual_width = int(width_ui)
245
 
246
  height_padded = ((actual_height - 1) // 32 + 1) * 32
247
  width_padded = ((actual_width - 1) // 32 + 1) * 32
248
+ num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
249
  if num_frames_padded != actual_num_frames:
250
  print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
251
+
252
  padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
253
 
254
  call_kwargs = {
 
256
  "negative_prompt": negative_prompt,
257
  "height": height_padded,
258
  "width": width_padded,
259
+ "num_frames": num_frames_padded,
260
  "num_inference_steps": num_steps,
261
+ "frame_rate": int(fps),
262
  "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
263
+ "output_type": "pt",
264
  "conditioning_items": None,
265
  "media_items": None,
266
  "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
 
300
  try:
301
  call_kwargs["media_items"] = load_media_file(
302
  media_path=input_video_filepath,
303
+ height=actual_height,
304
  width=actual_width,
305
+ max_frames=int(ui_frames_to_use),
306
  padding=padding_values
307
  ).to(target_inference_device)
308
  except Exception as e:
 
310
  raise gr.Error(f"Could not load video: {e}")
311
 
312
  print(f"Moving models to {target_inference_device} for inference (if not already there)...")
313
+
314
  active_latent_upsampler = None
315
  if improve_texture_flag and latent_upsampler_instance:
316
  active_latent_upsampler = latent_upsampler_instance
 
319
  if improve_texture_flag:
320
  if not active_latent_upsampler:
321
  raise gr.Error("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
322
+
323
  multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
324
+
325
  first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
326
+ first_pass_args["guidance_scale"] = float(ui_guidance_scale)
 
327
  first_pass_args.pop("num_inference_steps", None)
328
 
 
329
  second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
330
+ second_pass_args["guidance_scale"] = float(ui_guidance_scale)
 
331
  second_pass_args.pop("num_inference_steps", None)
332
+
333
  multi_scale_call_kwargs = call_kwargs.copy()
334
  multi_scale_call_kwargs.update({
335
  "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
336
  "first_pass": first_pass_args,
337
  "second_pass": second_pass_args,
338
  })
339
+
340
  print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
341
  result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
342
  else:
 
344
  first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
345
 
346
  single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
347
+ single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
348
  single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
349
  single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
350
  single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
351
+
352
+ single_pass_call_kwargs.pop("num_inference_steps", None)
353
+ single_pass_call_kwargs.pop("first_pass", None)
 
354
  single_pass_call_kwargs.pop("second_pass", None)
355
  single_pass_call_kwargs.pop("downscale_factor", None)
356
+
357
  print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
358
  result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
359
 
 
363
  pad_left, pad_right, pad_top, pad_bottom = padding_values
364
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
365
  slice_w_end = -pad_right if pad_right > 0 else None
366
+
367
  result_images_tensor = result_images_tensor[
368
  :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
369
  ]
370
 
371
  video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
372
+
373
+ video_np = np.clip(video_np, 0, 1)
374
  video_np = (video_np * 255).astype(np.uint8)
375
 
376
  temp_dir = tempfile.mkdtemp()
377
  timestamp = random.randint(10000,99999)
378
  output_video_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
379
+
380
  try:
381
  with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
382
  for frame_idx in range(video_np.shape[0]):
 
392
  except Exception as e2:
393
  print(f"Fallback video saving error: {e2}")
394
  raise gr.Error(f"Failed to save video: {e2}")
395
+
396
+ # --- MODIFIED ---
397
+ # The function now returns a third value: an update to make the new button visible.
398
+ return output_video_path, seed_ui, gr.update(visible=True)
399
 
400
  def update_task_image():
401
  return "image-to-video"
 
416
  gr.Markdown("# LTX Video 0.9.8 13B Distilled")
417
  gr.Markdown("Fast high quality video generation.**Update (17/07):** now with the new v0.9.8 for improved prompt understanding and detail generation" )
418
  gr.Markdown("[Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.8-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](https://huggingface.co/Lightricks/LTX-Video-0.9.8-13B-distilled#diffusers-🧨)")
419
+
420
  with gr.Row():
421
  with gr.Column():
422
  with gr.Tab("image-to-video") as image_tab:
 
423
  video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
424
  image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam", "clipboard"])
425
  i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
 
430
  video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
431
  t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
432
  t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
433
+
 
434
  with gr.Tab("video-to-video") as video_tab:
435
  image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
436
  video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"])
 
439
  v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
440
 
441
  duration_input = gr.Slider(
442
+ label="Video Duration (seconds)",
443
+ minimum=3.0,
444
+ maximum=60.0,
445
+ value=5.0,
446
+ step=0.1,
447
  info=f"Target video duration (0.3s to 8.5s)"
448
  )
449
  improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, visible=True, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
450
 
451
  with gr.Column():
452
  output_video = gr.Video(label="Generated Video", interactive=False)
453
+ # --- NEW ---
454
+ # This is the new button, initially hidden.
455
+ use_last_frame_button = gr.Button("Use Last Frame as Input Image", visible=False)
456
+
457
  with gr.Accordion("Advanced settings", open=False):
 
458
  mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", value="image-to-video", visible=False)
459
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
460
  with gr.Row():
 
486
  if not video_filepath:
487
  return gr.update(value=current_h), gr.update(value=current_w)
488
  try:
489
+ video_filepath_str = str(video_filepath)
490
  if not os.path.exists(video_filepath_str):
491
  print(f"Video file path does not exist for dimension update: {video_filepath_str}")
492
  return gr.update(value=current_h), gr.update(value=current_w)
493
+
494
  orig_w, orig_h = -1, -1
495
  with imageio.get_reader(video_filepath_str) as reader:
496
  meta = reader.get_meta_data()
 
503
  except Exception as e_frame:
504
  print(f"Could not get video size from metadata or first frame: {e_frame}")
505
  return gr.update(value=current_h), gr.update(value=current_w)
506
+
507
  if orig_w == -1 or orig_h == -1:
508
  print(f"Could not determine dimensions for video: {video_filepath_str}")
509
  return gr.update(value=current_h), gr.update(value=current_w)
 
513
  except Exception as e:
514
  print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
515
  return gr.update(value=current_h), gr.update(value=current_w)
516
+
517
  image_i2v.upload(
518
  fn=handle_image_upload_for_dims,
519
  inputs=[image_i2v, height_input, width_input],
 
533
  fn=update_task_text,
534
  outputs=[mode]
535
  )
 
 
536
  video_tab.select(
537
  fn=update_task_video,
538
  outputs=[mode]
539
  )
540
 
541
+ # --- MODIFIED ---
542
+ # The outputs for the generate buttons now include `use_last_frame_button` to control its visibility.
543
+ # We also use the `.then()` method to chain events. First, we instantly hide the button,
544
+ # then we run the main (and slower) video generation process.
545
  t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
546
  height_input, width_input, mode,
547
+ duration_input, frames_to_use,
548
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps]
 
549
  i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
550
  height_input, width_input, mode,
551
+ duration_input, frames_to_use,
552
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps]
 
553
  v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
554
  height_input, width_input, mode,
555
+ duration_input, frames_to_use,
556
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps]
557
 
558
+ t2v_button.click(fn=lambda: gr.update(visible=False), outputs=[use_last_frame_button], queue=False).then(
559
+ fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input, use_last_frame_button], api_name="text_to_video"
560
+ )
561
+ i2v_button.click(fn=lambda: gr.update(visible=False), outputs=[use_last_frame_button], queue=False).then(
562
+ fn=generate, inputs=i2v_inputs, outputs=[output_video, seed_input, use_last_frame_button], api_name="image_to_video"
563
+ )
564
+ v2v_button.click(fn=lambda: gr.update(visible=False), outputs=[use_last_frame_button], queue=False).then(
565
+ fn=generate, inputs=v2v_inputs, outputs=[output_video, seed_input, use_last_frame_button], api_name="video_to_video"
566
+ )
567
+
568
+ # --- NEW ---
569
+ # This is the event handler for our new button.
570
+ # It takes the generated video as input.
571
+ # It updates the image component in the first tab and also selects that tab.
572
+ use_last_frame_button.click(
573
+ fn=use_last_frame_as_input,
574
+ inputs=[output_video],
575
+ outputs=[image_i2v, image_tab]
576
+ )
577
 
578
  if __name__ == "__main__":
579
  if os.path.exists(models_dir) and os.path.isdir(models_dir):
580
  print(f"Model directory: {Path(models_dir).resolve()}")
581
+
582
  demo.queue().launch(debug=True, share=False, mcp_server=True)