eeuuia commited on
Commit
c75608c
·
verified ·
1 Parent(s): 2f53bb4

Delete api/ltx/inference.py

Browse files
Files changed (1) hide show
  1. api/ltx/inference.py +0 -785
api/ltx/inference.py DELETED
@@ -1,785 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
- from datetime import datetime
5
- from pathlib import Path
6
- from diffusers.utils import logging
7
- from typing import Optional, List, Union
8
- import yaml
9
-
10
- from huggingface_hub import logging
11
-
12
-
13
-
14
- logging.set_verbosity_error()
15
- logging.set_verbosity_warning()
16
- logging.set_verbosity_info()
17
- logging.set_verbosity_debug()
18
-
19
-
20
-
21
- import imageio
22
- import json
23
- import numpy as np
24
- import torch
25
- import cv2
26
- from safetensors import safe_open
27
- from PIL import Image
28
- from transformers import (
29
- T5EncoderModel,
30
- T5Tokenizer,
31
- AutoModelForCausalLM,
32
- AutoProcessor,
33
- AutoTokenizer,
34
- )
35
- from huggingface_hub import hf_hub_download
36
-
37
- from ltx_video.models.autoencoders.causal_video_autoencoder import (
38
- CausalVideoAutoencoder,
39
- )
40
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
41
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
42
- from ltx_video.pipelines.pipeline_ltx_video import (
43
- ConditioningItem,
44
- LTXVideoPipeline,
45
- LTXMultiScalePipeline,
46
- )
47
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
48
- from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
49
- from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
50
- import ltx_video.pipelines.crf_compressor as crf_compressor
51
-
52
- MAX_HEIGHT = 720
53
- MAX_WIDTH = 1280
54
- MAX_NUM_FRAMES = 257
55
-
56
- logger = logging.get_logger("LTX-Video")
57
-
58
-
59
- def get_total_gpu_memory():
60
- if torch.cuda.is_available():
61
- total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
62
- return total_memory
63
- return 0
64
-
65
-
66
- def get_device():
67
- if torch.cuda.is_available():
68
- return "cuda"
69
- elif torch.backends.mps.is_available():
70
- return "mps"
71
- return "cpu"
72
-
73
-
74
- def load_image_to_tensor_with_resize_and_crop(
75
- image_input: Union[str, Image.Image],
76
- target_height: int = 512,
77
- target_width: int = 768,
78
- just_crop: bool = False,
79
- ) -> torch.Tensor:
80
- """Load and process an image into a tensor.
81
-
82
- Args:
83
- image_input: Either a file path (str) or a PIL Image object
84
- target_height: Desired height of output tensor
85
- target_width: Desired width of output tensor
86
- just_crop: If True, only crop the image to the target size without resizing
87
- """
88
- if isinstance(image_input, str):
89
- image = Image.open(image_input).convert("RGB")
90
- elif isinstance(image_input, Image.Image):
91
- image = image_input
92
- else:
93
- raise ValueError("image_input must be either a file path or a PIL Image object")
94
-
95
- input_width, input_height = image.size
96
- aspect_ratio_target = target_width / target_height
97
- aspect_ratio_frame = input_width / input_height
98
- if aspect_ratio_frame > aspect_ratio_target:
99
- new_width = int(input_height * aspect_ratio_target)
100
- new_height = input_height
101
- x_start = (input_width - new_width) // 2
102
- y_start = 0
103
- else:
104
- new_width = input_width
105
- new_height = int(input_width / aspect_ratio_target)
106
- x_start = 0
107
- y_start = (input_height - new_height) // 2
108
-
109
- image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
110
- if not just_crop:
111
- image = image.resize((target_width, target_height))
112
-
113
- image = np.array(image)
114
- image = cv2.GaussianBlur(image, (3, 3), 0)
115
- frame_tensor = torch.from_numpy(image).float()
116
- frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
117
- frame_tensor = frame_tensor.permute(2, 0, 1)
118
- frame_tensor = (frame_tensor / 127.5) - 1.0
119
- # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
120
- return frame_tensor.unsqueeze(0).unsqueeze(2)
121
-
122
-
123
- def calculate_padding(
124
- source_height: int, source_width: int, target_height: int, target_width: int
125
- ) -> tuple[int, int, int, int]:
126
-
127
- # Calculate total padding needed
128
- pad_height = target_height - source_height
129
- pad_width = target_width - source_width
130
-
131
- # Calculate padding for each side
132
- pad_top = pad_height // 2
133
- pad_bottom = pad_height - pad_top # Handles odd padding
134
- pad_left = pad_width // 2
135
- pad_right = pad_width - pad_left # Handles odd padding
136
-
137
- # Return padded tensor
138
- # Padding format is (left, right, top, bottom)
139
- padding = (pad_left, pad_right, pad_top, pad_bottom)
140
- return padding
141
-
142
-
143
- def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
144
- # Remove non-letters and convert to lowercase
145
- clean_text = "".join(
146
- char.lower() for char in text if char.isalpha() or char.isspace()
147
- )
148
-
149
- # Split into words
150
- words = clean_text.split()
151
-
152
- # Build result string keeping track of length
153
- result = []
154
- current_length = 0
155
-
156
- for word in words:
157
- # Add word length plus 1 for underscore (except for first word)
158
- new_length = current_length + len(word)
159
-
160
- if new_length <= max_len:
161
- result.append(word)
162
- current_length += len(word)
163
- else:
164
- break
165
-
166
- return "-".join(result)
167
-
168
-
169
- # Generate output video name
170
- def get_unique_filename(
171
- base: str,
172
- ext: str,
173
- prompt: str,
174
- seed: int,
175
- resolution: tuple[int, int, int],
176
- dir: Path,
177
- endswith=None,
178
- index_range=1000,
179
- ) -> Path:
180
- base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
181
- for i in range(index_range):
182
- filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
183
- if not os.path.exists(filename):
184
- return filename
185
- raise FileExistsError(
186
- f"Could not find a unique filename after {index_range} attempts."
187
- )
188
-
189
-
190
- def seed_everething(seed: int):
191
- random.seed(seed)
192
- np.random.seed(seed)
193
- torch.manual_seed(seed)
194
- if torch.cuda.is_available():
195
- torch.cuda.manual_seed(seed)
196
- if torch.backends.mps.is_available():
197
- torch.mps.manual_seed(seed)
198
-
199
-
200
- def main():
201
- parser = argparse.ArgumentParser(
202
- description="Load models from separate directories and run the pipeline."
203
- )
204
-
205
- # Directories
206
- parser.add_argument(
207
- "--output_path",
208
- type=str,
209
- default=None,
210
- help="Path to the folder to save output video, if None will save in outputs/ directory.",
211
- )
212
- parser.add_argument("--seed", type=int, default="171198")
213
-
214
- # Pipeline parameters
215
- parser.add_argument(
216
- "--num_images_per_prompt",
217
- type=int,
218
- default=1,
219
- help="Number of images per prompt",
220
- )
221
- parser.add_argument(
222
- "--image_cond_noise_scale",
223
- type=float,
224
- default=0.15,
225
- help="Amount of noise to add to the conditioned image",
226
- )
227
- parser.add_argument(
228
- "--height",
229
- type=int,
230
- default=704,
231
- help="Height of the output video frames. Optional if an input image provided.",
232
- )
233
- parser.add_argument(
234
- "--width",
235
- type=int,
236
- default=1216,
237
- help="Width of the output video frames. If None will infer from input image.",
238
- )
239
- parser.add_argument(
240
- "--num_frames",
241
- type=int,
242
- default=121,
243
- help="Number of frames to generate in the output video",
244
- )
245
- parser.add_argument(
246
- "--frame_rate", type=int, default=30, help="Frame rate for the output video"
247
- )
248
- parser.add_argument(
249
- "--device",
250
- default=None,
251
- help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
252
- )
253
- parser.add_argument(
254
- "--pipeline_config",
255
- type=str,
256
- default="configs/ltxv-13b-0.9.7-dev.yaml",
257
- help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
258
- )
259
-
260
- # Prompts
261
- parser.add_argument(
262
- "--prompt",
263
- type=str,
264
- help="Text prompt to guide generation",
265
- )
266
- parser.add_argument(
267
- "--negative_prompt",
268
- type=str,
269
- default="worst quality, inconsistent motion, blurry, jittery, distorted",
270
- help="Negative prompt for undesired features",
271
- )
272
-
273
- parser.add_argument(
274
- "--offload_to_cpu",
275
- action="store_true",
276
- help="Offloading unnecessary computations to CPU.",
277
- )
278
-
279
- # video-to-video arguments:
280
- parser.add_argument(
281
- "--input_media_path",
282
- type=str,
283
- default=None,
284
- help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
285
- )
286
-
287
- # Conditioning arguments
288
- parser.add_argument(
289
- "--conditioning_media_paths",
290
- type=str,
291
- nargs="*",
292
- help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
293
- )
294
- parser.add_argument(
295
- "--conditioning_strengths",
296
- type=float,
297
- nargs="*",
298
- help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
299
- )
300
- parser.add_argument(
301
- "--conditioning_start_frames",
302
- type=int,
303
- nargs="*",
304
- help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
305
- )
306
-
307
- args = parser.parse_args()
308
- logger.warning(f"Running generation with arguments: {args}")
309
- infer(**vars(args))
310
-
311
-
312
- def create_ltx_video_pipeline(
313
- ckpt_path: str,
314
- precision: str,
315
- text_encoder_model_name_or_path: str,
316
- sampler: Optional[str] = None,
317
- device: Optional[str] = None,
318
- enhance_prompt: bool = False,
319
- prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
320
- prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
321
- ) -> LTXVideoPipeline:
322
- ckpt_path = Path(ckpt_path)
323
- assert os.path.exists(
324
- ckpt_path
325
- ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
326
-
327
- with safe_open(ckpt_path, framework="pt") as f:
328
- metadata = f.metadata()
329
- config_str = metadata.get("config")
330
- configs = json.loads(config_str)
331
- allowed_inference_steps = configs.get("allowed_inference_steps", None)
332
-
333
- vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
334
- transformer = Transformer3DModel.from_pretrained(ckpt_path)
335
-
336
- # Use constructor if sampler is specified, otherwise use from_pretrained
337
- if sampler == "from_checkpoint" or not sampler:
338
- scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
339
- else:
340
- scheduler = RectifiedFlowScheduler(
341
- sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
342
- )
343
-
344
- text_encoder = T5EncoderModel.from_pretrained(
345
- text_encoder_model_name_or_path, subfolder="text_encoder"
346
- )
347
- patchifier = SymmetricPatchifier(patch_size=1)
348
- tokenizer = T5Tokenizer.from_pretrained(
349
- text_encoder_model_name_or_path, subfolder="tokenizer"
350
- )
351
-
352
- transformer = transformer.to(device)
353
- vae = vae.to(device)
354
- text_encoder = text_encoder.to(device)
355
-
356
- if enhance_prompt:
357
- prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
358
- prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
359
- )
360
- prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
361
- prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
362
- )
363
- prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
364
- prompt_enhancer_llm_model_name_or_path,
365
- torch_dtype="bfloat16",
366
- )
367
- prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
368
- prompt_enhancer_llm_model_name_or_path,
369
- )
370
- else:
371
- prompt_enhancer_image_caption_model = None
372
- prompt_enhancer_image_caption_processor = None
373
- prompt_enhancer_llm_model = None
374
- prompt_enhancer_llm_tokenizer = None
375
-
376
- vae = vae.to(torch.bfloat16)
377
- if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
378
- transformer = transformer.to(torch.bfloat16)
379
- text_encoder = text_encoder.to(torch.bfloat16)
380
-
381
- # Use submodels for the pipeline
382
- submodel_dict = {
383
- "transformer": transformer,
384
- "patchifier": patchifier,
385
- "text_encoder": text_encoder,
386
- "tokenizer": tokenizer,
387
- "scheduler": scheduler,
388
- "vae": vae,
389
- "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
390
- "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
391
- "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
392
- "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
393
- "allowed_inference_steps": allowed_inference_steps,
394
- }
395
-
396
- pipeline = LTXVideoPipeline(**submodel_dict)
397
- pipeline = pipeline.to(device)
398
- return pipeline
399
-
400
-
401
- def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
402
- latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
403
- latent_upsampler.to(device)
404
- latent_upsampler.eval()
405
- return latent_upsampler
406
-
407
-
408
- def infer(
409
- output_path: Optional[str],
410
- seed: int,
411
- pipeline_config: str,
412
- image_cond_noise_scale: float,
413
- height: Optional[int],
414
- width: Optional[int],
415
- num_frames: int,
416
- frame_rate: int,
417
- prompt: str,
418
- negative_prompt: str,
419
- offload_to_cpu: bool,
420
- input_media_path: Optional[str] = None,
421
- conditioning_media_paths: Optional[List[str]] = None,
422
- conditioning_strengths: Optional[List[float]] = None,
423
- conditioning_start_frames: Optional[List[int]] = None,
424
- device: Optional[str] = None,
425
- **kwargs,
426
- ):
427
- # check if pipeline_config is a file
428
- if not os.path.isfile(pipeline_config):
429
- raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
430
- with open(pipeline_config, "r") as f:
431
- pipeline_config = yaml.safe_load(f)
432
-
433
- models_dir = "MODEL_DIR"
434
-
435
- ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
436
- if not os.path.isfile(ltxv_model_name_or_path):
437
- ltxv_model_path = hf_hub_download(
438
- repo_id="Lightricks/LTX-Video",
439
- filename=ltxv_model_name_or_path,
440
- local_dir=models_dir,
441
- repo_type="model",
442
- )
443
- else:
444
- ltxv_model_path = ltxv_model_name_or_path
445
-
446
- spatial_upscaler_model_name_or_path = pipeline_config.get(
447
- "spatial_upscaler_model_path"
448
- )
449
- if spatial_upscaler_model_name_or_path and not os.path.isfile(
450
- spatial_upscaler_model_name_or_path
451
- ):
452
- spatial_upscaler_model_path = hf_hub_download(
453
- repo_id="Lightricks/LTX-Video",
454
- filename=spatial_upscaler_model_name_or_path,
455
- local_dir=models_dir,
456
- repo_type="model",
457
- )
458
- else:
459
- spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
460
-
461
- if kwargs.get("input_image_path", None):
462
- logger.warning(
463
- "Please use conditioning_media_paths instead of input_image_path."
464
- )
465
- assert not conditioning_media_paths and not conditioning_start_frames
466
- conditioning_media_paths = [kwargs["input_image_path"]]
467
- conditioning_start_frames = [0]
468
-
469
- # Validate conditioning arguments
470
- if conditioning_media_paths:
471
- # Use default strengths of 1.0
472
- if not conditioning_strengths:
473
- conditioning_strengths = [1.0] * len(conditioning_media_paths)
474
- if not conditioning_start_frames:
475
- raise ValueError(
476
- "If `conditioning_media_paths` is provided, "
477
- "`conditioning_start_frames` must also be provided"
478
- )
479
- if len(conditioning_media_paths) != len(conditioning_strengths) or len(
480
- conditioning_media_paths
481
- ) != len(conditioning_start_frames):
482
- raise ValueError(
483
- "`conditioning_media_paths`, `conditioning_strengths`, "
484
- "and `conditioning_start_frames` must have the same length"
485
- )
486
- if any(s < 0 or s > 1 for s in conditioning_strengths):
487
- raise ValueError("All conditioning strengths must be between 0 and 1")
488
- if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
489
- raise ValueError(
490
- f"All conditioning start frames must be between 0 and {num_frames-1}"
491
- )
492
-
493
- seed_everething(seed)
494
- if offload_to_cpu and not torch.cuda.is_available():
495
- logger.warning(
496
- "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
497
- )
498
- offload_to_cpu = False
499
- else:
500
- offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
501
-
502
- output_dir = (
503
- Path(output_path)
504
- if output_path
505
- else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
506
- )
507
- output_dir.mkdir(parents=True, exist_ok=True)
508
-
509
- # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
510
- height_padded = ((height - 1) // 32 + 1) * 32
511
- width_padded = ((width - 1) // 32 + 1) * 32
512
- num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
513
-
514
- padding = calculate_padding(height, width, height_padded, width_padded)
515
-
516
- logger.warning(
517
- f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
518
- )
519
-
520
- prompt_enhancement_words_threshold = pipeline_config[
521
- "prompt_enhancement_words_threshold"
522
- ]
523
-
524
- prompt_word_count = len(prompt.split())
525
- enhance_prompt = (
526
- prompt_enhancement_words_threshold > 0
527
- and prompt_word_count < prompt_enhancement_words_threshold
528
- )
529
-
530
- if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
531
- logger.info(
532
- f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
533
- )
534
-
535
- precision = pipeline_config["precision"]
536
- text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
537
- sampler = pipeline_config["sampler"]
538
- prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
539
- "prompt_enhancer_image_caption_model_name_or_path"
540
- ]
541
- prompt_enhancer_llm_model_name_or_path = pipeline_config[
542
- "prompt_enhancer_llm_model_name_or_path"
543
- ]
544
-
545
- pipeline = create_ltx_video_pipeline(
546
- ckpt_path=ltxv_model_path,
547
- precision=precision,
548
- text_encoder_model_name_or_path=text_encoder_model_name_or_path,
549
- sampler=sampler,
550
- device=kwargs.get("device", get_device()),
551
- enhance_prompt=enhance_prompt,
552
- prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
553
- prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
554
- )
555
-
556
- if pipeline_config.get("pipeline_type", None) == "multi-scale":
557
- if not spatial_upscaler_model_path:
558
- raise ValueError(
559
- "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
560
- )
561
- latent_upsampler = create_latent_upsampler(
562
- spatial_upscaler_model_path, pipeline.device
563
- )
564
- pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
565
-
566
- media_item = None
567
- if input_media_path:
568
- media_item = load_media_file(
569
- media_path=input_media_path,
570
- height=height,
571
- width=width,
572
- max_frames=num_frames_padded,
573
- padding=padding,
574
- )
575
-
576
- conditioning_items = (
577
- prepare_conditioning(
578
- conditioning_media_paths=conditioning_media_paths,
579
- conditioning_strengths=conditioning_strengths,
580
- conditioning_start_frames=conditioning_start_frames,
581
- height=height,
582
- width=width,
583
- num_frames=num_frames,
584
- padding=padding,
585
- pipeline=pipeline,
586
- )
587
- if conditioning_media_paths
588
- else None
589
- )
590
-
591
- stg_mode = pipeline_config.get("stg_mode", "attention_values")
592
- del pipeline_config["stg_mode"]
593
- if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
594
- skip_layer_strategy = SkipLayerStrategy.AttentionValues
595
- elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
596
- skip_layer_strategy = SkipLayerStrategy.AttentionSkip
597
- elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
598
- skip_layer_strategy = SkipLayerStrategy.Residual
599
- elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
600
- skip_layer_strategy = SkipLayerStrategy.TransformerBlock
601
- else:
602
- raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
603
-
604
- # Prepare input for the pipeline
605
- sample = {
606
- "prompt": prompt,
607
- "prompt_attention_mask": None,
608
- "negative_prompt": negative_prompt,
609
- "negative_prompt_attention_mask": None,
610
- }
611
-
612
- device = device or get_device()
613
- generator = torch.Generator(device=device).manual_seed(seed)
614
-
615
- images = pipeline(
616
- **pipeline_config,
617
- skip_layer_strategy=skip_layer_strategy,
618
- generator=generator,
619
- output_type="pt",
620
- callback_on_step_end=None,
621
- height=height_padded,
622
- width=width_padded,
623
- num_frames=num_frames_padded,
624
- frame_rate=frame_rate,
625
- **sample,
626
- media_items=media_item,
627
- conditioning_items=conditioning_items,
628
- is_video=True,
629
- vae_per_channel_normalize=True,
630
- image_cond_noise_scale=image_cond_noise_scale,
631
- mixed_precision=(precision == "mixed_precision"),
632
- offload_to_cpu=offload_to_cpu,
633
- device=device,
634
- enhance_prompt=enhance_prompt,
635
- ).images
636
-
637
- # Crop the padded images to the desired resolution and number of frames
638
- (pad_left, pad_right, pad_top, pad_bottom) = padding
639
- pad_bottom = -pad_bottom
640
- pad_right = -pad_right
641
- if pad_bottom == 0:
642
- pad_bottom = images.shape[3]
643
- if pad_right == 0:
644
- pad_right = images.shape[4]
645
- images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
646
-
647
- for i in range(images.shape[0]):
648
- # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
649
- video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
650
- # Unnormalizing images to [0, 255] range
651
- video_np = (video_np * 255).astype(np.uint8)
652
- fps = frame_rate
653
- height, width = video_np.shape[1:3]
654
- # In case a single image is generated
655
- if video_np.shape[0] == 1:
656
- output_filename = get_unique_filename(
657
- f"image_output_{i}",
658
- ".png",
659
- prompt=prompt,
660
- seed=seed,
661
- resolution=(height, width, num_frames),
662
- dir=output_dir,
663
- )
664
- imageio.imwrite(output_filename, video_np[0])
665
- else:
666
- output_filename = get_unique_filename(
667
- f"video_output_{i}",
668
- ".mp4",
669
- prompt=prompt,
670
- seed=seed,
671
- resolution=(height, width, num_frames),
672
- dir=output_dir,
673
- )
674
-
675
- # Write video
676
- with imageio.get_writer(output_filename, fps=fps) as video:
677
- for frame in video_np:
678
- video.append_data(frame)
679
-
680
- logger.warning(f"Output saved to {output_filename}")
681
-
682
-
683
- def prepare_conditioning(
684
- conditioning_media_paths: List[str],
685
- conditioning_strengths: List[float],
686
- conditioning_start_frames: List[int],
687
- height: int,
688
- width: int,
689
- num_frames: int,
690
- padding: tuple[int, int, int, int],
691
- pipeline: LTXVideoPipeline,
692
- ) -> Optional[List[ConditioningItem]]:
693
- """Prepare conditioning items based on input media paths and their parameters.
694
-
695
- Args:
696
- conditioning_media_paths: List of paths to conditioning media (images or videos)
697
- conditioning_strengths: List of conditioning strengths for each media item
698
- conditioning_start_frames: List of frame indices where each item should be applied
699
- height: Height of the output frames
700
- width: Width of the output frames
701
- num_frames: Number of frames in the output video
702
- padding: Padding to apply to the frames
703
- pipeline: LTXVideoPipeline object used for condition video trimming
704
-
705
- Returns:
706
- A list of ConditioningItem objects.
707
- """
708
- conditioning_items = []
709
- for path, strength, start_frame in zip(
710
- conditioning_media_paths, conditioning_strengths, conditioning_start_frames
711
- ):
712
- num_input_frames = orig_num_input_frames = get_media_num_frames(path)
713
- if hasattr(pipeline, "trim_conditioning_sequence") and callable(
714
- getattr(pipeline, "trim_conditioning_sequence")
715
- ):
716
- num_input_frames = pipeline.trim_conditioning_sequence(
717
- start_frame, orig_num_input_frames, num_frames
718
- )
719
- if num_input_frames < orig_num_input_frames:
720
- logger.warning(
721
- f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
722
- )
723
-
724
- media_tensor = load_media_file(
725
- media_path=path,
726
- height=height,
727
- width=width,
728
- max_frames=num_input_frames,
729
- padding=padding,
730
- just_crop=True,
731
- )
732
- conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
733
- return conditioning_items
734
-
735
-
736
- def get_media_num_frames(media_path: str) -> int:
737
- is_video = any(
738
- media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
739
- )
740
- num_frames = 1
741
- if is_video:
742
- reader = imageio.get_reader(media_path)
743
- num_frames = reader.count_frames()
744
- reader.close()
745
- return num_frames
746
-
747
-
748
- def load_media_file(
749
- media_path: str,
750
- height: int,
751
- width: int,
752
- max_frames: int,
753
- padding: tuple[int, int, int, int],
754
- just_crop: bool = False,
755
- ) -> torch.Tensor:
756
- is_video = any(
757
- media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
758
- )
759
- if is_video:
760
- reader = imageio.get_reader(media_path)
761
- num_input_frames = min(reader.count_frames(), max_frames)
762
-
763
- # Read and preprocess the relevant frames from the video file.
764
- frames = []
765
- for i in range(num_input_frames):
766
- frame = Image.fromarray(reader.get_data(i))
767
- frame_tensor = load_image_to_tensor_with_resize_and_crop(
768
- frame, height, width, just_crop=just_crop
769
- )
770
- frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
771
- frames.append(frame_tensor)
772
- reader.close()
773
-
774
- # Stack frames along the temporal dimension
775
- media_tensor = torch.cat(frames, dim=2)
776
- else: # Input image
777
- media_tensor = load_image_to_tensor_with_resize_and_crop(
778
- media_path, height, width, just_crop=just_crop
779
- )
780
- media_tensor = torch.nn.functional.pad(media_tensor, padding)
781
- return media_tensor
782
-
783
-
784
- if __name__ == "__main__":
785
- main()