dagloop5 commited on
Commit
4e8337c
·
verified ·
1 Parent(s): 9a24168

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -82
app.py CHANGED
@@ -102,7 +102,7 @@ RESOLUTIONS = {
102
 
103
 
104
  class LTX23DistilledA2VPipeline(DistilledPipeline):
105
- """DistilledPipeline with optional audio conditioning."""
106
 
107
  def __call__(
108
  self,
@@ -117,20 +117,7 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
117
  tiling_config: TilingConfig | None = None,
118
  enhance_prompt: bool = False,
119
  ):
120
- # Standard path when no audio input is provided.
121
  print(prompt)
122
- if audio_path is None:
123
- return super().__call__(
124
- prompt=prompt,
125
- seed=seed,
126
- height=height,
127
- width=width,
128
- num_frames=num_frames,
129
- frame_rate=frame_rate,
130
- images=images,
131
- tiling_config=tiling_config,
132
- enhance_prompt=enhance_prompt,
133
- )
134
 
135
  generator = torch.Generator(device=self.device).manual_seed(seed)
136
  noiser = GaussianNoiser(generator=generator)
@@ -145,32 +132,41 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
145
  )
146
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
147
 
148
- video_duration = num_frames / frame_rate
149
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
150
- if decoded_audio is None:
151
- raise ValueError(f"Could not extract audio stream from {audio_path}")
152
-
153
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
154
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
155
- expected_frames = audio_shape.frames
156
- actual_frames = encoded_audio_latent.shape[2]
157
-
158
- if actual_frames > expected_frames:
159
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
160
- elif actual_frames < expected_frames:
161
- pad = torch.zeros(
162
- encoded_audio_latent.shape[0],
163
- encoded_audio_latent.shape[1],
164
- expected_frames - actual_frames,
165
- encoded_audio_latent.shape[3],
166
- device=encoded_audio_latent.device,
167
- dtype=encoded_audio_latent.dtype,
 
 
 
 
 
 
 
 
 
 
168
  )
169
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
170
 
171
  video_encoder = self.model_ledger.video_encoder()
172
  transformer = self.model_ledger.transformer()
173
- stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
174
 
175
  def denoising_loop(sigmas, video_state, audio_state, stepper):
176
  return euler_denoising_loop(
@@ -185,26 +181,26 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
185
  ),
186
  )
187
 
188
- stage_1_output_shape = VideoPixelShape(
189
  batch=1,
190
  frames=num_frames,
191
- width=width // 2,
192
- height=height // 2,
193
  fps=frame_rate,
194
  )
195
- stage_1_conditionings = combined_image_conditionings(
196
  images=images,
197
- height=stage_1_output_shape.height,
198
- width=stage_1_output_shape.width,
199
  video_encoder=video_encoder,
200
  dtype=dtype,
201
  device=self.device,
202
  )
203
  video_state = denoise_video_only(
204
- output_shape=stage_1_output_shape,
205
- conditionings=stage_1_conditionings,
206
  noiser=noiser,
207
- sigmas=stage_1_sigmas,
208
  stepper=stepper,
209
  denoising_loop_fn=denoising_loop,
210
  components=self.pipeline_components,
@@ -213,39 +209,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
213
  initial_audio_latent=encoded_audio_latent,
214
  )
215
 
216
- torch.cuda.synchronize()
217
- cleanup_memory()
218
-
219
- upscaled_video_latent = upsample_video(
220
- latent=video_state.latent[:1],
221
- video_encoder=video_encoder,
222
- upsampler=self.model_ledger.spatial_upsampler(),
223
- )
224
- stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
225
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
226
- stage_2_conditionings = combined_image_conditionings(
227
- images=images,
228
- height=stage_2_output_shape.height,
229
- width=stage_2_output_shape.width,
230
- video_encoder=video_encoder,
231
- dtype=dtype,
232
- device=self.device,
233
- )
234
- video_state = denoise_video_only(
235
- output_shape=stage_2_output_shape,
236
- conditionings=stage_2_conditionings,
237
- noiser=noiser,
238
- sigmas=stage_2_sigmas,
239
- stepper=stepper,
240
- denoising_loop_fn=denoising_loop,
241
- components=self.pipeline_components,
242
- dtype=dtype,
243
- device=self.device,
244
- noise_scale=stage_2_sigmas[0],
245
- initial_video_latent=upscaled_video_latent,
246
- initial_audio_latent=encoded_audio_latent,
247
- )
248
-
249
  torch.cuda.synchronize()
250
  del transformer
251
  del video_encoder
@@ -257,10 +220,7 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
257
  tiling_config,
258
  generator,
259
  )
260
- original_audio = Audio(
261
- waveform=decoded_audio.waveform.squeeze(0),
262
- sampling_rate=decoded_audio.sampling_rate,
263
- )
264
  return decoded_video, original_audio
265
 
266
 
 
102
 
103
 
104
  class LTX23DistilledA2VPipeline(DistilledPipeline):
105
+ """DistilledPipeline: single stage, full resolution, 8 steps, with optional audio."""
106
 
107
  def __call__(
108
  self,
 
117
  tiling_config: TilingConfig | None = None,
118
  enhance_prompt: bool = False,
119
  ):
 
120
  print(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  generator = torch.Generator(device=self.device).manual_seed(seed)
123
  noiser = GaussianNoiser(generator=generator)
 
132
  )
133
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
134
 
135
+ # Audio encoding only runs if audio is provided
136
+ encoded_audio_latent = None
137
+ original_audio = None
138
+ if audio_path is not None:
139
+ video_duration = num_frames / frame_rate
140
+ decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
141
+ if decoded_audio is None:
142
+ raise ValueError(f"Could not extract audio stream from {audio_path}")
143
+
144
+ encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
145
+ audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
146
+ expected_frames = audio_shape.frames
147
+ actual_frames = encoded_audio_latent.shape[2]
148
+
149
+ if actual_frames > expected_frames:
150
+ encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
151
+ elif actual_frames < expected_frames:
152
+ pad = torch.zeros(
153
+ encoded_audio_latent.shape[0],
154
+ encoded_audio_latent.shape[1],
155
+ expected_frames - actual_frames,
156
+ encoded_audio_latent.shape[3],
157
+ device=encoded_audio_latent.device,
158
+ dtype=encoded_audio_latent.dtype,
159
+ )
160
+ encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
161
+
162
+ original_audio = Audio(
163
+ waveform=decoded_audio.waveform.squeeze(0),
164
+ sampling_rate=decoded_audio.sampling_rate,
165
  )
 
166
 
167
  video_encoder = self.model_ledger.video_encoder()
168
  transformer = self.model_ledger.transformer()
169
+ sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
170
 
171
  def denoising_loop(sigmas, video_state, audio_state, stepper):
172
  return euler_denoising_loop(
 
181
  ),
182
  )
183
 
184
+ output_shape = VideoPixelShape(
185
  batch=1,
186
  frames=num_frames,
187
+ width=width,
188
+ height=height,
189
  fps=frame_rate,
190
  )
191
+ conditionings = combined_image_conditionings(
192
  images=images,
193
+ height=output_shape.height,
194
+ width=output_shape.width,
195
  video_encoder=video_encoder,
196
  dtype=dtype,
197
  device=self.device,
198
  )
199
  video_state = denoise_video_only(
200
+ output_shape=output_shape,
201
+ conditionings=conditionings,
202
  noiser=noiser,
203
+ sigmas=sigmas,
204
  stepper=stepper,
205
  denoising_loop_fn=denoising_loop,
206
  components=self.pipeline_components,
 
209
  initial_audio_latent=encoded_audio_latent,
210
  )
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  torch.cuda.synchronize()
213
  del transformer
214
  del video_encoder
 
220
  tiling_config,
221
  generator,
222
  )
223
+
 
 
 
224
  return decoded_video, original_audio
225
 
226