Update app.py
Browse files
app.py
CHANGED
|
@@ -102,7 +102,7 @@ RESOLUTIONS = {
|
|
| 102 |
|
| 103 |
|
| 104 |
class LTX23DistilledA2VPipeline(DistilledPipeline):
|
| 105 |
-
"""DistilledPipeline with optional audio
|
| 106 |
|
| 107 |
def __call__(
|
| 108 |
self,
|
|
@@ -117,20 +117,7 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 117 |
tiling_config: TilingConfig | None = None,
|
| 118 |
enhance_prompt: bool = False,
|
| 119 |
):
|
| 120 |
-
# Standard path when no audio input is provided.
|
| 121 |
print(prompt)
|
| 122 |
-
if audio_path is None:
|
| 123 |
-
return super().__call__(
|
| 124 |
-
prompt=prompt,
|
| 125 |
-
seed=seed,
|
| 126 |
-
height=height,
|
| 127 |
-
width=width,
|
| 128 |
-
num_frames=num_frames,
|
| 129 |
-
frame_rate=frame_rate,
|
| 130 |
-
images=images,
|
| 131 |
-
tiling_config=tiling_config,
|
| 132 |
-
enhance_prompt=enhance_prompt,
|
| 133 |
-
)
|
| 134 |
|
| 135 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 136 |
noiser = GaussianNoiser(generator=generator)
|
|
@@ -145,32 +132,41 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 145 |
)
|
| 146 |
video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
encoded_audio_latent
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
)
|
| 169 |
-
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 170 |
|
| 171 |
video_encoder = self.model_ledger.video_encoder()
|
| 172 |
transformer = self.model_ledger.transformer()
|
| 173 |
-
|
| 174 |
|
| 175 |
def denoising_loop(sigmas, video_state, audio_state, stepper):
|
| 176 |
return euler_denoising_loop(
|
|
@@ -185,26 +181,26 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 185 |
),
|
| 186 |
)
|
| 187 |
|
| 188 |
-
|
| 189 |
batch=1,
|
| 190 |
frames=num_frames,
|
| 191 |
-
width=width
|
| 192 |
-
height=height
|
| 193 |
fps=frame_rate,
|
| 194 |
)
|
| 195 |
-
|
| 196 |
images=images,
|
| 197 |
-
height=
|
| 198 |
-
width=
|
| 199 |
video_encoder=video_encoder,
|
| 200 |
dtype=dtype,
|
| 201 |
device=self.device,
|
| 202 |
)
|
| 203 |
video_state = denoise_video_only(
|
| 204 |
-
output_shape=
|
| 205 |
-
conditionings=
|
| 206 |
noiser=noiser,
|
| 207 |
-
sigmas=
|
| 208 |
stepper=stepper,
|
| 209 |
denoising_loop_fn=denoising_loop,
|
| 210 |
components=self.pipeline_components,
|
|
@@ -213,39 +209,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 213 |
initial_audio_latent=encoded_audio_latent,
|
| 214 |
)
|
| 215 |
|
| 216 |
-
torch.cuda.synchronize()
|
| 217 |
-
cleanup_memory()
|
| 218 |
-
|
| 219 |
-
upscaled_video_latent = upsample_video(
|
| 220 |
-
latent=video_state.latent[:1],
|
| 221 |
-
video_encoder=video_encoder,
|
| 222 |
-
upsampler=self.model_ledger.spatial_upsampler(),
|
| 223 |
-
)
|
| 224 |
-
stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
|
| 225 |
-
stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
| 226 |
-
stage_2_conditionings = combined_image_conditionings(
|
| 227 |
-
images=images,
|
| 228 |
-
height=stage_2_output_shape.height,
|
| 229 |
-
width=stage_2_output_shape.width,
|
| 230 |
-
video_encoder=video_encoder,
|
| 231 |
-
dtype=dtype,
|
| 232 |
-
device=self.device,
|
| 233 |
-
)
|
| 234 |
-
video_state = denoise_video_only(
|
| 235 |
-
output_shape=stage_2_output_shape,
|
| 236 |
-
conditionings=stage_2_conditionings,
|
| 237 |
-
noiser=noiser,
|
| 238 |
-
sigmas=stage_2_sigmas,
|
| 239 |
-
stepper=stepper,
|
| 240 |
-
denoising_loop_fn=denoising_loop,
|
| 241 |
-
components=self.pipeline_components,
|
| 242 |
-
dtype=dtype,
|
| 243 |
-
device=self.device,
|
| 244 |
-
noise_scale=stage_2_sigmas[0],
|
| 245 |
-
initial_video_latent=upscaled_video_latent,
|
| 246 |
-
initial_audio_latent=encoded_audio_latent,
|
| 247 |
-
)
|
| 248 |
-
|
| 249 |
torch.cuda.synchronize()
|
| 250 |
del transformer
|
| 251 |
del video_encoder
|
|
@@ -257,10 +220,7 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 257 |
tiling_config,
|
| 258 |
generator,
|
| 259 |
)
|
| 260 |
-
|
| 261 |
-
waveform=decoded_audio.waveform.squeeze(0),
|
| 262 |
-
sampling_rate=decoded_audio.sampling_rate,
|
| 263 |
-
)
|
| 264 |
return decoded_video, original_audio
|
| 265 |
|
| 266 |
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
class LTX23DistilledA2VPipeline(DistilledPipeline):
|
| 105 |
+
"""DistilledPipeline: single stage, full resolution, 8 steps, with optional audio."""
|
| 106 |
|
| 107 |
def __call__(
|
| 108 |
self,
|
|
|
|
| 117 |
tiling_config: TilingConfig | None = None,
|
| 118 |
enhance_prompt: bool = False,
|
| 119 |
):
|
|
|
|
| 120 |
print(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 123 |
noiser = GaussianNoiser(generator=generator)
|
|
|
|
| 132 |
)
|
| 133 |
video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
|
| 134 |
|
| 135 |
+
# Audio encoding — only runs if audio is provided
|
| 136 |
+
encoded_audio_latent = None
|
| 137 |
+
original_audio = None
|
| 138 |
+
if audio_path is not None:
|
| 139 |
+
video_duration = num_frames / frame_rate
|
| 140 |
+
decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
|
| 141 |
+
if decoded_audio is None:
|
| 142 |
+
raise ValueError(f"Could not extract audio stream from {audio_path}")
|
| 143 |
+
|
| 144 |
+
encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
|
| 145 |
+
audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
|
| 146 |
+
expected_frames = audio_shape.frames
|
| 147 |
+
actual_frames = encoded_audio_latent.shape[2]
|
| 148 |
+
|
| 149 |
+
if actual_frames > expected_frames:
|
| 150 |
+
encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
|
| 151 |
+
elif actual_frames < expected_frames:
|
| 152 |
+
pad = torch.zeros(
|
| 153 |
+
encoded_audio_latent.shape[0],
|
| 154 |
+
encoded_audio_latent.shape[1],
|
| 155 |
+
expected_frames - actual_frames,
|
| 156 |
+
encoded_audio_latent.shape[3],
|
| 157 |
+
device=encoded_audio_latent.device,
|
| 158 |
+
dtype=encoded_audio_latent.dtype,
|
| 159 |
+
)
|
| 160 |
+
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 161 |
+
|
| 162 |
+
original_audio = Audio(
|
| 163 |
+
waveform=decoded_audio.waveform.squeeze(0),
|
| 164 |
+
sampling_rate=decoded_audio.sampling_rate,
|
| 165 |
)
|
|
|
|
| 166 |
|
| 167 |
video_encoder = self.model_ledger.video_encoder()
|
| 168 |
transformer = self.model_ledger.transformer()
|
| 169 |
+
sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
|
| 170 |
|
| 171 |
def denoising_loop(sigmas, video_state, audio_state, stepper):
|
| 172 |
return euler_denoising_loop(
|
|
|
|
| 181 |
),
|
| 182 |
)
|
| 183 |
|
| 184 |
+
output_shape = VideoPixelShape(
|
| 185 |
batch=1,
|
| 186 |
frames=num_frames,
|
| 187 |
+
width=width,
|
| 188 |
+
height=height,
|
| 189 |
fps=frame_rate,
|
| 190 |
)
|
| 191 |
+
conditionings = combined_image_conditionings(
|
| 192 |
images=images,
|
| 193 |
+
height=output_shape.height,
|
| 194 |
+
width=output_shape.width,
|
| 195 |
video_encoder=video_encoder,
|
| 196 |
dtype=dtype,
|
| 197 |
device=self.device,
|
| 198 |
)
|
| 199 |
video_state = denoise_video_only(
|
| 200 |
+
output_shape=output_shape,
|
| 201 |
+
conditionings=conditionings,
|
| 202 |
noiser=noiser,
|
| 203 |
+
sigmas=sigmas,
|
| 204 |
stepper=stepper,
|
| 205 |
denoising_loop_fn=denoising_loop,
|
| 206 |
components=self.pipeline_components,
|
|
|
|
| 209 |
initial_audio_latent=encoded_audio_latent,
|
| 210 |
)
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
torch.cuda.synchronize()
|
| 213 |
del transformer
|
| 214 |
del video_encoder
|
|
|
|
| 220 |
tiling_config,
|
| 221 |
generator,
|
| 222 |
)
|
| 223 |
+
|
|
|
|
|
|
|
|
|
|
| 224 |
return decoded_video, original_audio
|
| 225 |
|
| 226 |
|