gnn_wm / Ctrl-World-Graph /scripts /eval_graph_video.py
EndeavourDD's picture
Add files using upload-large-folder tool
da7bf91 verified
import sys
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from graphwm.config_graph import GraphWMArgs
from graphwm.dataset.collate_graph_wm import collate_graph_wm
from graphwm.models.ctrl_world_graph import CtrlWorldGraph
from graphwm.original_ctrl_world import import_original_modules
from scripts.train_wm_graph import build_datasets
def write_video(path: Path, frames: np.ndarray, fps: int = 5):
try:
import mediapy as media
media.write_video(str(path), frames, fps=fps)
return
except Exception:
import imageio.v2 as imageio
imageio.mimwrite(str(path), frames, fps=fps, macro_block_size=None)
def decode_latents_to_video(pipeline, latents: torch.Tensor, decode_chunk_size: int):
bsz, num_frames = latents.shape[:2]
flat = latents.flatten(0, 1)
decoded = []
for i in range(0, flat.shape[0], decode_chunk_size):
chunk = flat[i:i + decode_chunk_size] / pipeline.vae.config.scaling_factor
sample = pipeline.vae.decode(chunk, num_frames=chunk.shape[0]).sample
decoded.append(sample)
video = torch.cat(decoded, dim=0).reshape(bsz, num_frames, -1, flat.shape[-2] * 8, flat.shape[-1] * 8)
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).byte()
return video.permute(0, 1, 3, 4, 2).cpu().numpy()
def main():
args = GraphWMArgs()
args.ckpt_path = '/workspace/Ctrl-World-Graph/model_ckpt/ctrl_world_graph/checkpoint-100000.pt'
args.eval_batch_size = 1
args.num_workers = 0
_, val_ds = build_datasets(args)
if val_ds is None or len(val_ds) == 0:
raise ValueError('Validation dataset is empty.')
sample = val_ds[0]
batch = collate_graph_wm([sample])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CtrlWorldGraph(args).to(device)
state_dict = torch.load(args.ckpt_path, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()
original = import_original_modules(args.ctrl_world_root)
CtrlWorldDiffusionPipeline = original['CtrlWorldDiffusionPipeline']
batch['rgb'] = batch['rgb'].to(device)
batch['graph_seq'] = [g.to(device) for g in batch['graph_seq']]
with torch.no_grad():
latents = model.encode_rgb_to_latents(batch['rgb'])
graph_hidden = model.encode_graph_condition(batch).to(device=device, dtype=model.unet.dtype)
current_latent = latents[:, args.num_history]
history = latents[:, :args.num_history] if args.num_history > 0 else None
_, pred_latents = CtrlWorldDiffusionPipeline.__call__(
model.pipeline,
image=current_latent,
text=graph_hidden,
width=args.width,
height=args.height,
num_frames=args.num_frames,
history=history,
num_inference_steps=args.num_inference_steps,
decode_chunk_size=args.decode_chunk_size,
max_guidance_scale=args.guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
output_type='latent',
return_dict=False,
frame_level_cond=args.frame_level_cond,
his_cond_zero=args.his_cond_zero,
)
pred_video = decode_latents_to_video(model.pipeline, pred_latents, args.decode_chunk_size)[0]
gt_video = (batch['rgb'][0].permute(0, 2, 3, 1).clamp(0, 1) * 255).byte().cpu().numpy()
compare_video = np.concatenate([gt_video, pred_video], axis=2)
out_dir = Path('/workspace/Ctrl-World-Graph/eval_videos')
out_dir.mkdir(parents=True, exist_ok=True)
pred_path = out_dir / 'val0_pred.mp4'
gt_path = out_dir / 'val0_gt.mp4'
compare_path = out_dir / 'val0_compare.mp4'
write_video(pred_path, pred_video, fps=args.fps)
write_video(gt_path, gt_video, fps=args.fps)
write_video(compare_path, compare_video, fps=args.fps)
print('saved_pred=', pred_path)
print('saved_gt=', gt_path)
print('saved_compare=', compare_path)
print('frame_ids=', batch['frame_ids'][0].tolist())
if __name__ == '__main__':
main()