File size: 8,335 Bytes
9185209
 
2f53bb4
db85898
9185209
 
db85898
 
975f84e
9185209
db85898
9185209
9ee0aed
db85898
975f84e
9185209
 
 
 
 
db85898
 
9185209
db85898
 
9185209
 
c41e649
9bfadf5
c41e649
9185209
db85898
9185209
 
 
 
db85898
 
 
9185209
db85898
9185209
db85898
 
9185209
 
 
 
db85898
9185209
 
 
 
 
 
 
db85898
9185209
 
db85898
 
9185209
db85898
 
9185209
 
 
 
 
 
 
 
 
 
 
 
 
db4bbd0
 
ece26e1
 
 
 
02633f1
 
 
db4bbd0
9185209
 
 
 
 
 
db4bbd0
9185209
 
 
db85898
9185209
 
 
 
 
 
 
 
 
 
db85898
9185209
 
 
 
 
 
 
 
ae38dbc
 
 
 
 
9185209
 
 
b65c38b
 
02633f1
b65c38b
02633f1
9185209
 
 
 
 
ae38dbc
db85898
 
 
2f53bb4
db85898
9185209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f53bb4
9185209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f53bb4
9185209
 
 
 
 
2f53bb4
9185209
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# FILE: api/ltx/ltx_utils.py
# DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
# Handles dependency path injection, model loading, pipeline creation, and tensor preparation.

import os
import random
import json
import logging
import time
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from huggingface_hub import hf_hub_download

import numpy as np
import torch
import torchvision.transforms.functional as TVF
from PIL import Image
from safetensors import safe_open
from transformers import T5EncoderModel, T5Tokenizer

# ==============================================================================
# --- CRITICAL: DEPENDENCY PATH INJECTION ---
# ==============================================================================

# Define o caminho para o repositório clonado
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
LTX_REPO_ID = "Lightricks/LTX-Video"
CACHE_DIR = os.environ.get("HF_HOME")


def add_deps_to_path():
    """
    Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
    bibliotecas possam ser importadas.
    """
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
        logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")

# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()


# ==============================================================================
# --- IMPORTAÇÕES DA BIBLIOTECA LTX-VIDEO (Após configuração do path) ---
# ==============================================================================
try:
    from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
    from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
    from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
    from ltx_video.models.transformers.transformer3d import Transformer3DModel
    from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
    from ltx_video.schedulers.rf import RectifiedFlowScheduler
    import ltx_video.pipelines.crf_compressor as crf_compressor
except ImportError as e:
    raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")


# ==============================================================================
# --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
# ==============================================================================

def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
    """Loads the Latent Upsampler model from a checkpoint path."""
    logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
    latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
    latent_upsampler.to(device)
    latent_upsampler.eval()
    return latent_upsampler

def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
    """Builds the complete LTX pipeline and upsampler on the CPU."""
    t0 = time.perf_counter()
    logging.info("Building LTX pipeline on CPU...")



    ckpt_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["checkpoint_path"], cache_dir=CACHE_DIR)
    ckpt_path = Path(ckpt_path_str)
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")

    logging.info(f"Building LTX pipeline ckpt:{ckpt_path_str}")
    
    
    with safe_open(ckpt_path, framework="pt") as f:
        metadata = f.metadata() or {}
        config_str = metadata.get("config", "{}")
        configs = json.loads(config_str)
        allowed_inference_steps = configs.get("allowed_inference_steps")

        
    vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
    transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
    scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
    
    text_encoder_path = config["text_encoder_model_name_or_path"]
    text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
    tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
    patchifier = SymmetricPatchifier(patch_size=1)

    precision = config.get("precision", "bfloat16")
    if precision == "bfloat16":
        vae.to(torch.bfloat16)
        transformer.to(torch.bfloat16)
        text_encoder.to(torch.bfloat16)
    
    pipeline = LTXVideoPipeline(
        transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
        tokenizer=tokenizer, scheduler=scheduler, vae=vae,
        allowed_inference_steps=allowed_inference_steps,
        prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
        prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
    )


    vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
    if precision == "bfloat16":
        vae.to(torch.bfloat16)

    latent_upsampler = None
    if config.get("spatial_upscaler_model_path"):
        spatial_path = config["spatial_upscaler_model_path"]
        spatial_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["spatial_upscaler_model_path"], cache_dir=CACHE_DIR)
        spatial_path = Path(spatial_path_str)
        if not spatial_path.is_file():
            raise FileNotFoundError(f"Main checkpoint upscaler file not found: {spatial_path_str}")
        logging.info(f"Building UPSCALER pipeline ckpt:{spatial_path_str}")
        latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
        if precision == "bfloat16":
            latent_upsampler.to(torch.bfloat16)

    logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
    return pipeline, latent_upsampler, vae


# ==============================================================================
# --- FUNÇÕES AUXILIARES (Seed, Preparação de Imagem) ---
# ==============================================================================

def seed_everything(seed: int):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_image_to_tensor_with_resize_and_crop(
    image_input: Union[str, Image.Image],
    target_height: int,
    target_width: int,
) -> torch.Tensor:
    """Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline."""
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, Image.Image):
        image = image_input
    else:
        raise ValueError("image_input must be a file path or a PIL Image object")

    input_width, input_height = image.size
    aspect_ratio_target = target_width / target_height
    aspect_ratio_frame = input_width / input_height

    if aspect_ratio_frame > aspect_ratio_target:
        new_width, new_height = int(input_height * aspect_ratio_target), input_height
        x_start, y_start = (input_width - new_width) // 2, 0
    else:
        new_width, new_height = input_width, int(input_width / aspect_ratio_target)
        x_start, y_start = 0, (input_height - new_height) // 2

    image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
    image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)

    frame_tensor = TVF.to_tensor(image)  # PIL -> tensor (C, H, W) in [0, 1] range
    frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
    
    frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
    frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
    frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
    # Normalize to [-1, 1] range, which the VAE expects for encoding
    frame_tensor = (frame_tensor * 2.0) - 1.0
    
    # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
    return frame_tensor.unsqueeze(0).unsqueeze(2)