euiiiia commited on
Commit
dae5014
·
verified ·
1 Parent(s): 1cd256f

Upload pipeline_ltx_video.py

Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py CHANGED
@@ -45,17 +45,8 @@ from ltx_video.models.autoencoders.vae_encode import (
45
  )
46
 
47
 
48
- import warnings
49
- warnings.filterwarnings("ignore", category=UserWarning)
50
- warnings.filterwarnings("ignore", category=FutureWarning)
51
- warnings.filterwarnings("ignore", message=".*")
52
 
53
- from huggingface_hub import logging as ll
54
-
55
- ll.set_verbosity_error()
56
- ll.set_verbosity_warning()
57
- ll.set_verbosity_info()
58
- ll.set_verbosity_debug()
59
 
60
  ASPECT_RATIO_1024_BIN = {
61
  "0.25": [512.0, 2048.0],
@@ -1389,42 +1380,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1389
  tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
1390
  return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1391
 
1392
- def _prepare_conditioning_media_to_latents(
1393
- self,
1394
- conditioning_item: ConditioningItem,
1395
- height: int,
1396
- width: int,
1397
- latent_height: int,
1398
- latent_width: int,
1399
- vae_per_channel_normalize: bool,
1400
- init_latents_dtype: torch.dtype,
1401
- init_latents_device: torch.device,
1402
- ) -> Tuple[ConditioningItem, torch.Tensor]:
1403
-
1404
- media_item = conditioning_item.media_item
1405
- c = media_item.shape[1]
1406
-
1407
- print (f"_ltx_prepare_media_item.shape {media_item.shape}")
1408
-
1409
- # Dentro de _prepare_conditioning_media_to_latents:
1410
- c = media_item.shape[1]
1411
- if c == self.transformer.config.in_channels:
1412
- latents = media_item.to(dtype=init_latents_dtype, device=init_latents_device)
1413
- return conditioning_item, latents
1414
-
1415
- conditioning_item = self._resize_conditioning_item(conditioning_item, height, width)
1416
- media_item = conditioning_item.media_item
1417
- latents = vae_encode(
1418
- media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1419
- self.vae,
1420
- vae_per_channel_normalize=vae_per_channel_normalize,
1421
- ).to(dtype=init_latents_dtype, device=init_latents_device)
1422
-
1423
- print (f"_ltx_prepare_media_item_vae.shape? {media_item.shape}")
1424
-
1425
- return conditioning_item, latents
1426
-
1427
-
1428
  def prepare_conditioning(
1429
  self,
1430
  conditioning_items: Optional[List[ConditioningItem]],
@@ -1434,119 +1389,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1434
  width: int,
1435
  vae_per_channel_normalize: bool = False,
1436
  generator=None,
1437
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1438
- assert isinstance(self.vae, CausalVideoAutoencoder)
1439
- if not conditioning_items:
1440
- # comportamento existente
1441
- # ...
1442
- return init_latents, init_pixel_coords, None, 0
1443
-
1444
- # Deriva dims latentes
1445
- latent_height = height // self.vae_scale_factor
1446
- latent_width = width // self.vae_scale_factor
1447
-
1448
- batch_size, _, num_latent_frames = init_latents.shape[:3]
1449
- init_conditioning_mask = torch.zeros(
1450
- init_latents[:, 0, :, :, :].shape,
1451
- dtype=torch.float32,
1452
- device=init_latents.device,
1453
- )
1454
-
1455
- extra_conditioning_latents = []
1456
- extra_conditioning_pixel_coords = []
1457
- extra_conditioning_mask = []
1458
- extra_conditioning_num_latents = 0
1459
-
1460
- for conditioning_item in conditioning_items:
1461
- # NOVO: centraliza resize/encode/deteção de latents
1462
- conditioning_item, media_item_latents = self._prepare_conditioning_media_to_latents(
1463
- conditioning_item=conditioning_item,
1464
- height=height,
1465
- width=width,
1466
- latent_height=latent_height,
1467
- latent_width=latent_width,
1468
- vae_per_channel_normalize=vae_per_channel_normalize,
1469
- init_latents_dtype=init_latents.dtype,
1470
- init_latents_device=init_latents.device,
1471
- )
1472
-
1473
- media_frame_number = conditioning_item.media_frame_number
1474
- strength = conditioning_item.conditioning_strength
1475
-
1476
- # Validações de shape/frames (agora sobre latents)
1477
- b, c_l, f_l, h_l, w_l = media_item_latents.shape
1478
- assert c_l == self.transformer.config.in_channels
1479
- assert (h_l, w_l) == (latent_height, latent_width), "Latents com HxW incompatíveis"
1480
- assert f_l % 8 == 1, "n_frames latente deve satisfazer múltiplos da escala temporal + 1"
1481
- assert media_frame_number >= 0 and (media_frame_number + f_l) <= num_frames
1482
-
1483
- print(f"media_item_latents, {media_item_latents.shape}")
1484
-
1485
- # A partir daqui permanece igual (posicionamento, máscaras, patchify etc.)
1486
- if media_frame_number == 0:
1487
- media_item_latents, l_x, l_y = self._get_latent_spatial_position(
1488
- media_item_latents, conditioning_item, height, width, strip_latent_border=True
1489
- )
1490
- init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(
1491
- init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l],
1492
- media_item_latents,
1493
- strength,
1494
- )
1495
- init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
1496
- else:
1497
- if f_l > 1:
1498
- init_latents, init_conditioning_mask, media_item_latents = self._handle_non_first_conditioning_sequence(
1499
- init_latents, init_conditioning_mask, media_item_latents, media_frame_number, strength,
1500
- )
1501
- if media_item_latents is not None:
1502
- noise = randn_tensor(
1503
- media_item_latents.shape, generator=generator,
1504
- device=media_item_latents.device, dtype=media_item_latents.dtype,
1505
- )
1506
- media_item_latents = torch.lerp(noise, media_item_latents, strength)
1507
- media_item_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
1508
- pixel_coords = latent_to_pixel_coords(
1509
- latent_coords, self.vae,
1510
- causal_fix=self.transformer.config.causal_temporal_positioning,
1511
- )
1512
- pixel_coords[:, 0] += media_frame_number
1513
- extra_conditioning_num_latents += media_item_latents.shape[1]
1514
- conditioning_mask = torch.full(
1515
- media_item_latents.shape[:2], strength,
1516
- dtype=torch.float32, device=init_latents.device,
1517
- )
1518
- extra_conditioning_latents.append(media_item_latents)
1519
- extra_conditioning_pixel_coords.append(pixel_coords)
1520
- extra_conditioning_mask.append(conditioning_mask)
1521
-
1522
- # Patchify e concat iguais ao código existente...
1523
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
1524
- init_pixel_coords = latent_to_pixel_coords(
1525
- init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning,
1526
- )
1527
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
1528
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
1529
- if extra_conditioning_latents:
1530
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
1531
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
1532
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
1533
- if self.transformer.use_tpu_flash_attention:
1534
- init_latents = init_latents[:, :-extra_conditioning_num_latents]
1535
- init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
1536
- init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
1537
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
1538
-
1539
-
1540
-
1541
- def prepare_conditioning_old(
1542
- self,
1543
- conditioning_items: Optional[List[ConditioningItem]],
1544
- init_latents: torch.Tensor,
1545
- num_frames: int,
1546
- height: int,
1547
- width: int,
1548
- vae_per_channel_normalize: bool = False,
1549
- generator=None,
1550
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1551
  """
1552
  Prepare conditioning tokens based on the provided conditioning items.
@@ -1596,33 +1438,41 @@ class LTXVideoPipeline(DiffusionPipeline):
1596
 
1597
  # Process each conditioning item
1598
  for conditioning_item in conditioning_items:
1599
- conditioning_item = self._resize_conditioning_item(
1600
- conditioning_item, height, width
1601
- )
1602
- media_item = conditioning_item.media_item
1603
- media_frame_number = conditioning_item.media_frame_number
1604
- strength = conditioning_item.conditioning_strength
1605
- assert media_item.ndim == 5 # (b, c, f, h, w)
1606
- b, c, n_frames, h, w = media_item.shape
1607
- assert (
1608
- height == h and width == w
1609
- ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
1610
- assert n_frames % 8 == 1
1611
- assert (
1612
- media_frame_number >= 0
1613
- and media_frame_number + n_frames <= num_frames
1614
- )
1615
-
1616
- # Encode the provided conditioning media item
1617
- media_item_latents = vae_encode(
1618
- media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1619
- self.vae,
1620
- vae_per_channel_normalize=vae_per_channel_normalize,
1621
- ).to(dtype=init_latents.dtype)
1622
-
1623
- print(f"media_item_latents, {media_item_latents.shape}")
1624
-
1625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1626
  # Handle the different conditioning cases
1627
  if media_frame_number == 0:
1628
  # Get the target spatial position of the latent conditioning item
@@ -1763,7 +1613,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1763
  )
1764
  return new_conditioning_item
1765
 
1766
- @staticmethod
1767
  def _get_latent_spatial_position(
1768
  self,
1769
  latents: torch.Tensor,
@@ -1887,7 +1736,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1887
  latents,
1888
  )
1889
 
1890
- @staticmethod
1891
  def trim_conditioning_sequence(
1892
  self, start_frame: int, sequence_num_frames: int, target_num_frames: int
1893
  ):
 
45
  )
46
 
47
 
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
 
 
49
 
 
 
 
 
 
 
50
 
51
  ASPECT_RATIO_1024_BIN = {
52
  "0.25": [512.0, 2048.0],
 
1380
  tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
1381
  return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1383
  def prepare_conditioning(
1384
  self,
1385
  conditioning_items: Optional[List[ConditioningItem]],
 
1389
  width: int,
1390
  vae_per_channel_normalize: bool = False,
1391
  generator=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1392
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1393
  """
1394
  Prepare conditioning tokens based on the provided conditioning items.
 
1438
 
1439
  # Process each conditioning item
1440
  for conditioning_item in conditioning_items:
1441
+
1442
+ print(f"media_item_latents ini {conditioning_item.media_item.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1443
 
1444
+ c = conditioning_item.media_item.shape[1]
1445
+ if c == self.transformer.config.in_channels:
1446
+ media_item_latents = conditioning_item.media_item.to(dtype=init_latents_dtype, device=init_latents_device)
1447
+ strength = conditioning_item.conditioning_strength
1448
+ media_frame_number = conditioning_item.media_frame_number
1449
+ else:
1450
+ conditioning_item = self._resize_conditioning_item(
1451
+ conditioning_item, height, width
1452
+ )
1453
+ media_item = conditioning_item.media_item
1454
+ media_frame_number = conditioning_item.media_frame_number
1455
+ strength = conditioning_item.conditioning_strength
1456
+ assert media_item.ndim == 5 # (b, c, f, h, w)
1457
+ b, c, n_frames, h, w = media_item.shape
1458
+ assert (
1459
+ height == h and width == w
1460
+ ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
1461
+ assert n_frames % 8 == 1
1462
+ assert (
1463
+ media_frame_number >= 0
1464
+ and media_frame_number + n_frames <= num_frames
1465
+ )
1466
+
1467
+ # Encode the provided conditioning media item
1468
+ media_item_latents = vae_encode(
1469
+ media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1470
+ self.vae,
1471
+ vae_per_channel_normalize=vae_per_channel_normalize,
1472
+ ).to(dtype=init_latents.dtype)
1473
+
1474
+ print(f"media_item_latents encode vae {conditioning_item.media_item.shape}")
1475
+
1476
  # Handle the different conditioning cases
1477
  if media_frame_number == 0:
1478
  # Get the target spatial position of the latent conditioning item
 
1613
  )
1614
  return new_conditioning_item
1615
 
 
1616
  def _get_latent_spatial_position(
1617
  self,
1618
  latents: torch.Tensor,
 
1736
  latents,
1737
  )
1738
 
 
1739
  def trim_conditioning_sequence(
1740
  self, start_frame: int, sequence_num_frames: int, target_num_frames: int
1741
  ):