dagloop5 commited on
Commit
593d864
·
verified ·
1 Parent(s): 733dd76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -60
app.py CHANGED
@@ -66,6 +66,8 @@ from ltx_pipelines.utils.helpers import (
66
  simple_denoising_func,
67
  )
68
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
 
 
69
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
70
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
71
 
@@ -131,7 +133,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
131
 
132
  generator = torch.Generator(device=self.device).manual_seed(seed)
133
  noiser = GaussianNoiser(generator=generator)
134
- stepper = EulerDiffusionStep()
135
  dtype = torch.bfloat16
136
 
137
  (ctx_p,) = encode_prompts(
@@ -148,7 +149,8 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
148
  raise ValueError(f"Could not extract audio stream from {audio_path}")
149
 
150
  encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
151
- # Keep the uploaded audio as a soft conditioning signal, not a hard copy.
 
152
  audio_mix_ratio = float(max(0.0, min(1.0, audio_mix_ratio)))
153
  if audio_mix_ratio < 1.0:
154
  noise = torch.randn(
@@ -161,7 +163,13 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
161
  audio_mix_ratio * encoded_audio_latent
162
  + (1.0 - audio_mix_ratio) * noise
163
  )
164
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
 
 
 
 
 
 
165
  expected_frames = audio_shape.frames
166
  actual_frames = encoded_audio_latent.shape[2]
167
 
@@ -178,22 +186,8 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
178
  )
179
  encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
180
 
181
- video_encoder = self.model_ledger.video_encoder()
182
- transformer = self.model_ledger.transformer()
183
  stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
184
-
185
- def denoising_loop(sigmas, video_state, audio_state, stepper):
186
- return euler_denoising_loop(
187
- sigmas=sigmas,
188
- video_state=video_state,
189
- audio_state=audio_state,
190
- stepper=stepper,
191
- denoise_fn=simple_denoising_func(
192
- video_context=video_context,
193
- audio_context=audio_context,
194
- transformer=transformer,
195
- ),
196
- )
197
 
198
  stage_1_output_shape = VideoPixelShape(
199
  batch=1,
@@ -206,21 +200,28 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
206
  images=images,
207
  height=stage_1_output_shape.height,
208
  width=stage_1_output_shape.width,
209
- video_encoder=video_encoder,
210
  dtype=dtype,
211
  device=self.device,
212
  )
213
- video_state = denoise_video_only(
214
- output_shape=stage_1_output_shape,
215
- conditionings=stage_1_conditionings,
216
- noiser=noiser,
217
  sigmas=stage_1_sigmas,
218
- stepper=stepper,
219
- denoising_loop_fn=denoising_loop,
220
- components=self.pipeline_components,
221
- dtype=dtype,
222
- device=self.device,
223
- initial_audio_latent=encoded_audio_latent,
 
 
 
 
 
 
 
 
224
  )
225
 
226
  torch.cuda.synchronize()
@@ -228,56 +229,56 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
228
 
229
  upscaled_video_latent = upsample_video(
230
  latent=video_state.latent[:1],
231
- video_encoder=video_encoder,
232
  upsampler=self.model_ledger.spatial_upsampler(),
233
  )
234
- stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
235
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
 
 
 
 
 
 
236
  stage_2_conditionings = combined_image_conditionings(
237
  images=images,
238
  height=stage_2_output_shape.height,
239
  width=stage_2_output_shape.width,
240
- video_encoder=video_encoder,
241
  dtype=dtype,
242
  device=self.device,
243
  )
244
- video_state = denoise_video_only(
245
- output_shape=stage_2_output_shape,
246
- conditionings=stage_2_conditionings,
247
- noiser=noiser,
248
  sigmas=stage_2_sigmas,
249
- stepper=stepper,
250
- denoising_loop_fn=denoising_loop,
251
- components=self.pipeline_components,
252
- dtype=dtype,
253
- device=self.device,
254
- noise_scale=stage_2_sigmas[0],
255
- initial_video_latent=upscaled_video_latent,
256
- initial_audio_latent=encoded_audio_latent,
 
 
 
 
 
 
 
 
257
  )
258
 
259
  torch.cuda.synchronize()
260
- del transformer
261
- del video_encoder
262
  cleanup_memory()
263
 
264
- decoded_video = vae_decode_video(
265
  video_state.latent,
266
- self.model_ledger.video_decoder(),
267
  tiling_config,
268
  generator,
269
  )
270
-
271
- generated_audio_latent = getattr(video_state, "audio_latent", None)
272
- if generated_audio_latent is None:
273
- raise RuntimeError(
274
- "No generated audio latent was returned. "
275
- "Patch denoise_video_only() to expose the audio latent, "
276
- "or switch this block to the upstream stage API that returns "
277
- "video_state, audio_state."
278
- )
279
-
280
- decoded_audio = self.model_ledger.audio_decoder()(generated_audio_latent)
281
  return decoded_video, decoded_audio
282
 
283
 
 
66
  simple_denoising_func,
67
  )
68
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
69
+ from ltx_pipelines.utils.denoisers import SimpleDenoiser
70
+ from ltx_pipelines.utils.types import ModalitySpec
71
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
72
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
73
 
 
133
 
134
  generator = torch.Generator(device=self.device).manual_seed(seed)
135
  noiser = GaussianNoiser(generator=generator)
 
136
  dtype = torch.bfloat16
137
 
138
  (ctx_p,) = encode_prompts(
 
149
  raise ValueError(f"Could not extract audio stream from {audio_path}")
150
 
151
  encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
152
+
153
+ # Keep the uploaded audio as a soft prior instead of a hard target.
154
  audio_mix_ratio = float(max(0.0, min(1.0, audio_mix_ratio)))
155
  if audio_mix_ratio < 1.0:
156
  noise = torch.randn(
 
163
  audio_mix_ratio * encoded_audio_latent
164
  + (1.0 - audio_mix_ratio) * noise
165
  )
166
+
167
+ audio_shape = AudioLatentShape.from_duration(
168
+ batch=1,
169
+ duration=video_duration,
170
+ channels=8,
171
+ mel_bins=16,
172
+ )
173
  expected_frames = audio_shape.frames
174
  actual_frames = encoded_audio_latent.shape[2]
175
 
 
186
  )
187
  encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
188
 
 
 
189
  stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
190
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  stage_1_output_shape = VideoPixelShape(
193
  batch=1,
 
200
  images=images,
201
  height=stage_1_output_shape.height,
202
  width=stage_1_output_shape.width,
203
+ video_encoder=self.model_ledger.video_encoder(),
204
  dtype=dtype,
205
  device=self.device,
206
  )
207
+
208
+ video_state, audio_state = self.stage(
209
+ denoiser=SimpleDenoiser(video_context, audio_context),
 
210
  sigmas=stage_1_sigmas,
211
+ noiser=noiser,
212
+ width=stage_1_output_shape.width,
213
+ height=stage_1_output_shape.height,
214
+ frames=num_frames,
215
+ fps=frame_rate,
216
+ video=ModalitySpec(
217
+ context=video_context,
218
+ conditionings=stage_1_conditionings,
219
+ ),
220
+ audio=ModalitySpec(
221
+ context=audio_context,
222
+ noise_scale=stage_1_sigmas[0].item(),
223
+ initial_latent=encoded_audio_latent,
224
+ ),
225
  )
226
 
227
  torch.cuda.synchronize()
 
229
 
230
  upscaled_video_latent = upsample_video(
231
  latent=video_state.latent[:1],
232
+ video_encoder=self.model_ledger.video_encoder(),
233
  upsampler=self.model_ledger.spatial_upsampler(),
234
  )
235
+
236
+ stage_2_output_shape = VideoPixelShape(
237
+ batch=1,
238
+ frames=num_frames,
239
+ width=width,
240
+ height=height,
241
+ fps=frame_rate,
242
+ )
243
  stage_2_conditionings = combined_image_conditionings(
244
  images=images,
245
  height=stage_2_output_shape.height,
246
  width=stage_2_output_shape.width,
247
+ video_encoder=self.model_ledger.video_encoder(),
248
  dtype=dtype,
249
  device=self.device,
250
  )
251
+
252
+ video_state, audio_state = self.stage(
253
+ denoiser=SimpleDenoiser(video_context, audio_context),
 
254
  sigmas=stage_2_sigmas,
255
+ noiser=noiser,
256
+ width=stage_2_output_shape.width,
257
+ height=stage_2_output_shape.height,
258
+ frames=num_frames,
259
+ fps=frame_rate,
260
+ video=ModalitySpec(
261
+ context=video_context,
262
+ conditionings=stage_2_conditionings,
263
+ noise_scale=stage_2_sigmas[0].item(),
264
+ initial_latent=upscaled_video_latent,
265
+ ),
266
+ audio=ModalitySpec(
267
+ context=audio_context,
268
+ noise_scale=stage_2_sigmas[0].item(),
269
+ initial_latent=audio_state.latent,
270
+ ),
271
  )
272
 
273
  torch.cuda.synchronize()
 
 
274
  cleanup_memory()
275
 
276
+ decoded_video = self.model_ledger.video_decoder()(
277
  video_state.latent,
 
278
  tiling_config,
279
  generator,
280
  )
281
+ decoded_audio = self.model_ledger.audio_decoder()(audio_state.latent)
 
 
 
 
 
 
 
 
 
 
282
  return decoded_video, decoded_audio
283
 
284