dagloop5 commited on
Commit
ed1c038
·
verified ·
1 Parent(s): ec187f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -9
app.py CHANGED
@@ -51,9 +51,10 @@ from safetensors import safe_open
51
  import json
52
  import requests
53
 
54
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
55
  from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
56
  from ltx_core.components.noisers import GaussianNoiser
 
57
  from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
58
  from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
59
  from ltx_core.model.upsampler import upsample_video
@@ -71,6 +72,7 @@ from ltx_pipelines.utils.helpers import (
71
  encode_prompts,
72
  simple_denoising_func,
73
  multi_modal_guider_denoising_func,
 
74
  )
75
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
76
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
@@ -152,6 +154,7 @@ class LTX23DistilledA2VPipeline:
152
  video_guider_params: MultiModalGuiderParams,
153
  audio_guider_params: MultiModalGuiderParams,
154
  images: list[ImageConditioningInput],
 
155
  audio_path: str | None = None,
156
  tiling_config: TilingConfig | None = None,
157
  enhance_prompt: bool = False,
@@ -160,7 +163,7 @@ class LTX23DistilledA2VPipeline:
160
 
161
  generator = torch.Generator(device=self.device).manual_seed(seed)
162
  noiser = GaussianNoiser(generator=generator)
163
- stepper = EulerDiffusionStep()
164
  dtype = torch.bfloat16
165
 
166
  ctx_p, ctx_n = encode_prompts(
@@ -201,10 +204,19 @@ class LTX23DistilledA2VPipeline:
201
 
202
  video_encoder = self.model_ledger.video_encoder()
203
  transformer = self.model_ledger.transformer()
204
- stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
205
 
206
- def stage1_denoising_loop(sigmas, video_state, audio_state, stepper):
207
- return euler_denoising_loop(
 
 
 
 
 
 
 
 
 
 
208
  sigmas=sigmas,
209
  video_state=video_state,
210
  audio_state=audio_state,
@@ -224,8 +236,8 @@ class LTX23DistilledA2VPipeline:
224
  ),
225
  )
226
 
227
- def stage2_denoising_loop(sigmas, video_state, audio_state, stepper):
228
- return euler_denoising_loop(
229
  sigmas=sigmas,
230
  video_state=video_state,
231
  audio_state=audio_state,
@@ -674,9 +686,12 @@ def get_gpu_duration(
674
  voice_strength: float = 0.0,
675
  realism_strength: float = 0.0,
676
  transition_strength: float = 0.0,
 
677
  progress=None,
678
  ):
679
- return int(gpu_duration)
 
 
680
 
681
  @spaces.GPU(duration=get_gpu_duration)
682
  @torch.inference_mode()
@@ -713,6 +728,7 @@ def generate_video(
713
  voice_strength: float = 0.0,
714
  realism_strength: float = 0.0,
715
  transition_strength: float = 0.0,
 
716
  progress=gr.Progress(track_tqdm=True),
717
  ):
718
  try:
@@ -783,6 +799,7 @@ def generate_video(
783
  video_guider_params=video_guider_params,
784
  audio_guider_params=audio_guider_params,
785
  images=images,
 
786
  audio_path=input_audio,
787
  tiling_config=tiling_config,
788
  enhance_prompt=enhance_prompt,
@@ -860,6 +877,13 @@ with gr.Blocks(title="LTX-2.3 Distilled with LoRAs, Negative Prompting, and Adva
860
  with gr.Row():
861
  width = gr.Number(label="Width", value=1536, precision=0)
862
  height = gr.Number(label="Height", value=1024, precision=0)
 
 
 
 
 
 
 
863
 
864
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
865
 
@@ -973,7 +997,7 @@ with gr.Blocks(title="LTX-2.3 Distilled with LoRAs, Negative Prompting, and Adva
973
  pose_strength, general_strength, motion_strength,
974
  dreamlay_strength, mself_strength, dramatic_strength, fluid_strength,
975
  liquid_strength, demopose_strength, voice_strength, realism_strength,
976
- transition_strength,
977
  ],
978
  outputs=[output_video, seed],
979
  )
 
51
  import json
52
  import requests
53
 
54
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
55
  from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
56
  from ltx_core.components.noisers import GaussianNoiser
57
+ from ltx_core.components.schedulers import LTX2Scheduler
58
  from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
59
  from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
60
  from ltx_core.model.upsampler import upsample_video
 
72
  encode_prompts,
73
  simple_denoising_func,
74
  multi_modal_guider_denoising_func,
75
+ res2s_audio_video_denoising_loop,
76
  )
77
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
78
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
 
154
  video_guider_params: MultiModalGuiderParams,
155
  audio_guider_params: MultiModalGuiderParams,
156
  images: list[ImageConditioningInput],
157
+ num_inference_steps: int = 8,
158
  audio_path: str | None = None,
159
  tiling_config: TilingConfig | None = None,
160
  enhance_prompt: bool = False,
 
163
 
164
  generator = torch.Generator(device=self.device).manual_seed(seed)
165
  noiser = GaussianNoiser(generator=generator)
166
+ stepper = Res2sDiffusionStep()
167
  dtype = torch.bfloat16
168
 
169
  ctx_p, ctx_n = encode_prompts(
 
204
 
205
  video_encoder = self.model_ledger.video_encoder()
206
  transformer = self.model_ledger.transformer()
 
207
 
208
+ # Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
209
+ empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
210
+ VideoPixelShape(batch=1, frames=num_frames, width=width // 2, height=height // 2, fps=frame_rate)
211
+ ).to_torch_shape())
212
+ stage_1_sigmas = (
213
+ LTX2Scheduler()
214
+ .execute(latent=empty_latent, steps=num_inference_steps)
215
+ .to(dtype=torch.float32, device=self.device)
216
+ )
217
+
218
+ def stage1_denoising_loop(sigmas: torch.Tensor, video_state, audio_state, stepper: DiffusionStepProtocol):
219
+ return res2s_audio_video_denoising_loop(
220
  sigmas=sigmas,
221
  video_state=video_state,
222
  audio_state=audio_state,
 
236
  ),
237
  )
238
 
239
+ def stage2_denoising_loop(sigmas: torch.Tensor, video_state, audio_state, stepper: DiffusionStepProtocol):
240
+ return res2s_audio_video_denoising_loop(
241
  sigmas=sigmas,
242
  video_state=video_state,
243
  audio_state=audio_state,
 
686
  voice_strength: float = 0.0,
687
  realism_strength: float = 0.0,
688
  transition_strength: float = 0.0,
689
+ num_inference_steps: int = 8,
690
  progress=None,
691
  ):
692
+ base_duration = int(gpu_duration)
693
+ step_ratio = num_inference_steps / 8 # Normalize to 8 steps as baseline
694
+ return int(base_duration * step_ratio)
695
 
696
  @spaces.GPU(duration=get_gpu_duration)
697
  @torch.inference_mode()
 
728
  voice_strength: float = 0.0,
729
  realism_strength: float = 0.0,
730
  transition_strength: float = 0.0,
731
+ num_inference_steps: int = 8,
732
  progress=gr.Progress(track_tqdm=True),
733
  ):
734
  try:
 
799
  video_guider_params=video_guider_params,
800
  audio_guider_params=audio_guider_params,
801
  images=images,
802
+ num_inference_steps=num_inference_steps,
803
  audio_path=input_audio,
804
  tiling_config=tiling_config,
805
  enhance_prompt=enhance_prompt,
 
877
  with gr.Row():
878
  width = gr.Number(label="Width", value=1536, precision=0)
879
  height = gr.Number(label="Height", value=1024, precision=0)
880
+
881
+ with gr.Row():
882
+ num_inference_steps = gr.Slider(
883
+ label="Stage 1 Inference Steps",
884
+ minimum=2, maximum=16, value=8, step=1,
885
+ info="Higher = more quality but slower (Stage 2 uses fixed 3 steps)"
886
+ )
887
 
888
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
889
 
 
997
  pose_strength, general_strength, motion_strength,
998
  dreamlay_strength, mself_strength, dramatic_strength, fluid_strength,
999
  liquid_strength, demopose_strength, voice_strength, realism_strength,
1000
+ transition_strength, num_inference_steps,
1001
  ],
1002
  outputs=[output_video, seed],
1003
  )