|
import os |
|
import torch |
|
import numpy as np |
|
import math |
|
import argparse |
|
from decord import VideoReader |
|
from diffusers import AutoencoderKLCogVideoX |
|
from safetensors.torch import save_file |
|
import tqdm |
|
import random |
|
|
|
def encode_video(video, vae): |
|
video = video[None].permute(0, 2, 1, 3, 4).contiguous() |
|
video = video.to(vae.device, dtype=vae.dtype) |
|
latent_dist = vae.encode(video).latent_dist |
|
latent = latent_dist.sample() * vae.config.scaling_factor |
|
return latent |
|
|
|
def add_dashed_rays_to_video(video_tensor, num_perp_samples=50, density_decay=0.075): |
|
T, C, H, W = video_tensor.shape |
|
max_length = int((H**2 + W**2) ** 0.5) + 10 |
|
center = torch.tensor([W / 2, H / 2]) |
|
theta = torch.rand(1).item() * 2 * math.pi |
|
direction = torch.tensor([math.cos(theta), math.sin(theta)]) |
|
direction = direction / direction.norm() |
|
d_perp = torch.tensor([-direction[1], direction[0]]) |
|
half_len = max(H, W) // 2 |
|
positions = torch.linspace(-half_len, half_len, num_perp_samples) |
|
perp_coords = center[None, :] + positions[:, None] * d_perp[None, :] |
|
x0, y0 = perp_coords[:, 0], perp_coords[:, 1] |
|
steps = [] |
|
dist = 0 |
|
while dist < max_length: |
|
steps.append(dist) |
|
dist += 1.0 + density_decay * dist |
|
steps = torch.tensor(steps) |
|
S = len(steps) |
|
dxdy = direction[None, :] * steps[:, None] |
|
all_xy = perp_coords[:, None, :] + dxdy[None, :, :] |
|
all_xy = all_xy.reshape(-1, 2) |
|
all_x = all_xy[:, 0].round().long() |
|
all_y = all_xy[:, 1].round().long() |
|
valid = (0 <= all_x) & (all_x < W) & (0 <= all_y) & (all_y < H) |
|
all_x = all_x[valid] |
|
all_y = all_y[valid] |
|
x0r = x0.round().long().clamp(0, W - 1) |
|
y0r = y0.round().long().clamp(0, H - 1) |
|
frame0 = video_tensor[0] |
|
base_colors = frame0[:, y0r, x0r] |
|
base_colors = base_colors.repeat_interleave(S, dim=1)[:, valid] |
|
video_out = video_tensor.clone() |
|
offsets = [(0, 0), (0, 1), (1, 0), (1, 1)] |
|
for dxo, dyo in offsets: |
|
ox = all_x + dxo |
|
oy = all_y + dyo |
|
inside = (0 <= ox) & (ox < W) & (0 <= oy) & (oy < H) |
|
ox = ox[inside] |
|
oy = oy[inside] |
|
colors = base_colors[:, inside] |
|
for c in range(C): |
|
video_out[1:, c, oy, ox] = colors[c][None, :].expand(T - 1, -1) |
|
return video_out |
|
|
|
def main(args): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
vae = AutoencoderKLCogVideoX.from_pretrained(args.pretrained_model_path, subfolder="vae") |
|
vae.requires_grad_(False) |
|
vae = vae.to(device, dtype=torch.float16) |
|
|
|
masked_video_path = os.path.join(args.video_root, "masked_videos") |
|
source_video_path = os.path.join(args.video_root, "videos") |
|
joint_latent_path = os.path.join(args.video_root, "joint_latents") |
|
os.makedirs(joint_latent_path, exist_ok=True) |
|
|
|
all_video_names = sorted(os.listdir(source_video_path)) |
|
video_names = all_video_names[args.start_idx : args.end_idx] |
|
|
|
for video_name in tqdm.tqdm(video_names, desc=f"GPU {args.gpu_id}"): |
|
masked_video_file = os.path.join(masked_video_path, video_name) |
|
source_video_file = os.path.join(source_video_path, video_name) |
|
output_file = os.path.join(joint_latent_path, video_name.replace('.mp4', '.safetensors')) |
|
|
|
if not os.path.exists(masked_video_file): |
|
print(f"Skipping {video_name}, masked video not found.") |
|
continue |
|
if os.path.exists(output_file): |
|
continue |
|
|
|
try: |
|
vr = VideoReader(source_video_file) |
|
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
video = (video / 255.0) * 2 - 1 |
|
source_latent = encode_video(video, vae) |
|
|
|
vr = VideoReader(masked_video_file) |
|
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
video = (video / 255.0) * 2 - 1 |
|
video = add_dashed_rays_to_video(video) |
|
masked_latent = encode_video(video, vae) |
|
|
|
source_latent = source_latent.to("cpu") |
|
masked_latent = masked_latent.to("cpu") |
|
cated_latent = torch.cat([source_latent, masked_latent], dim=2) |
|
save_file({'joint_latents': cated_latent}, output_file) |
|
|
|
except Exception as e: |
|
print(f"[GPU {args.gpu_id}] Error processing {video_name}: {e}") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--video_root", type=str, required=True) |
|
parser.add_argument("--pretrained_model_path", type=str, required=True) |
|
parser.add_argument("--start_idx", type=int, required=True) |
|
parser.add_argument("--end_idx", type=int, required=True) |
|
parser.add_argument("--gpu_id", type=int, required=True) |
|
args = parser.parse_args() |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) |
|
main(args) |
|
|