Spaces:
Paused
Paused
Upload pipeline_ltx_video.py
Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py
CHANGED
|
@@ -45,17 +45,8 @@ from ltx_video.models.autoencoders.vae_encode import (
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
-
|
| 49 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 50 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 51 |
-
warnings.filterwarnings("ignore", message=".*")
|
| 52 |
|
| 53 |
-
from huggingface_hub import logging as ll
|
| 54 |
-
|
| 55 |
-
ll.set_verbosity_error()
|
| 56 |
-
ll.set_verbosity_warning()
|
| 57 |
-
ll.set_verbosity_info()
|
| 58 |
-
ll.set_verbosity_debug()
|
| 59 |
|
| 60 |
ASPECT_RATIO_1024_BIN = {
|
| 61 |
"0.25": [512.0, 2048.0],
|
|
@@ -1389,42 +1380,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1389 |
tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
|
| 1390 |
return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
|
| 1391 |
|
| 1392 |
-
def _prepare_conditioning_media_to_latents(
|
| 1393 |
-
self,
|
| 1394 |
-
conditioning_item: ConditioningItem,
|
| 1395 |
-
height: int,
|
| 1396 |
-
width: int,
|
| 1397 |
-
latent_height: int,
|
| 1398 |
-
latent_width: int,
|
| 1399 |
-
vae_per_channel_normalize: bool,
|
| 1400 |
-
init_latents_dtype: torch.dtype,
|
| 1401 |
-
init_latents_device: torch.device,
|
| 1402 |
-
) -> Tuple[ConditioningItem, torch.Tensor]:
|
| 1403 |
-
|
| 1404 |
-
media_item = conditioning_item.media_item
|
| 1405 |
-
c = media_item.shape[1]
|
| 1406 |
-
|
| 1407 |
-
print (f"_ltx_prepare_media_item.shape {media_item.shape}")
|
| 1408 |
-
|
| 1409 |
-
# Dentro de _prepare_conditioning_media_to_latents:
|
| 1410 |
-
c = media_item.shape[1]
|
| 1411 |
-
if c == self.transformer.config.in_channels:
|
| 1412 |
-
latents = media_item.to(dtype=init_latents_dtype, device=init_latents_device)
|
| 1413 |
-
return conditioning_item, latents
|
| 1414 |
-
|
| 1415 |
-
conditioning_item = self._resize_conditioning_item(conditioning_item, height, width)
|
| 1416 |
-
media_item = conditioning_item.media_item
|
| 1417 |
-
latents = vae_encode(
|
| 1418 |
-
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
|
| 1419 |
-
self.vae,
|
| 1420 |
-
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1421 |
-
).to(dtype=init_latents_dtype, device=init_latents_device)
|
| 1422 |
-
|
| 1423 |
-
print (f"_ltx_prepare_media_item_vae.shape? {media_item.shape}")
|
| 1424 |
-
|
| 1425 |
-
return conditioning_item, latents
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
def prepare_conditioning(
|
| 1429 |
self,
|
| 1430 |
conditioning_items: Optional[List[ConditioningItem]],
|
|
@@ -1434,119 +1389,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1434 |
width: int,
|
| 1435 |
vae_per_channel_normalize: bool = False,
|
| 1436 |
generator=None,
|
| 1437 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1438 |
-
assert isinstance(self.vae, CausalVideoAutoencoder)
|
| 1439 |
-
if not conditioning_items:
|
| 1440 |
-
# comportamento existente
|
| 1441 |
-
# ...
|
| 1442 |
-
return init_latents, init_pixel_coords, None, 0
|
| 1443 |
-
|
| 1444 |
-
# Deriva dims latentes
|
| 1445 |
-
latent_height = height // self.vae_scale_factor
|
| 1446 |
-
latent_width = width // self.vae_scale_factor
|
| 1447 |
-
|
| 1448 |
-
batch_size, _, num_latent_frames = init_latents.shape[:3]
|
| 1449 |
-
init_conditioning_mask = torch.zeros(
|
| 1450 |
-
init_latents[:, 0, :, :, :].shape,
|
| 1451 |
-
dtype=torch.float32,
|
| 1452 |
-
device=init_latents.device,
|
| 1453 |
-
)
|
| 1454 |
-
|
| 1455 |
-
extra_conditioning_latents = []
|
| 1456 |
-
extra_conditioning_pixel_coords = []
|
| 1457 |
-
extra_conditioning_mask = []
|
| 1458 |
-
extra_conditioning_num_latents = 0
|
| 1459 |
-
|
| 1460 |
-
for conditioning_item in conditioning_items:
|
| 1461 |
-
# NOVO: centraliza resize/encode/deteção de latents
|
| 1462 |
-
conditioning_item, media_item_latents = self._prepare_conditioning_media_to_latents(
|
| 1463 |
-
conditioning_item=conditioning_item,
|
| 1464 |
-
height=height,
|
| 1465 |
-
width=width,
|
| 1466 |
-
latent_height=latent_height,
|
| 1467 |
-
latent_width=latent_width,
|
| 1468 |
-
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1469 |
-
init_latents_dtype=init_latents.dtype,
|
| 1470 |
-
init_latents_device=init_latents.device,
|
| 1471 |
-
)
|
| 1472 |
-
|
| 1473 |
-
media_frame_number = conditioning_item.media_frame_number
|
| 1474 |
-
strength = conditioning_item.conditioning_strength
|
| 1475 |
-
|
| 1476 |
-
# Validações de shape/frames (agora sobre latents)
|
| 1477 |
-
b, c_l, f_l, h_l, w_l = media_item_latents.shape
|
| 1478 |
-
assert c_l == self.transformer.config.in_channels
|
| 1479 |
-
assert (h_l, w_l) == (latent_height, latent_width), "Latents com HxW incompatíveis"
|
| 1480 |
-
assert f_l % 8 == 1, "n_frames latente deve satisfazer múltiplos da escala temporal + 1"
|
| 1481 |
-
assert media_frame_number >= 0 and (media_frame_number + f_l) <= num_frames
|
| 1482 |
-
|
| 1483 |
-
print(f"media_item_latents, {media_item_latents.shape}")
|
| 1484 |
-
|
| 1485 |
-
# A partir daqui permanece igual (posicionamento, máscaras, patchify etc.)
|
| 1486 |
-
if media_frame_number == 0:
|
| 1487 |
-
media_item_latents, l_x, l_y = self._get_latent_spatial_position(
|
| 1488 |
-
media_item_latents, conditioning_item, height, width, strip_latent_border=True
|
| 1489 |
-
)
|
| 1490 |
-
init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(
|
| 1491 |
-
init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l],
|
| 1492 |
-
media_item_latents,
|
| 1493 |
-
strength,
|
| 1494 |
-
)
|
| 1495 |
-
init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
|
| 1496 |
-
else:
|
| 1497 |
-
if f_l > 1:
|
| 1498 |
-
init_latents, init_conditioning_mask, media_item_latents = self._handle_non_first_conditioning_sequence(
|
| 1499 |
-
init_latents, init_conditioning_mask, media_item_latents, media_frame_number, strength,
|
| 1500 |
-
)
|
| 1501 |
-
if media_item_latents is not None:
|
| 1502 |
-
noise = randn_tensor(
|
| 1503 |
-
media_item_latents.shape, generator=generator,
|
| 1504 |
-
device=media_item_latents.device, dtype=media_item_latents.dtype,
|
| 1505 |
-
)
|
| 1506 |
-
media_item_latents = torch.lerp(noise, media_item_latents, strength)
|
| 1507 |
-
media_item_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
|
| 1508 |
-
pixel_coords = latent_to_pixel_coords(
|
| 1509 |
-
latent_coords, self.vae,
|
| 1510 |
-
causal_fix=self.transformer.config.causal_temporal_positioning,
|
| 1511 |
-
)
|
| 1512 |
-
pixel_coords[:, 0] += media_frame_number
|
| 1513 |
-
extra_conditioning_num_latents += media_item_latents.shape[1]
|
| 1514 |
-
conditioning_mask = torch.full(
|
| 1515 |
-
media_item_latents.shape[:2], strength,
|
| 1516 |
-
dtype=torch.float32, device=init_latents.device,
|
| 1517 |
-
)
|
| 1518 |
-
extra_conditioning_latents.append(media_item_latents)
|
| 1519 |
-
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 1520 |
-
extra_conditioning_mask.append(conditioning_mask)
|
| 1521 |
-
|
| 1522 |
-
# Patchify e concat iguais ao código existente...
|
| 1523 |
-
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 1524 |
-
init_pixel_coords = latent_to_pixel_coords(
|
| 1525 |
-
init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning,
|
| 1526 |
-
)
|
| 1527 |
-
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
|
| 1528 |
-
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
|
| 1529 |
-
if extra_conditioning_latents:
|
| 1530 |
-
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
|
| 1531 |
-
init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
|
| 1532 |
-
init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
|
| 1533 |
-
if self.transformer.use_tpu_flash_attention:
|
| 1534 |
-
init_latents = init_latents[:, :-extra_conditioning_num_latents]
|
| 1535 |
-
init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
|
| 1536 |
-
init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
|
| 1537 |
-
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
|
| 1538 |
-
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
def prepare_conditioning_old(
|
| 1542 |
-
self,
|
| 1543 |
-
conditioning_items: Optional[List[ConditioningItem]],
|
| 1544 |
-
init_latents: torch.Tensor,
|
| 1545 |
-
num_frames: int,
|
| 1546 |
-
height: int,
|
| 1547 |
-
width: int,
|
| 1548 |
-
vae_per_channel_normalize: bool = False,
|
| 1549 |
-
generator=None,
|
| 1550 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1551 |
"""
|
| 1552 |
Prepare conditioning tokens based on the provided conditioning items.
|
|
@@ -1596,33 +1438,41 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1596 |
|
| 1597 |
# Process each conditioning item
|
| 1598 |
for conditioning_item in conditioning_items:
|
| 1599 |
-
|
| 1600 |
-
|
| 1601 |
-
)
|
| 1602 |
-
media_item = conditioning_item.media_item
|
| 1603 |
-
media_frame_number = conditioning_item.media_frame_number
|
| 1604 |
-
strength = conditioning_item.conditioning_strength
|
| 1605 |
-
assert media_item.ndim == 5 # (b, c, f, h, w)
|
| 1606 |
-
b, c, n_frames, h, w = media_item.shape
|
| 1607 |
-
assert (
|
| 1608 |
-
height == h and width == w
|
| 1609 |
-
) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
|
| 1610 |
-
assert n_frames % 8 == 1
|
| 1611 |
-
assert (
|
| 1612 |
-
media_frame_number >= 0
|
| 1613 |
-
and media_frame_number + n_frames <= num_frames
|
| 1614 |
-
)
|
| 1615 |
-
|
| 1616 |
-
# Encode the provided conditioning media item
|
| 1617 |
-
media_item_latents = vae_encode(
|
| 1618 |
-
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
|
| 1619 |
-
self.vae,
|
| 1620 |
-
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1621 |
-
).to(dtype=init_latents.dtype)
|
| 1622 |
-
|
| 1623 |
-
print(f"media_item_latents, {media_item_latents.shape}")
|
| 1624 |
-
|
| 1625 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1626 |
# Handle the different conditioning cases
|
| 1627 |
if media_frame_number == 0:
|
| 1628 |
# Get the target spatial position of the latent conditioning item
|
|
@@ -1763,7 +1613,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1763 |
)
|
| 1764 |
return new_conditioning_item
|
| 1765 |
|
| 1766 |
-
@staticmethod
|
| 1767 |
def _get_latent_spatial_position(
|
| 1768 |
self,
|
| 1769 |
latents: torch.Tensor,
|
|
@@ -1887,7 +1736,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1887 |
latents,
|
| 1888 |
)
|
| 1889 |
|
| 1890 |
-
@staticmethod
|
| 1891 |
def trim_conditioning_sequence(
|
| 1892 |
self, start_frame: int, sequence_num_frames: int, target_num_frames: int
|
| 1893 |
):
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
ASPECT_RATIO_1024_BIN = {
|
| 52 |
"0.25": [512.0, 2048.0],
|
|
|
|
| 1380 |
tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
|
| 1381 |
return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
|
| 1382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1383 |
def prepare_conditioning(
|
| 1384 |
self,
|
| 1385 |
conditioning_items: Optional[List[ConditioningItem]],
|
|
|
|
| 1389 |
width: int,
|
| 1390 |
vae_per_channel_normalize: bool = False,
|
| 1391 |
generator=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1392 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1393 |
"""
|
| 1394 |
Prepare conditioning tokens based on the provided conditioning items.
|
|
|
|
| 1438 |
|
| 1439 |
# Process each conditioning item
|
| 1440 |
for conditioning_item in conditioning_items:
|
| 1441 |
+
|
| 1442 |
+
print(f"media_item_latents ini {conditioning_item.media_item.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1443 |
|
| 1444 |
+
c = conditioning_item.media_item.shape[1]
|
| 1445 |
+
if c == self.transformer.config.in_channels:
|
| 1446 |
+
media_item_latents = conditioning_item.media_item.to(dtype=init_latents_dtype, device=init_latents_device)
|
| 1447 |
+
strength = conditioning_item.conditioning_strength
|
| 1448 |
+
media_frame_number = conditioning_item.media_frame_number
|
| 1449 |
+
else:
|
| 1450 |
+
conditioning_item = self._resize_conditioning_item(
|
| 1451 |
+
conditioning_item, height, width
|
| 1452 |
+
)
|
| 1453 |
+
media_item = conditioning_item.media_item
|
| 1454 |
+
media_frame_number = conditioning_item.media_frame_number
|
| 1455 |
+
strength = conditioning_item.conditioning_strength
|
| 1456 |
+
assert media_item.ndim == 5 # (b, c, f, h, w)
|
| 1457 |
+
b, c, n_frames, h, w = media_item.shape
|
| 1458 |
+
assert (
|
| 1459 |
+
height == h and width == w
|
| 1460 |
+
) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
|
| 1461 |
+
assert n_frames % 8 == 1
|
| 1462 |
+
assert (
|
| 1463 |
+
media_frame_number >= 0
|
| 1464 |
+
and media_frame_number + n_frames <= num_frames
|
| 1465 |
+
)
|
| 1466 |
+
|
| 1467 |
+
# Encode the provided conditioning media item
|
| 1468 |
+
media_item_latents = vae_encode(
|
| 1469 |
+
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
|
| 1470 |
+
self.vae,
|
| 1471 |
+
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1472 |
+
).to(dtype=init_latents.dtype)
|
| 1473 |
+
|
| 1474 |
+
print(f"media_item_latents encode vae {conditioning_item.media_item.shape}")
|
| 1475 |
+
|
| 1476 |
# Handle the different conditioning cases
|
| 1477 |
if media_frame_number == 0:
|
| 1478 |
# Get the target spatial position of the latent conditioning item
|
|
|
|
| 1613 |
)
|
| 1614 |
return new_conditioning_item
|
| 1615 |
|
|
|
|
| 1616 |
def _get_latent_spatial_position(
|
| 1617 |
self,
|
| 1618 |
latents: torch.Tensor,
|
|
|
|
| 1736 |
latents,
|
| 1737 |
)
|
| 1738 |
|
|
|
|
| 1739 |
def trim_conditioning_sequence(
|
| 1740 |
self, start_frame: int, sequence_num_frames: int, target_num_frames: int
|
| 1741 |
):
|