Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from graphwm.config_graph import GraphWMArgs | |
| from graphwm.dataset.collate_graph_wm import collate_graph_wm | |
| from graphwm.models.ctrl_world_graph import CtrlWorldGraph | |
| from graphwm.original_ctrl_world import import_original_modules | |
| from scripts.train_wm_graph import build_datasets | |
| def write_video(path: Path, frames: np.ndarray, fps: int = 5): | |
| try: | |
| import mediapy as media | |
| media.write_video(str(path), frames, fps=fps) | |
| return | |
| except Exception: | |
| import imageio.v2 as imageio | |
| imageio.mimwrite(str(path), frames, fps=fps, macro_block_size=None) | |
| def decode_latents_to_video(pipeline, latents: torch.Tensor, decode_chunk_size: int): | |
| bsz, num_frames = latents.shape[:2] | |
| flat = latents.flatten(0, 1) | |
| decoded = [] | |
| for i in range(0, flat.shape[0], decode_chunk_size): | |
| chunk = flat[i:i + decode_chunk_size] / pipeline.vae.config.scaling_factor | |
| sample = pipeline.vae.decode(chunk, num_frames=chunk.shape[0]).sample | |
| decoded.append(sample) | |
| video = torch.cat(decoded, dim=0).reshape(bsz, num_frames, -1, flat.shape[-2] * 8, flat.shape[-1] * 8) | |
| video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).byte() | |
| return video.permute(0, 1, 3, 4, 2).cpu().numpy() | |
| def main(): | |
| args = GraphWMArgs() | |
| args.ckpt_path = '/workspace/Ctrl-World-Graph/model_ckpt/ctrl_world_graph/checkpoint-100000.pt' | |
| args.eval_batch_size = 1 | |
| args.num_workers = 0 | |
| _, val_ds = build_datasets(args) | |
| if val_ds is None or len(val_ds) == 0: | |
| raise ValueError('Validation dataset is empty.') | |
| sample = val_ds[0] | |
| batch = collate_graph_wm([sample]) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = CtrlWorldGraph(args).to(device) | |
| state_dict = torch.load(args.ckpt_path, map_location='cpu') | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| original = import_original_modules(args.ctrl_world_root) | |
| CtrlWorldDiffusionPipeline = original['CtrlWorldDiffusionPipeline'] | |
| batch['rgb'] = batch['rgb'].to(device) | |
| batch['graph_seq'] = [g.to(device) for g in batch['graph_seq']] | |
| with torch.no_grad(): | |
| latents = model.encode_rgb_to_latents(batch['rgb']) | |
| graph_hidden = model.encode_graph_condition(batch).to(device=device, dtype=model.unet.dtype) | |
| current_latent = latents[:, args.num_history] | |
| history = latents[:, :args.num_history] if args.num_history > 0 else None | |
| _, pred_latents = CtrlWorldDiffusionPipeline.__call__( | |
| model.pipeline, | |
| image=current_latent, | |
| text=graph_hidden, | |
| width=args.width, | |
| height=args.height, | |
| num_frames=args.num_frames, | |
| history=history, | |
| num_inference_steps=args.num_inference_steps, | |
| decode_chunk_size=args.decode_chunk_size, | |
| max_guidance_scale=args.guidance_scale, | |
| fps=args.fps, | |
| motion_bucket_id=args.motion_bucket_id, | |
| output_type='latent', | |
| return_dict=False, | |
| frame_level_cond=args.frame_level_cond, | |
| his_cond_zero=args.his_cond_zero, | |
| ) | |
| pred_video = decode_latents_to_video(model.pipeline, pred_latents, args.decode_chunk_size)[0] | |
| gt_video = (batch['rgb'][0].permute(0, 2, 3, 1).clamp(0, 1) * 255).byte().cpu().numpy() | |
| compare_video = np.concatenate([gt_video, pred_video], axis=2) | |
| out_dir = Path('/workspace/Ctrl-World-Graph/eval_videos') | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| pred_path = out_dir / 'val0_pred.mp4' | |
| gt_path = out_dir / 'val0_gt.mp4' | |
| compare_path = out_dir / 'val0_compare.mp4' | |
| write_video(pred_path, pred_video, fps=args.fps) | |
| write_video(gt_path, gt_video, fps=args.fps) | |
| write_video(compare_path, compare_video, fps=args.fps) | |
| print('saved_pred=', pred_path) | |
| print('saved_gt=', gt_path) | |
| print('saved_compare=', compare_path) | |
| print('frame_ids=', batch['frame_ids'][0].tolist()) | |
| if __name__ == '__main__': | |
| main() | |