dagloop5 commited on
Commit
f6ce92e
·
verified ·
1 Parent(s): 2b4573b

Delete app(draft).py

Browse files
Files changed (1) hide show
  1. app(draft).py +0 -373
app(draft).py DELETED
@@ -1,373 +0,0 @@
1
- import os
2
- import subprocess
3
- import sys
4
-
5
- # Disable torch.compile / dynamo before any torch import
6
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
-
9
- # Install xformers for memory-efficient attention
10
- subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
-
12
- # Clone LTX-2 repo and install packages
13
- LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
- LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
- LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45"
16
-
17
- if not os.path.exists(LTX_REPO_DIR):
18
- print(f"Cloning {LTX_REPO_URL}...")
19
- os.makedirs(LTX_REPO_DIR)
20
- subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
21
- subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
22
- subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
- subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
24
-
25
-
26
- print("Installing ltx-core and ltx-pipelines from cloned repo...")
27
- subprocess.run(
28
- [sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "-e",
29
- os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
30
- "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")],
31
- check=True,
32
- )
33
-
34
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
35
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
36
-
37
- import logging
38
- import random
39
- import tempfile
40
- from pathlib import Path
41
- from collections.abc import Iterator
42
-
43
- import torch
44
- torch._dynamo.config.suppress_errors = True
45
- torch._dynamo.config.disable = True
46
-
47
- import spaces
48
- import gradio as gr
49
- import numpy as np
50
- from huggingface_hub import hf_hub_download, snapshot_download
51
-
52
- from ltx_core.components.diffusion_steps import Res2sDiffusionStep
53
- from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
54
- from ltx_core.components.noisers import GaussianNoiser
55
- from ltx_core.components.schedulers import LTX2Scheduler
56
- from ltx_core.loader import LoraPathStrengthAndSDOps
57
- from ltx_core.loader.registry import Registry
58
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
59
- from ltx_core.quantization import QuantizationPolicy
60
- from ltx_core.types import Audio, VideoLatentShape, VideoPixelShape
61
- from ltx_pipelines.utils.args import ImageConditioningInput, hq_2_stage_arg_parser
62
- from ltx_pipelines.utils.blocks import (
63
- AudioDecoder,
64
- DiffusionStage,
65
- ImageConditioner,
66
- PromptEncoder,
67
- VideoDecoder,
68
- VideoUpsampler,
69
- )
70
- from ltx_pipelines.utils.constants import (
71
- LTX_2_3_HQ_PARAMS,
72
- STAGE_2_DISTILLED_SIGMAS,
73
- )
74
- from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
75
- from ltx_pipelines.utils.helpers import (
76
- assert_resolution,
77
- combined_image_conditionings,
78
- get_device,
79
- )
80
- from ltx_pipelines.utils.media_io import encode_video
81
- from ltx_pipelines.utils.samplers import res2s_audio_video_denoising_loop
82
- from ltx_pipelines.utils.types import ModalitySpec
83
-
84
- # Force-patch xformers attention into the LTX attention module.
85
- from ltx_core.model.transformer import attention as _attn_mod
86
- print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
87
- try:
88
- from xformers.ops import memory_efficient_attention as _mea
89
- _attn_mod.memory_efficient_attention = _mea
90
- print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
91
- except Exception as e:
92
- print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
93
-
94
- logging.getLogger().setLevel(logging.INFO)
95
-
96
- MAX_SEED = np.iinfo(np.int32).max
97
- DEFAULT_PROMPT = (
98
- "An astronaut hatches from a fragile egg on the surface of the Moon, "
99
- "the shell cracking and peeling apart in gentle low-gravity motion. "
100
- "Fine lunar dust lifts and drifts outward with each movement, floating "
101
- "in slow arcs before settling back onto the ground."
102
- )
103
- DEFAULT_FRAME_RATE = 24.0
104
-
105
- # Resolution presets: (width, height)
106
- RESOLUTIONS = {
107
- "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
108
- "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
109
- }
110
-
111
- class TI2VidTwoStagesHQPipeline:
112
- """
113
- Two-stage text/image-to-video generation pipeline using the res_2s sampler.
114
- Same structure as :class:`TI2VidTwoStagesPipeline`: stage 1 generates video at
115
- half of the target resolution with CFG guidance (assuming full model is used),
116
- then Stage 2 upsamples by 2x and refines using a distilled LoRA for higher
117
- quality output.
118
- Uses the res_2s second-order sampler instead of Euler, allowing fewer
119
- steps for comparable quality. Supports optional image conditioning via
120
- the images parameter.
121
- """
122
-
123
- def __init__( # noqa: PLR0913
124
- self,
125
- checkpoint_path: str,
126
- distilled_lora: list[LoraPathStrengthAndSDOps],
127
- distilled_lora_strength_stage_1: float,
128
- distilled_lora_strength_stage_2: float,
129
- spatial_upsampler_path: str,
130
- gemma_root: str,
131
- loras: tuple[LoraPathStrengthAndSDOps, ...],
132
- device: torch.device | None = None,
133
- quantization: QuantizationPolicy | None = None,
134
- registry: Registry | None = None,
135
- torch_compile: bool = False,
136
- ):
137
- self.device = device or get_device()
138
- self.dtype = torch.bfloat16
139
- self._scheduler = LTX2Scheduler()
140
-
141
- distilled_lora_stage_1 = LoraPathStrengthAndSDOps(
142
- path=distilled_lora[0].path,
143
- strength=distilled_lora_strength_stage_1,
144
- sd_ops=distilled_lora[0].sd_ops,
145
- )
146
- distilled_lora_stage_2 = LoraPathStrengthAndSDOps(
147
- path=distilled_lora[0].path,
148
- strength=distilled_lora_strength_stage_2,
149
- sd_ops=distilled_lora[0].sd_ops,
150
- )
151
-
152
- self.prompt_encoder = PromptEncoder(checkpoint_path, gemma_root, self.dtype, self.device, registry=registry)
153
- self.image_conditioner = ImageConditioner(checkpoint_path, self.dtype, self.device, registry=registry)
154
- self.upsampler = VideoUpsampler(
155
- checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry
156
- )
157
- self.video_decoder = VideoDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
158
- self.audio_decoder = AudioDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
159
-
160
- self.stage_1 = DiffusionStage(
161
- checkpoint_path,
162
- self.dtype,
163
- self.device,
164
- loras=(*loras, distilled_lora_stage_1),
165
- quantization=quantization,
166
- registry=registry,
167
- torch_compile=torch_compile,
168
- )
169
- self.stage_2 = DiffusionStage(
170
- checkpoint_path,
171
- self.dtype,
172
- self.device,
173
- loras=(*loras, distilled_lora_stage_2),
174
- quantization=quantization,
175
- registry=registry,
176
- torch_compile=torch_compile,
177
- )
178
-
179
- @torch.inference_mode()
180
- def __call__( # noqa: PLR0913
181
- self,
182
- prompt: str,
183
- negative_prompt: str,
184
- seed: int,
185
- height: int,
186
- width: int,
187
- num_frames: int,
188
- frame_rate: float,
189
- num_inference_steps: int,
190
- video_guider_params: MultiModalGuiderParams,
191
- audio_guider_params: MultiModalGuiderParams,
192
- images: list[ImageConditioningInput],
193
- tiling_config: TilingConfig | None = None,
194
- enhance_prompt: bool = False,
195
- streaming_prefetch_count: int | None = None,
196
- max_batch_size: int = 1,
197
- stage_1_sigmas: torch.Tensor | None = None,
198
- stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS,
199
- ) -> tuple[Iterator[torch.Tensor], Audio]:
200
- assert_resolution(height=height, width=width, is_two_stage=True)
201
-
202
- generator = torch.Generator(device=self.device).manual_seed(seed)
203
- noiser = GaussianNoiser(generator=generator)
204
- dtype = torch.bfloat16
205
-
206
- ctx_p, ctx_n = self.prompt_encoder(
207
- [prompt, negative_prompt],
208
- enhance_first_prompt=enhance_prompt,
209
- enhance_prompt_image=images[0][0] if len(images) > 0 else None,
210
- enhance_prompt_seed=seed,
211
- streaming_prefetch_count=streaming_prefetch_count,
212
- )
213
- v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
214
- v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
215
-
216
- # Stage 1: Generate video at half resolution with CFG guidance using res2s sampler.
217
- stage_1_output_shape = VideoPixelShape(
218
- batch=1,
219
- frames=num_frames,
220
- width=width // 2,
221
- height=height // 2,
222
- fps=frame_rate,
223
- )
224
- stage_1_conditionings = self.image_conditioner(
225
- lambda enc: combined_image_conditionings(
226
- images=images,
227
- height=stage_1_output_shape.height,
228
- width=stage_1_output_shape.width,
229
- video_encoder=enc,
230
- dtype=dtype,
231
- device=self.device,
232
- )
233
- )
234
-
235
- stepper = Res2sDiffusionStep()
236
-
237
- if stage_1_sigmas is None:
238
- empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape())
239
- stage_1_sigmas = self._scheduler.execute(latent=empty_latent, steps=num_inference_steps)
240
- sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device)
241
-
242
- video_state, audio_state = self.stage_1(
243
- denoiser=GuidedDenoiser(
244
- v_context=v_context_p,
245
- a_context=a_context_p,
246
- video_guider=MultiModalGuider(
247
- params=video_guider_params,
248
- negative_context=v_context_n,
249
- ),
250
- audio_guider=MultiModalGuider(
251
- params=audio_guider_params,
252
- negative_context=a_context_n,
253
- ),
254
- ),
255
- sigmas=sigmas,
256
- noiser=noiser,
257
- stepper=stepper,
258
- width=stage_1_output_shape.width,
259
- height=stage_1_output_shape.height,
260
- frames=num_frames,
261
- fps=frame_rate,
262
- video=ModalitySpec(context=v_context_p, conditionings=stage_1_conditionings),
263
- audio=ModalitySpec(context=a_context_p),
264
- loop=res2s_audio_video_denoising_loop,
265
- streaming_prefetch_count=streaming_prefetch_count,
266
- max_batch_size=max_batch_size,
267
- )
268
-
269
- # Stage 2: Upsample and refine the video at higher resolution with distilled LoRA.
270
- upscaled_video_latent = self.upsampler(video_state.latent[:1])
271
-
272
- stage_2_sigmas = stage_2_sigmas.to(dtype=torch.float32, device=self.device)
273
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
274
- stage_2_conditionings = self.image_conditioner(
275
- lambda enc: combined_image_conditionings(
276
- images=images,
277
- height=stage_2_output_shape.height,
278
- width=stage_2_output_shape.width,
279
- video_encoder=enc,
280
- dtype=dtype,
281
- device=self.device,
282
- )
283
- )
284
-
285
- video_state, audio_state = self.stage_2(
286
- denoiser=SimpleDenoiser(v_context=v_context_p, a_context=a_context_p),
287
- sigmas=stage_2_sigmas,
288
- noiser=noiser,
289
- stepper=stepper,
290
- width=width,
291
- height=height,
292
- frames=num_frames,
293
- fps=frame_rate,
294
- video=ModalitySpec(
295
- context=v_context_p,
296
- conditionings=stage_2_conditionings,
297
- noise_scale=stage_2_sigmas[0].item(),
298
- initial_latent=upscaled_video_latent,
299
- ),
300
- audio=ModalitySpec(
301
- context=a_context_p,
302
- noise_scale=stage_2_sigmas[0].item(),
303
- initial_latent=audio_state.latent,
304
- ),
305
- loop=res2s_audio_video_denoising_loop,
306
- streaming_prefetch_count=streaming_prefetch_count,
307
- )
308
-
309
- decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
310
- decoded_audio = self.audio_decoder(audio_state.latent)
311
- return decoded_video, decoded_audio
312
-
313
-
314
- @torch.inference_mode()
315
- def main() -> None:
316
- logging.getLogger().setLevel(logging.INFO)
317
- parser = hq_2_stage_arg_parser(params=LTX_2_3_HQ_PARAMS)
318
- args = parser.parse_args()
319
- pipeline = TI2VidTwoStagesHQPipeline(
320
- checkpoint_path=args.checkpoint_path,
321
- distilled_lora=args.distilled_lora,
322
- distilled_lora_strength_stage_1=args.distilled_lora_strength_stage_1,
323
- distilled_lora_strength_stage_2=args.distilled_lora_strength_stage_2,
324
- spatial_upsampler_path=args.spatial_upsampler_path,
325
- gemma_root=args.gemma_root,
326
- loras=tuple(args.lora) if args.lora else (),
327
- quantization=args.quantization,
328
- torch_compile=args.compile,
329
- )
330
- tiling_config = TilingConfig.default()
331
- video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
332
- video, audio = pipeline(
333
- prompt=args.prompt,
334
- negative_prompt=args.negative_prompt,
335
- seed=args.seed,
336
- height=args.height,
337
- width=args.width,
338
- num_frames=args.num_frames,
339
- frame_rate=args.frame_rate,
340
- num_inference_steps=args.num_inference_steps,
341
- video_guider_params=MultiModalGuiderParams(
342
- cfg_scale=args.video_cfg_guidance_scale,
343
- stg_scale=args.video_stg_guidance_scale,
344
- rescale_scale=args.video_rescale_scale,
345
- modality_scale=args.a2v_guidance_scale,
346
- skip_step=args.video_skip_step,
347
- stg_blocks=args.video_stg_blocks,
348
- ),
349
- audio_guider_params=MultiModalGuiderParams(
350
- cfg_scale=args.audio_cfg_guidance_scale,
351
- stg_scale=args.audio_stg_guidance_scale,
352
- rescale_scale=args.audio_rescale_scale,
353
- modality_scale=args.v2a_guidance_scale,
354
- skip_step=args.audio_skip_step,
355
- stg_blocks=args.audio_stg_blocks,
356
- ),
357
- images=args.images,
358
- tiling_config=tiling_config,
359
- streaming_prefetch_count=args.streaming_prefetch_count,
360
- max_batch_size=args.max_batch_size,
361
- )
362
-
363
- encode_video(
364
- video=video,
365
- fps=args.frame_rate,
366
- audio=audio,
367
- output_path=args.output_path,
368
- video_chunks_number=video_chunks_number,
369
- )
370
-
371
-
372
- if __name__ == "__main__":
373
- main()