eeuuia commited on
Commit
79d5be1
·
verified ·
1 Parent(s): 6c0add0

Rename LTX-Video/ltx_video/pipelines/pipeline_ltx_video (3).py to LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py

Browse files
LTX-Video/ltx_video/pipelines/{pipeline_ltx_video (3).py → pipeline_ltx_video.py} RENAMED
@@ -59,8 +59,24 @@ logging.set_verbosity_debug()
59
 
60
 
61
  #logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
63
 
 
 
64
  class SpyLatent:
65
 
66
  """
@@ -120,7 +136,7 @@ class SpyLatent:
120
  # --- Converte para 5D se necessário ---
121
  tensor_5d = self._to_5d(tensor, reference_shape_5d)
122
  if tensor_5d is not None and tensor.ndim == 3:
123
- self._print_stats("Convertido para 5D", tensor_5d)
124
 
125
  # --- Visualização com VAE ---
126
  if save_visual and self.vae is not None and tensor_5d is not None:
@@ -129,7 +145,7 @@ class SpyLatent:
129
 
130
  frame_idx_to_viz = min(1, tensor_5d.shape[2] - 1)
131
  if frame_idx_to_viz < 0:
132
- print(" VISUALIZAÇÃO (VAE): Tensor não tem frames para visualizar.")
133
  else:
134
  #print(f" VISUALIZAÇÃO (VAE): Usando frame de índice {frame_idx_to_viz}.")
135
  latent_slice = tensor_5d[:, :, frame_idx_to_viz:frame_idx_to_viz+1, :, :]
@@ -163,7 +179,7 @@ class SpyLatent:
163
  std = tensor.std().item()
164
  min_val = tensor.min().item()
165
  max_val = tensor.max().item()
166
- print(f" {prefix}: {tensor.shape}")
167
 
168
 
169
 
@@ -240,7 +256,7 @@ ASPECT_RATIO_512_BIN = {
240
  "4.0": [1024.0, 256.0],
241
  }
242
 
243
-
244
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
245
  def retrieve_timesteps(
246
  scheduler,
@@ -308,9 +324,10 @@ def retrieve_timesteps(
308
  num_inference_steps = len(timesteps)
309
 
310
  try:
311
- print(f"[LTX]LATENTS {latents.shape}")
 
312
  except Exception:
313
- pass
314
 
315
 
316
  return timesteps, num_inference_steps
@@ -334,8 +351,9 @@ class ConditioningItem:
334
  conditioning_strength: float
335
  media_x: Optional[int] = None
336
  media_y: Optional[int] = None
 
337
 
338
-
339
  class LTXVideoPipeline(DiffusionPipeline):
340
  r"""
341
  Pipeline for text-to-image generation using LTX-Video.
@@ -387,6 +405,7 @@ class LTXVideoPipeline(DiffusionPipeline):
387
  ]
388
  model_cpu_offload_seq = "prompt_enhancer_image_caption_model->prompt_enhancer_llm_model->text_encoder->transformer->vae"
389
 
 
390
  def __init__(
391
  self,
392
  tokenizer: T5Tokenizer,
@@ -425,6 +444,7 @@ class LTXVideoPipeline(DiffusionPipeline):
425
 
426
  self.spy = SpyLatent(vae=vae)
427
 
 
428
  def mask_text_embeddings(self, emb, mask):
429
  if emb.shape[0] == 1:
430
  keep_index = mask.sum().item()
@@ -434,6 +454,8 @@ class LTXVideoPipeline(DiffusionPipeline):
434
  return masked_feature, emb.shape[2]
435
 
436
  # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
 
 
437
  def encode_prompt(
438
  self,
439
  prompt: Union[str, List[str]],
@@ -628,6 +650,7 @@ class LTXVideoPipeline(DiffusionPipeline):
628
  extra_step_kwargs["generator"] = generator
629
  return extra_step_kwargs
630
 
 
631
  def check_inputs(
632
  self,
633
  prompt,
@@ -714,6 +737,7 @@ class LTXVideoPipeline(DiffusionPipeline):
714
  self.prompt_enhancer_llm_tokenizer is not None
715
  ), "Text prompt enhancer tokenizer must be initialized if enhance_prompt is True"
716
 
 
717
  def _text_preprocessing(self, text):
718
  if not isinstance(text, (tuple, list)):
719
  text = [text]
@@ -725,6 +749,7 @@ class LTXVideoPipeline(DiffusionPipeline):
725
  return [process(t) for t in text]
726
 
727
  @staticmethod
 
728
  def add_noise_to_image_conditioning_latents(
729
  t: float,
730
  init_latents: torch.Tensor,
@@ -751,6 +776,7 @@ class LTXVideoPipeline(DiffusionPipeline):
751
  return latents
752
 
753
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
 
754
  def prepare_latents(
755
  self,
756
  latents: torch.Tensor | None,
@@ -832,6 +858,7 @@ class LTXVideoPipeline(DiffusionPipeline):
832
  return latents
833
 
834
  @staticmethod
 
835
  def classify_height_width_bin(
836
  height: int, width: int, ratios: dict
837
  ) -> Tuple[int, int]:
@@ -842,6 +869,7 @@ class LTXVideoPipeline(DiffusionPipeline):
842
  return int(default_hw[0]), int(default_hw[1])
843
 
844
  @staticmethod
 
845
  def resize_and_crop_tensor(
846
  samples: torch.Tensor, new_width: int, new_height: int
847
  ) -> torch.Tensor:
@@ -868,6 +896,7 @@ class LTXVideoPipeline(DiffusionPipeline):
868
  return samples
869
 
870
  @staticmethod
 
871
  def resize_tensor(media_items, height, width):
872
  n_frames = media_items.shape[2]
873
  if media_items.shape[-2:] != (height, width):
@@ -882,6 +911,7 @@ class LTXVideoPipeline(DiffusionPipeline):
882
  return media_items
883
 
884
  @torch.no_grad()
 
885
  def __call__(
886
  self,
887
  height: int,
@@ -1087,7 +1117,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1087
  )
1088
 
1089
  try:
1090
- print(f"[LTX2]LATENTS {latents.shape}")
1091
  except Exception:
1092
  pass
1093
 
@@ -1160,7 +1190,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1160
  )
1161
 
1162
  try:
1163
- print(f"[LTX3]LATENTS {latents.shape}")
1164
  except Exception:
1165
  pass
1166
 
@@ -1230,7 +1260,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1230
  )
1231
 
1232
  try:
1233
- print(f"[LTX4]LATENTS {latents.shape}")
1234
  original_shape = latents
1235
  except Exception:
1236
  pass
@@ -1252,17 +1282,17 @@ class LTXVideoPipeline(DiffusionPipeline):
1252
  init_latents = latents.clone() # Used for image_cond_noise_update
1253
 
1254
  try:
1255
- print(f"[LTXCond]conditioning_mask {conditioning_mask.shape}")
1256
  except Exception:
1257
  pass
1258
 
1259
  try:
1260
- print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
1261
  except Exception:
1262
  pass
1263
 
1264
  try:
1265
- print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
1266
  except Exception:
1267
  pass
1268
 
@@ -1274,7 +1304,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1274
 
1275
 
1276
  try:
1277
- print(f"[LTX5]LATENTS {latents.shape}")
1278
  except Exception:
1279
  pass
1280
 
@@ -1283,6 +1313,12 @@ class LTXVideoPipeline(DiffusionPipeline):
1283
  len(timesteps) - num_inference_steps * self.scheduler.order, 0
1284
  )
1285
 
 
 
 
 
 
 
1286
  orig_conditioning_mask = conditioning_mask
1287
 
1288
  # Befor compiling this code please be aware:
@@ -1337,11 +1373,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1337
  generator,
1338
  )
1339
 
1340
- try:
1341
- print(f"[LTX6]LATENTS {latents.shape}")
1342
- self.spy.inspect(latents, "LTX6_After_Patchify", reference_shape_5d=original_shape)
1343
- except Exception:
1344
- pass
1345
 
1346
 
1347
 
@@ -1352,11 +1384,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1352
  latent_model_input, t
1353
  )
1354
 
1355
- try:
1356
- print(f"[LTX7]LATENTS {latent_model_input.shape}")
1357
- self.spy.inspect(latents, "LTX7_After_Patchify", reference_shape_5d=original_shape)
1358
- except Exception:
1359
- pass
1360
 
1361
  current_timestep = t
1362
  if not torch.is_tensor(current_timestep):
@@ -1473,11 +1501,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1473
  stochastic_sampling=stochastic_sampling,
1474
  )
1475
 
1476
- try:
1477
- print(f"[LTX8]LATENTS {latents.shape}")
1478
- self.spy.inspect(latents, "LTX8_After_Patchify", reference_shape_5d=original_shape)
1479
- except Exception:
1480
- pass
1481
 
1482
  # call the callback, if provided
1483
  if i == len(timesteps) - 1 or (
@@ -1504,15 +1528,10 @@ class LTXVideoPipeline(DiffusionPipeline):
1504
  torch.cuda.empty_cache()
1505
 
1506
  # Remove the added conditioning latents
1507
- latents = latents[:, num_cond_latents:]
 
1508
 
1509
 
1510
- try:
1511
- print(f"[LTX10]LATENTS {latents.shape}")
1512
- self.spy.inspect(latents, "LTX10_After_Patchify", reference_shape_5d=original_shape)
1513
- except Exception:
1514
- pass
1515
-
1516
  latents = self.patchifier.unpatchify(
1517
  latents=latents,
1518
  output_height=latent_height,
@@ -1549,7 +1568,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1549
  )
1550
 
1551
  try:
1552
- print(f"[LTX11]LATENTS {latents.shape}")
1553
  except Exception:
1554
  pass
1555
 
@@ -1566,6 +1585,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1566
 
1567
  return ImagePipelineOutput(images=image)
1568
 
 
1569
  def denoising_step(
1570
  self,
1571
  latents: torch.Tensor,
@@ -1601,6 +1621,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1601
  tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
1602
  return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1603
 
 
1604
  def prepare_conditioning(
1605
  self,
1606
  conditioning_items: Optional[List[ConditioningItem]],
@@ -1808,6 +1829,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1808
  )
1809
 
1810
  @staticmethod
 
1811
  def _resize_conditioning_item(
1812
  conditioning_item: ConditioningItem,
1813
  height: int,
@@ -1823,6 +1845,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1823
  )
1824
  return new_conditioning_item
1825
 
 
1826
  def _get_latent_spatial_position(
1827
  self,
1828
  latents: torch.Tensor,
@@ -1871,6 +1894,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1871
  return latents, x_start // scale, y_start // scale
1872
 
1873
  @staticmethod
 
1874
  def _handle_non_first_conditioning_sequence(
1875
  init_latents: torch.Tensor,
1876
  init_conditioning_mask: torch.Tensor,
@@ -1946,6 +1970,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1946
  latents,
1947
  )
1948
 
 
1949
  def trim_conditioning_sequence(
1950
  self, start_frame: int, sequence_num_frames: int, target_num_frames: int
1951
  ):
@@ -1967,6 +1992,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1967
  return num_frames
1968
 
1969
  @staticmethod
 
1970
  def tone_map_latents(
1971
  latents: torch.Tensor,
1972
  compression: float,
@@ -2008,6 +2034,7 @@ class LTXVideoPipeline(DiffusionPipeline):
2008
  return filtered
2009
 
2010
 
 
2011
  def adain_filter_latent(
2012
  latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
2013
  ):
@@ -2038,7 +2065,7 @@ def adain_filter_latent(
2038
  result = torch.lerp(latents, result, factor)
2039
  return result
2040
 
2041
-
2042
  class LTXMultiScalePipeline:
2043
  def _upsample_latents(
2044
  self, latest_upsampler: LatentUpsampler, latents: torch.Tensor
 
59
 
60
 
61
  #logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+ import logging
63
+ import warnings
64
+ warnings.filterwarnings("ignore", category=UserWarning)
65
+ warnings.filterwarnings("ignore", category=FutureWarning)
66
+ warnings.filterwarnings("ignore", message=".*")
67
+ from huggingface_hub import logging as ll
68
+ ll.set_verbosity_error()
69
+ ll.set_verbosity_warning()
70
+ ll.set_verbosity_info()
71
+ ll.set_verbosity_debug()
72
 
73
+ from utils.debug_utils import log_function_io
74
+ logger = logging.getLogger("AducDebug")
75
+ logging.basicConfig(level=logging.DEBUG)
76
+ logger.setLevel(logging.DEBUG)
77
 
78
+
79
+ @log_function_io
80
  class SpyLatent:
81
 
82
  """
 
136
  # --- Converte para 5D se necessário ---
137
  tensor_5d = self._to_5d(tensor, reference_shape_5d)
138
  if tensor_5d is not None and tensor.ndim == 3:
139
+ #self._print_stats("Convertido para 5D", tensor_5d)
140
 
141
  # --- Visualização com VAE ---
142
  if save_visual and self.vae is not None and tensor_5d is not None:
 
145
 
146
  frame_idx_to_viz = min(1, tensor_5d.shape[2] - 1)
147
  if frame_idx_to_viz < 0:
148
+ #print(" VISUALIZAÇÃO (VAE): Tensor não tem frames para visualizar.")
149
  else:
150
  #print(f" VISUALIZAÇÃO (VAE): Usando frame de índice {frame_idx_to_viz}.")
151
  latent_slice = tensor_5d[:, :, frame_idx_to_viz:frame_idx_to_viz+1, :, :]
 
179
  std = tensor.std().item()
180
  min_val = tensor.min().item()
181
  max_val = tensor.max().item()
182
+ #print(f" {prefix}: {tensor.shape}")
183
 
184
 
185
 
 
256
  "4.0": [1024.0, 256.0],
257
  }
258
 
259
+ @log_function_io
260
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
261
  def retrieve_timesteps(
262
  scheduler,
 
324
  num_inference_steps = len(timesteps)
325
 
326
  try:
327
+ print(f"[LTX]timesteps {timesteps}")
328
+ print(f"[LTX]num_inference_steps {num_inference_steps}")
329
  except Exception:
330
+ pass
331
 
332
 
333
  return timesteps, num_inference_steps
 
351
  conditioning_strength: float
352
  media_x: Optional[int] = None
353
  media_y: Optional[int] = None
354
+
355
 
356
+ @log_function_io
357
  class LTXVideoPipeline(DiffusionPipeline):
358
  r"""
359
  Pipeline for text-to-image generation using LTX-Video.
 
405
  ]
406
  model_cpu_offload_seq = "prompt_enhancer_image_caption_model->prompt_enhancer_llm_model->text_encoder->transformer->vae"
407
 
408
+ @log_function_io
409
  def __init__(
410
  self,
411
  tokenizer: T5Tokenizer,
 
444
 
445
  self.spy = SpyLatent(vae=vae)
446
 
447
+ @log_function_io
448
  def mask_text_embeddings(self, emb, mask):
449
  if emb.shape[0] == 1:
450
  keep_index = mask.sum().item()
 
454
  return masked_feature, emb.shape[2]
455
 
456
  # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
457
+
458
+ @log_function_io
459
  def encode_prompt(
460
  self,
461
  prompt: Union[str, List[str]],
 
650
  extra_step_kwargs["generator"] = generator
651
  return extra_step_kwargs
652
 
653
+ @log_function_io
654
  def check_inputs(
655
  self,
656
  prompt,
 
737
  self.prompt_enhancer_llm_tokenizer is not None
738
  ), "Text prompt enhancer tokenizer must be initialized if enhance_prompt is True"
739
 
740
+ @log_function_io
741
  def _text_preprocessing(self, text):
742
  if not isinstance(text, (tuple, list)):
743
  text = [text]
 
749
  return [process(t) for t in text]
750
 
751
  @staticmethod
752
+ @log_function_io
753
  def add_noise_to_image_conditioning_latents(
754
  t: float,
755
  init_latents: torch.Tensor,
 
776
  return latents
777
 
778
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
779
+ @log_function_io
780
  def prepare_latents(
781
  self,
782
  latents: torch.Tensor | None,
 
858
  return latents
859
 
860
  @staticmethod
861
+ @log_function_io
862
  def classify_height_width_bin(
863
  height: int, width: int, ratios: dict
864
  ) -> Tuple[int, int]:
 
869
  return int(default_hw[0]), int(default_hw[1])
870
 
871
  @staticmethod
872
+ @log_function_io
873
  def resize_and_crop_tensor(
874
  samples: torch.Tensor, new_width: int, new_height: int
875
  ) -> torch.Tensor:
 
896
  return samples
897
 
898
  @staticmethod
899
+ @log_function_io
900
  def resize_tensor(media_items, height, width):
901
  n_frames = media_items.shape[2]
902
  if media_items.shape[-2:] != (height, width):
 
911
  return media_items
912
 
913
  @torch.no_grad()
914
+ @log_function_io
915
  def __call__(
916
  self,
917
  height: int,
 
1117
  )
1118
 
1119
  try:
1120
+ print(f"[LTX2]timesteps {timesteps}")
1121
  except Exception:
1122
  pass
1123
 
 
1190
  )
1191
 
1192
  try:
1193
+ print(f"[LTX3]LATENTS {prompt}")
1194
  except Exception:
1195
  pass
1196
 
 
1260
  )
1261
 
1262
  try:
1263
+ print(f"[LTX4]media_items {media_items}")
1264
  original_shape = latents
1265
  except Exception:
1266
  pass
 
1282
  init_latents = latents.clone() # Used for image_cond_noise_update
1283
 
1284
  try:
1285
+ print(f"[LTXCond]conditioning_items {conditioning_items.shape}")
1286
  except Exception:
1287
  pass
1288
 
1289
  try:
1290
+ print(f"[LTXCond]num_frames {num_frames}")
1291
  except Exception:
1292
  pass
1293
 
1294
  try:
1295
+ print(f"[LTXCond]width {width}")
1296
  except Exception:
1297
  pass
1298
 
 
1304
 
1305
 
1306
  try:
1307
+ print(f"[LTX5]width {width}")
1308
  except Exception:
1309
  pass
1310
 
 
1313
  len(timesteps) - num_inference_steps * self.scheduler.order, 0
1314
  )
1315
 
1316
+ try:
1317
+ print(f"[LTX5]num_warmup_steps {num_warmup_steps}")
1318
+ except Exception:
1319
+ pass
1320
+
1321
+
1322
  orig_conditioning_mask = conditioning_mask
1323
 
1324
  # Befor compiling this code please be aware:
 
1373
  generator,
1374
  )
1375
 
1376
+
 
 
 
 
1377
 
1378
 
1379
 
 
1384
  latent_model_input, t
1385
  )
1386
 
1387
+
 
 
 
 
1388
 
1389
  current_timestep = t
1390
  if not torch.is_tensor(current_timestep):
 
1501
  stochastic_sampling=stochastic_sampling,
1502
  )
1503
 
1504
+
 
 
 
 
1505
 
1506
  # call the callback, if provided
1507
  if i == len(timesteps) - 1 or (
 
1528
  torch.cuda.empty_cache()
1529
 
1530
  # Remove the added conditioning latents
1531
+ #latents = latents[:, num_cond_latents:]
1532
+
1533
 
1534
 
 
 
 
 
 
 
1535
  latents = self.patchifier.unpatchify(
1536
  latents=latents,
1537
  output_height=latent_height,
 
1568
  )
1569
 
1570
  try:
1571
+ print(f"[LTX11]LATENTSfim {latents.shape}")
1572
  except Exception:
1573
  pass
1574
 
 
1585
 
1586
  return ImagePipelineOutput(images=image)
1587
 
1588
+ @log_function_io
1589
  def denoising_step(
1590
  self,
1591
  latents: torch.Tensor,
 
1621
  tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
1622
  return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1623
 
1624
+ @log_function_io
1625
  def prepare_conditioning(
1626
  self,
1627
  conditioning_items: Optional[List[ConditioningItem]],
 
1829
  )
1830
 
1831
  @staticmethod
1832
+ @log_function_io
1833
  def _resize_conditioning_item(
1834
  conditioning_item: ConditioningItem,
1835
  height: int,
 
1845
  )
1846
  return new_conditioning_item
1847
 
1848
+ @log_function_io
1849
  def _get_latent_spatial_position(
1850
  self,
1851
  latents: torch.Tensor,
 
1894
  return latents, x_start // scale, y_start // scale
1895
 
1896
  @staticmethod
1897
+ @log_function_io
1898
  def _handle_non_first_conditioning_sequence(
1899
  init_latents: torch.Tensor,
1900
  init_conditioning_mask: torch.Tensor,
 
1970
  latents,
1971
  )
1972
 
1973
+ @log_function_io
1974
  def trim_conditioning_sequence(
1975
  self, start_frame: int, sequence_num_frames: int, target_num_frames: int
1976
  ):
 
1992
  return num_frames
1993
 
1994
  @staticmethod
1995
+ @log_function_io
1996
  def tone_map_latents(
1997
  latents: torch.Tensor,
1998
  compression: float,
 
2034
  return filtered
2035
 
2036
 
2037
+ @log_function_io
2038
  def adain_filter_latent(
2039
  latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
2040
  ):
 
2065
  result = torch.lerp(latents, result, factor)
2066
  return result
2067
 
2068
+ @log_function_io
2069
  class LTXMultiScalePipeline:
2070
  def _upsample_latents(
2071
  self, latest_upsampler: LatentUpsampler, latents: torch.Tensor