sam-motamed commited on
Commit
b20559f
·
verified ·
1 Parent(s): c6cc81a

Add diffusers/pipeline_void.py

Browse files
Files changed (1) hide show
  1. diffusers/pipeline_void.py +559 -0
diffusers/pipeline_void.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VOID (Video Object and Interaction Deletion) Pipeline.
3
+
4
+ Simple usage:
5
+
6
+ from pipeline_void import VOIDPipeline
7
+
8
+ pipe = VOIDPipeline.from_pretrained("netflix/void-model")
9
+ result = pipe.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.")
10
+ result.save("output.mp4")
11
+
12
+ Pass 2 refinement:
13
+
14
+ pipe2 = VOIDPipeline.from_pretrained("netflix/void-model", void_pass=2)
15
+ result2 = pipe2.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.",
16
+ pass1_video="output.mp4")
17
+ result2.save("output_refined.mp4")
18
+ """
19
+
20
+ import os
21
+ import json
22
+ import subprocess
23
+ import sys
24
+ import tempfile
25
+ from dataclasses import dataclass
26
+ from typing import List, Optional, Tuple, Union
27
+
28
+ import cv2
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ from huggingface_hub import hf_hub_download, snapshot_download
33
+ from safetensors.torch import load_file
34
+ from diffusers import CogVideoXDDIMScheduler
35
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
36
+
37
+ from cogvideox_transformer3d import CogVideoXTransformer3DModel
38
+ from cogvideox_vae import AutoencoderKLCogVideoX
39
+ from pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
40
+
41
+ # The base model that VOID is fine-tuned from
42
+ BASE_MODEL_REPO = "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP"
43
+
44
+ # Checkpoint filenames in the VOID repo
45
+ PASS_CHECKPOINTS = {
46
+ 1: "void_pass1.safetensors",
47
+ 2: "void_pass2.safetensors",
48
+ }
49
+
50
+ # Default negative prompt (from config/quadmask_cogvideox.py)
51
+ DEFAULT_NEGATIVE_PROMPT = (
52
+ "The video is not of a high quality, it has a low resolution. "
53
+ "Watermark present in each frame. The background is solid. "
54
+ "Strange body and strange trajectory. Distortion. "
55
+ )
56
+
57
+
58
+ @dataclass
59
+ class VOIDOutput:
60
+ """Output from VOID pipeline."""
61
+ video: torch.Tensor # (T, H, W, 3) uint8
62
+ video_float: torch.Tensor # (1, C, T, H, W) float [0, 1]
63
+
64
+ def save(self, path: str, fps: int = 12):
65
+ """Save output video to file."""
66
+ import imageio
67
+ frames = [f for f in self.video.cpu().numpy()]
68
+ imageio.mimwrite(path, frames, fps=fps)
69
+ print(f"Saved {len(frames)} frames to {path}")
70
+
71
+
72
+ def _merge_void_weights(transformer, checkpoint_path):
73
+ """Merge VOID checkpoint into base transformer, handling channel mismatch."""
74
+ state_dict = load_file(checkpoint_path)
75
+ param_name = "patch_embed.proj.weight"
76
+
77
+ if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
78
+ latent_ch = 16
79
+ feat_scale = 8
80
+ feat_dim = int(latent_ch * feat_scale)
81
+
82
+ new_weight = transformer.state_dict()[param_name].clone()
83
+ new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
84
+ new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
85
+ state_dict[param_name] = new_weight
86
+
87
+ m, u = transformer.load_state_dict(state_dict, strict=False)
88
+ if m:
89
+ print(f"[VOID] Missing keys: {len(m)}")
90
+ if u:
91
+ print(f"[VOID] Unexpected keys: {len(u)}")
92
+
93
+ return transformer
94
+
95
+
96
+ def _load_video(path: str, max_frames: int) -> np.ndarray:
97
+ """Load video as numpy array (T, H, W, 3) uint8."""
98
+ import imageio
99
+ frames = list(imageio.imiter(path))
100
+ frames = frames[:max_frames]
101
+ return np.array(frames)
102
+
103
+
104
+ def _prep_video_tensor(
105
+ video_np: np.ndarray,
106
+ sample_size: Tuple[int, int],
107
+ ) -> torch.Tensor:
108
+ """Convert video numpy array to pipeline input tensor.
109
+
110
+ Returns: (1, C, T, H, W) float32 in [0, 1]
111
+ """
112
+ video = torch.from_numpy(video_np).float()
113
+ video = video.permute(3, 0, 1, 2) / 255.0 # (C, T, H, W)
114
+ video = F.interpolate(video, sample_size, mode="area")
115
+ return video.unsqueeze(0) # (1, C, T, H, W)
116
+
117
+
118
+ def _prep_mask_tensor(
119
+ mask_np: np.ndarray,
120
+ sample_size: Tuple[int, int],
121
+ use_quadmask: bool = True,
122
+ ) -> torch.Tensor:
123
+ """Convert mask numpy array to pipeline input tensor.
124
+
125
+ Quantizes to quadmask values [0, 63, 127, 255], inverts,
126
+ and normalizes to [0, 1].
127
+
128
+ Returns: (1, 1, T, H, W) float32 in [0, 1]
129
+ """
130
+ mask = torch.from_numpy(mask_np).float()
131
+ if mask.ndim == 4:
132
+ mask = mask[..., 0] # drop channel dim -> (T, H, W)
133
+ mask = F.interpolate(mask.unsqueeze(0), sample_size, mode="area")
134
+ mask = mask.unsqueeze(0) # (1, 1, T, H, W)
135
+
136
+ if use_quadmask:
137
+ # Quantize to 4 values
138
+ mask = torch.where(mask <= 31, 0., mask)
139
+ mask = torch.where((mask > 31) * (mask <= 95), 63., mask)
140
+ mask = torch.where((mask > 95) * (mask <= 191), 127., mask)
141
+ mask = torch.where(mask > 191, 255., mask)
142
+ else:
143
+ # Trimask: 3 values
144
+ mask = torch.where(mask > 192, 255., mask)
145
+ mask = torch.where((mask <= 192) * (mask >= 64), 128., mask)
146
+ mask = torch.where(mask < 64, 0., mask)
147
+
148
+ # Invert and normalize to [0, 1]
149
+ mask = (255. - mask) / 255.
150
+
151
+ return mask
152
+
153
+
154
+ def _temporal_padding(
155
+ tensor: torch.Tensor,
156
+ min_length: int = 85,
157
+ max_length: int = 197,
158
+ dim: int = 2,
159
+ ) -> torch.Tensor:
160
+ """Pad video temporally by mirroring, matching CogVideoX requirements."""
161
+ length = tensor.size(dim)
162
+
163
+ min_len = (length // 4) * 4 + 1
164
+ if min_len < length:
165
+ min_len += 4
166
+ if (min_len / 4) % 2 == 0:
167
+ min_len += 4
168
+ target_length = min(min_len, max_length)
169
+ target_length = max(min_length, target_length)
170
+
171
+ # Truncate if needed
172
+ if dim == 2:
173
+ tensor = tensor[:, :, :target_length]
174
+ else:
175
+ raise NotImplementedError(f"dim={dim} not supported")
176
+
177
+ # Pad by mirroring
178
+ while tensor.size(dim) < target_length:
179
+ flipped = torch.flip(tensor, [dim])
180
+ tensor = torch.cat([tensor, flipped], dim=dim)
181
+
182
+ if dim == 2:
183
+ tensor = tensor[:, :, :target_length]
184
+
185
+ return tensor
186
+
187
+
188
+ def _generate_warped_noise(
189
+ pass1_video_path: str,
190
+ target_shape: Tuple[int, int, int, int],
191
+ device: torch.device,
192
+ dtype: torch.dtype,
193
+ ) -> torch.Tensor:
194
+ """Generate warped noise from Pass 1 output video.
195
+
196
+ Args:
197
+ pass1_video_path: Path to Pass 1 output video.
198
+ target_shape: (latent_T, latent_H, latent_W, latent_C)
199
+ device: Target device.
200
+ dtype: Target dtype.
201
+
202
+ Returns: (1, T, C, H, W) warped noise tensor.
203
+ """
204
+ # Try to import rp and nw for direct warped noise generation
205
+ try:
206
+ # Fix for SLURM: rp crashes parsing GPU UUIDs like "GPU-9fca2b4f-..."
207
+ # Set CUDA_VISIBLE_DEVICES to numeric index if it contains UUIDs
208
+ cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
209
+ if cuda_env and not cuda_env.replace(",", "").isdigit():
210
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
211
+
212
+ import rp
213
+ rp.r._pip_import_autoyes = True
214
+ rp.git_import('CommonSource')
215
+ import rp.git.CommonSource.noise_warp as nw
216
+ return _generate_warped_noise_direct(pass1_video_path, target_shape, device, dtype)
217
+ except ImportError as e:
218
+ print(f"[VOID] rp/noise_warp not available: {e}")
219
+ except Exception as e:
220
+ print(f"[VOID] Warped noise generation via rp failed: {e}")
221
+ import traceback
222
+ traceback.print_exc()
223
+
224
+ # Fallback: try to find and run make_warped_noise.py as subprocess
225
+ script_candidates = [
226
+ os.path.join(os.path.dirname(__file__), "make_warped_noise.py"),
227
+ os.path.join(os.path.dirname(__file__), "..", "inference", "cogvideox_fun", "make_warped_noise.py"),
228
+ ]
229
+ gwf_script = None
230
+ for candidate in script_candidates:
231
+ if os.path.exists(candidate):
232
+ gwf_script = candidate
233
+ break
234
+
235
+ if gwf_script is None:
236
+ raise RuntimeError(
237
+ "Cannot generate warped noise: 'rp' package not installed and "
238
+ "make_warped_noise.py not found. Install 'rp' package or provide "
239
+ "pre-computed warped noise via warped_noise_path parameter."
240
+ )
241
+
242
+ with tempfile.TemporaryDirectory() as tmpdir:
243
+ cmd = [sys.executable, gwf_script, os.path.abspath(pass1_video_path), tmpdir]
244
+ print(f"[VOID] Generating warped noise (this may take a few minutes)...")
245
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
246
+ if result.returncode != 0:
247
+ raise RuntimeError(f"Warped noise generation failed:\n{result.stderr}")
248
+
249
+ # Find the output noises.npy
250
+ video_stem = os.path.splitext(os.path.basename(pass1_video_path))[0]
251
+ noise_path = os.path.join(tmpdir, video_stem, "noises.npy")
252
+ if not os.path.exists(noise_path):
253
+ # Try flat path
254
+ noise_path = os.path.join(tmpdir, "noises.npy")
255
+ if not os.path.exists(noise_path):
256
+ raise RuntimeError(f"Warped noise file not found after generation")
257
+
258
+ return _load_warped_noise(noise_path, target_shape, device, dtype)
259
+
260
+
261
+ def _generate_warped_noise_direct(
262
+ video_path: str,
263
+ target_shape: Tuple[int, int, int, int],
264
+ device: torch.device,
265
+ dtype: torch.dtype,
266
+ ) -> torch.Tensor:
267
+ """Generate warped noise directly using rp package."""
268
+ import rp
269
+ import rp.git.CommonSource.noise_warp as nw
270
+
271
+ video = rp.load_video(video_path)
272
+ video = rp.resize_list(video, length=72)
273
+ video = rp.resize_images_to_hold(video, height=480, width=720)
274
+ video = rp.crop_images(video, height=480, width=720, origin='center')
275
+ video = rp.as_numpy_array(video)
276
+
277
+ FRAME = 2**-1
278
+ FLOW = 2**3
279
+ LATENT = 8
280
+
281
+ output = nw.get_noise_from_video(
282
+ video,
283
+ remove_background=False,
284
+ visualize=False,
285
+ save_files=False,
286
+ noise_channels=16,
287
+ resize_frames=FRAME,
288
+ resize_flow=FLOW,
289
+ downscale_factor=round(FRAME * FLOW) * LATENT,
290
+ )
291
+
292
+ noises = output.numpy_noises # (T, H, W, C)
293
+ return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
294
+
295
+
296
+ def _load_warped_noise(
297
+ noise_path: str,
298
+ target_shape: Tuple[int, int, int, int],
299
+ device: torch.device,
300
+ dtype: torch.dtype,
301
+ ) -> torch.Tensor:
302
+ """Load and resize pre-computed warped noise."""
303
+ noises = np.load(noise_path)
304
+ if noises.dtype == np.float16:
305
+ noises = noises.astype(np.float32)
306
+ # Ensure THWC format
307
+ if noises.shape[1] == 16: # TCHW -> THWC
308
+ noises = np.transpose(noises, (0, 2, 3, 1))
309
+ return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
310
+
311
+
312
+ def _numpy_noise_to_tensor(
313
+ noises: np.ndarray,
314
+ target_shape: Tuple[int, int, int, int],
315
+ device: torch.device,
316
+ dtype: torch.dtype,
317
+ ) -> torch.Tensor:
318
+ """Convert numpy noise (T, H, W, C) to pipeline tensor (1, T, C, H, W)."""
319
+ latent_T, latent_H, latent_W, latent_C = target_shape
320
+
321
+ # Temporal resize if needed
322
+ if noises.shape[0] != latent_T:
323
+ indices = np.linspace(0, noises.shape[0] - 1, latent_T)
324
+ lower = np.floor(indices).astype(int)
325
+ upper = np.ceil(indices).astype(int)
326
+ frac = indices - lower
327
+ noises = noises[lower] * (1 - frac[:, None, None, None]) + noises[upper] * frac[:, None, None, None]
328
+
329
+ # Spatial resize if needed
330
+ if noises.shape[1] != latent_H or noises.shape[2] != latent_W:
331
+ resized = np.zeros((latent_T, latent_H, latent_W, latent_C), dtype=noises.dtype)
332
+ for t in range(latent_T):
333
+ for c in range(latent_C):
334
+ resized[t, :, :, c] = cv2.resize(
335
+ noises[t, :, :, c], (latent_W, latent_H),
336
+ interpolation=cv2.INTER_LINEAR,
337
+ )
338
+ noises = resized
339
+
340
+ # Convert to tensor: (T, H, W, C) -> (1, T, C, H, W)
341
+ tensor = torch.from_numpy(noises).permute(0, 3, 1, 2).unsqueeze(0)
342
+ return tensor.to(device=device, dtype=dtype)
343
+
344
+
345
+ class VOIDPipeline(CogVideoXFunInpaintPipeline):
346
+ """
347
+ VOID: Video Object and Interaction Deletion.
348
+
349
+ Removes objects and their physical interactions from videos using
350
+ quadmask-conditioned video inpainting.
351
+ """
352
+
353
+ @classmethod
354
+ def from_pretrained(
355
+ cls,
356
+ pretrained_model_name_or_path: str,
357
+ void_pass: int = 1,
358
+ base_model: str = BASE_MODEL_REPO,
359
+ torch_dtype: torch.dtype = torch.bfloat16,
360
+ **kwargs,
361
+ ):
362
+ """
363
+ Load the VOID pipeline.
364
+
365
+ Args:
366
+ pretrained_model_name_or_path: HF repo ID or local path containing
367
+ VOID checkpoint files (void_pass1.safetensors, etc.)
368
+ void_pass: Which pass checkpoint to load (1 or 2). Default: 1.
369
+ base_model: HF repo ID for the base CogVideoX-Fun model.
370
+ torch_dtype: Weight dtype. Default: torch.bfloat16.
371
+ """
372
+ if void_pass not in PASS_CHECKPOINTS:
373
+ raise ValueError(f"void_pass must be 1 or 2, got {void_pass}")
374
+
375
+ # --- Download VOID checkpoint ---
376
+ checkpoint_name = PASS_CHECKPOINTS[void_pass]
377
+ print(f"[VOID] Loading Pass {void_pass} checkpoint...")
378
+
379
+ if os.path.isdir(pretrained_model_name_or_path):
380
+ checkpoint_path = os.path.join(pretrained_model_name_or_path, checkpoint_name)
381
+ else:
382
+ checkpoint_path = hf_hub_download(
383
+ repo_id=pretrained_model_name_or_path,
384
+ filename=checkpoint_name,
385
+ )
386
+
387
+ # --- Download and load base model ---
388
+ print(f"[VOID] Loading base model: {base_model}")
389
+ base_model_path = snapshot_download(repo_id=base_model)
390
+
391
+ # Transformer (with VAE mask channels)
392
+ print("[VOID] Loading transformer...")
393
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
394
+ base_model_path,
395
+ subfolder="transformer",
396
+ low_cpu_mem_usage=True,
397
+ torch_dtype=torch_dtype,
398
+ use_vae_mask=True,
399
+ )
400
+
401
+ # Merge VOID weights
402
+ print(f"[VOID] Merging Pass {void_pass} weights...")
403
+ transformer = _merge_void_weights(transformer, checkpoint_path)
404
+ transformer = transformer.to(torch_dtype)
405
+
406
+ # VAE
407
+ print("[VOID] Loading VAE...")
408
+ vae = AutoencoderKLCogVideoX.from_pretrained(
409
+ base_model_path, subfolder="vae"
410
+ ).to(torch_dtype)
411
+
412
+ # Tokenizer + Text encoder
413
+ print("[VOID] Loading tokenizer and text encoder...")
414
+ from transformers import T5Tokenizer, T5EncoderModel
415
+ tokenizer = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
416
+ text_encoder = T5EncoderModel.from_pretrained(
417
+ base_model_path, subfolder="text_encoder", torch_dtype=torch_dtype,
418
+ )
419
+
420
+ # Scheduler
421
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(
422
+ base_model_path, subfolder="scheduler"
423
+ )
424
+
425
+ # Build pipeline
426
+ pipe = cls(
427
+ tokenizer=tokenizer,
428
+ text_encoder=text_encoder,
429
+ vae=vae,
430
+ transformer=transformer,
431
+ scheduler=scheduler,
432
+ )
433
+ pipe._void_pass = void_pass
434
+
435
+ print("[VOID] Pipeline ready!")
436
+ return pipe
437
+
438
+ def inpaint(
439
+ self,
440
+ video_path: str,
441
+ mask_path: str,
442
+ prompt: str,
443
+ negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
444
+ height: int = 384,
445
+ width: int = 672,
446
+ num_inference_steps: int = 30,
447
+ guidance_scale: float = 1.0,
448
+ strength: float = 1.0,
449
+ temporal_window_size: int = 85,
450
+ max_video_length: int = 197,
451
+ fps: int = 12,
452
+ seed: int = 42,
453
+ pass1_video: Optional[str] = None,
454
+ warped_noise_path: Optional[str] = None,
455
+ use_quadmask: bool = True,
456
+ ) -> VOIDOutput:
457
+ """
458
+ Run VOID inpainting on a video.
459
+
460
+ Args:
461
+ video_path: Path to input video (mp4).
462
+ mask_path: Path to quadmask video (mp4). Grayscale with values:
463
+ 0=object to remove, 63=overlap, 127=affected region, 255=background.
464
+ prompt: Text description of the desired result after removal.
465
+ E.g., "A lime falls on the table."
466
+ negative_prompt: Negative prompt for generation quality.
467
+ height: Output height (default 384).
468
+ width: Output width (default 672).
469
+ num_inference_steps: Denoising steps (default 30).
470
+ guidance_scale: CFG scale (default 1.0 = no CFG).
471
+ strength: Denoising strength (default 1.0).
472
+ temporal_window_size: Frames per inference window (default 85).
473
+ max_video_length: Max frames to process (default 197).
474
+ fps: Output FPS (default 12).
475
+ seed: Random seed (default 42).
476
+ pass1_video: Path to Pass 1 output video, for Pass 2 warped noise init.
477
+ warped_noise_path: Path to pre-computed warped noise (.npy).
478
+ use_quadmask: Use 4-value quadmask (default True). Set False for trimask.
479
+
480
+ Returns:
481
+ VOIDOutput with .video (uint8) and .save() method.
482
+ """
483
+ sample_size = (height, width)
484
+
485
+ # Align video length to VAE temporal compression ratio
486
+ vae_temporal_ratio = self.vae.config.temporal_compression_ratio
487
+ video_length = int((max_video_length - 1) // vae_temporal_ratio * vae_temporal_ratio) + 1
488
+
489
+ # --- Load and prep video ---
490
+ print("[VOID] Loading video and mask...")
491
+ vid_np = _load_video(video_path, video_length)
492
+ mask_np = _load_video(mask_path, video_length)
493
+
494
+ video = _prep_video_tensor(vid_np, sample_size)
495
+ mask = _prep_mask_tensor(mask_np, sample_size, use_quadmask=use_quadmask)
496
+
497
+ # Temporal padding
498
+ video = _temporal_padding(video, min_length=temporal_window_size, max_length=max_video_length)
499
+ mask = _temporal_padding(mask, min_length=temporal_window_size, max_length=max_video_length)
500
+
501
+ num_frames = min(video.shape[2], temporal_window_size)
502
+
503
+ print(f"[VOID] Video: {video.shape}, Mask: {mask.shape}, Frames: {num_frames}")
504
+
505
+ # --- Handle warped noise for Pass 2 ---
506
+ latents = None
507
+ if warped_noise_path is not None or pass1_video is not None:
508
+ latent_T = (num_frames - 1) // 4 + 1
509
+ latent_H = height // 8
510
+ latent_W = width // 8
511
+ latent_C = 16
512
+ target_shape = (latent_T, latent_H, latent_W, latent_C)
513
+
514
+ if warped_noise_path is not None:
515
+ print(f"[VOID] Loading pre-computed warped noise from {warped_noise_path}")
516
+ latents = _load_warped_noise(
517
+ warped_noise_path, target_shape,
518
+ device=torch.device("cpu"), dtype=torch.bfloat16,
519
+ )
520
+ else:
521
+ print(f"[VOID] Generating warped noise from Pass 1 output...")
522
+ latents = _generate_warped_noise(
523
+ pass1_video, target_shape,
524
+ device=torch.device("cpu"), dtype=torch.bfloat16,
525
+ )
526
+ print(f"[VOID] Warped noise: {latents.shape}, mean={latents.mean():.4f}, std={latents.std():.4f}")
527
+
528
+ # --- Run inference ---
529
+ generator = torch.Generator(device="cpu").manual_seed(seed)
530
+
531
+ print(f"[VOID] Running inference ({num_frames} frames, {num_inference_steps} steps)...")
532
+ with torch.no_grad():
533
+ output = self(
534
+ prompt=prompt,
535
+ negative_prompt=negative_prompt,
536
+ num_frames=num_frames,
537
+ height=height,
538
+ width=width,
539
+ guidance_scale=guidance_scale,
540
+ num_inference_steps=num_inference_steps,
541
+ generator=generator,
542
+ video=video,
543
+ mask_video=mask,
544
+ strength=strength,
545
+ use_trimask=True,
546
+ use_vae_mask=True,
547
+ latents=latents,
548
+ ).videos
549
+
550
+ # --- Process output ---
551
+ if isinstance(output, np.ndarray):
552
+ output = torch.from_numpy(output)
553
+
554
+ # output is (B, C, T, H, W) in [0, 1]
555
+ video_float = output
556
+ video_uint8 = (output[0].permute(1, 2, 3, 0).clamp(0, 1) * 255).to(torch.uint8)
557
+
558
+ print(f"[VOID] Done! Output: {video_uint8.shape}")
559
+ return VOIDOutput(video=video_uint8, video_float=video_float)