BonanDing commited on
Commit
8652b14
·
1 Parent(s): f40de20

update lfs

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. LICENSE.md +14 -0
  3. README.md +201 -0
  4. algorithms/README.md +21 -0
  5. algorithms/__init__.py +0 -0
  6. algorithms/common/README.md +5 -0
  7. algorithms/common/__init__.py +0 -0
  8. algorithms/common/base_algo.py +21 -0
  9. algorithms/common/base_pytorch_algo.py +252 -0
  10. algorithms/common/metrics/__init__.py +3 -0
  11. algorithms/common/metrics/fid.py +1 -0
  12. algorithms/common/metrics/fvd.py +158 -0
  13. algorithms/common/metrics/lpips.py +1 -0
  14. algorithms/common/models/__init__.py +0 -0
  15. algorithms/common/models/cnn.py +141 -0
  16. algorithms/common/models/mlp.py +22 -0
  17. algorithms/worldmem/__init__.py +2 -0
  18. algorithms/worldmem/df_base.py +307 -0
  19. algorithms/worldmem/df_video.py +926 -0
  20. algorithms/worldmem/models/attention.py +342 -0
  21. algorithms/worldmem/models/cameractrl_module.py +12 -0
  22. algorithms/worldmem/models/diffusion.py +520 -0
  23. algorithms/worldmem/models/dit.py +572 -0
  24. algorithms/worldmem/models/pose_prediction.py +42 -0
  25. algorithms/worldmem/models/rotary_embedding_torch.py +302 -0
  26. algorithms/worldmem/models/utils.py +163 -0
  27. algorithms/worldmem/models/vae.py +359 -0
  28. algorithms/worldmem/pose_prediction.py +374 -0
  29. app.py +576 -0
  30. assets/desert.png +3 -0
  31. assets/ice_plains.png +3 -0
  32. assets/place.png +3 -0
  33. assets/plains.png +3 -0
  34. assets/rain_sunflower_plains.png +3 -0
  35. assets/savanna.png +3 -0
  36. assets/sunflower_plains.png +3 -0
  37. assets/worldmem_logo.png +3 -0
  38. calculate_fid.py +277 -0
  39. configurations/algorithm/base_algo.yaml +3 -0
  40. configurations/algorithm/base_pytorch_algo.yaml +4 -0
  41. configurations/algorithm/df_base.yaml +42 -0
  42. configurations/algorithm/df_video_worldmemminecraft.yaml +38 -0
  43. configurations/dataset/base_dataset.yaml +3 -0
  44. configurations/dataset/base_video.yaml +14 -0
  45. configurations/dataset/video_minecraft.yaml +14 -0
  46. configurations/experiment/base_experiment.yaml +2 -0
  47. configurations/experiment/base_pytorch.yaml +50 -0
  48. configurations/experiment/exp_video.yaml +31 -0
  49. configurations/huggingface.yaml +60 -0
  50. configurations/training.yaml +16 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ <<<<<<< HEAD
37
+ =======
38
+ assets/* filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ >>>>>>> def529c (Baseline WorldMem)
LICENSE.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # S-Lab License 1.0
2
+
3
+ Copyright 2025 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
9
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10
+ 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
11
+
12
+
13
+ ---
14
+ For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg)
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <br>
3
+ <p align="center">
4
+
5
+ <p align="center">
6
+ <img src="assets/worldmem_logo.png" alt="WORLDMEM Icon" width="80"/>
7
+ </p>
8
+ <h1 align="center"><strong>WorldMem: Long-term Consistent World Simulation <br> with Memory</strong></h1>
9
+ <p align="center"><span><a href=""></a></span>
10
+ <a href="https://xizaoqu.github.io">Zeqi Xiao<sup>1</sup></a>
11
+ <a href="https://nirvanalan.github.io/">Yushi Lan<sup>1</sup></a>
12
+ <a href="https://zhouyifan.net/about/">Yifan Zhou<sup>1</sup></a>
13
+ <a href="https://vicky0522.github.io/Wenqi-Ouyang/">Wenqi Ouyang<sup>1</sup></a>
14
+ <a href="https://williamyang1991.github.io/">Shuai Yang<sup>2</sup></a>
15
+ <a href="https://zengyh1900.github.io/">Yanhong Zeng<sup>3</sup></a>
16
+ <a href="https://xingangpan.github.io/">Xingang Pan<sup>1</sup></a> <br>
17
+ <sup>1</sup>S-Lab, Nanyang Technological University, <br> <sup>2</sup>Wangxuan Institute of Computer Technology, Peking University,<br> <sup>3</sup>Shanghai AI Laboratory
18
+ </p>
19
+ </p>
20
+
21
+ <p align="center">
22
+ <a href="https://arxiv.org/abs/2504.12369" target='_blank'>
23
+ <img src="https://img.shields.io/badge/arXiv-2504.12369-blue?">
24
+ </a>
25
+ <a href="https://xizaoqu.github.io/worldmem/" target='_blank'>
26
+ <img src="https://img.shields.io/badge/Project-&#x1F680-blue">
27
+ </a>
28
+ <a href="https://huggingface.co/spaces/yslan/worldmem" target="_blank">
29
+ <img src="https://img.shields.io/badge/🤗 HuggingFace-Demo-orange" />
30
+ </a>
31
+ </p>
32
+
33
+ https://github.com/user-attachments/assets/fb8a32e2-9470-4819-a93d-c38caf76d72c
34
+
35
+
36
+ ## Installation
37
+
38
+ ```
39
+ conda create python=3.10 -n worldmem
40
+ conda activate worldmem
41
+ pip install -r requirements.txt
42
+ conda install -c conda-forge ffmpeg=4.3.2
43
+ ```
44
+
45
+
46
+ ## Quick start
47
+
48
+ ```
49
+ python app.py
50
+ ```
51
+
52
+ ## Run
53
+
54
+ To enable cloud logging with [Weights & Biases (wandb)](https://wandb.ai/site), follow these steps:
55
+
56
+ 1. Sign up for a wandb account.
57
+ 2. Run the following command to log in:
58
+
59
+ ```bash
60
+ wandb login
61
+ ```
62
+
63
+ 3. Open `configurations/training.yaml` and set the `entity` and `project` field to your wandb username.
64
+
65
+ ---
66
+
67
+ ### Training
68
+
69
+ Download pretrained weights from [Oasis](https://github.com/etched-ai/open-oasis).
70
+
71
+ Training the model on 4 H100 GPUs, it converges after approximately 500K steps.
72
+ We observe that gradually increasing task difficulty improves performance. Thus, we adopt a multi-stage training strategy:
73
+ ,
74
+ ```bash
75
+ sh train_stage_1.sh # Small range, no vertical turning
76
+ sh train_stage_2.sh # Large range, no vertical turning
77
+ sh train_stage_3.sh # Large range, with vertical turning
78
+ ```
79
+
80
+ To resume training from a previous checkpoint, configure the `resume` and `output_dir` variables in the corresponding `.sh` script.
81
+
82
+ ---
83
+
84
+ ### Inference
85
+
86
+ To run inference:
87
+
88
+ ```bash
89
+ sh infer.sh
90
+ ```
91
+
92
+ You can either **load the diffusion model and VAE separately**:
93
+
94
+ ```bash
95
+ +diffusion_model_path=zeqixiao/worldmem_checkpoints/diffusion_only.ckpt \
96
+ +vae_path=zeqixiao/worldmem_checkpoints/vae_only.ckpt \
97
+ +customized_load=true \
98
+ +seperate_load=true \
99
+ ```
100
+
101
+ Or **load a combined checkpoint**:
102
+
103
+ ```bash
104
+ +load=your_model_path \
105
+ +customized_load=true \
106
+ +seperate_load=false \
107
+ ```
108
+
109
+ ### Evaluation
110
+
111
+ To run evaluation:
112
+
113
+ ```bash
114
+ sh evaluate.sh
115
+ ```
116
+
117
+ This script reproduces the results in Table 1 (beyond context window). It will generate PSNR and Lpips. Evaluating 1 case on 1 A100 GPU takes approximately 6 minutes. You can adjust `experiment.test.limit_batch` to specify the number of cases to evaluate.
118
+
119
+ Visual results will be saved by default to a timestamped directory (e.g., `outputs/2025-11-30/00-02-42`).
120
+
121
+ To calculate the FID score, run:
122
+
123
+ ```bash
124
+ python calculate_fid.py --videos_dir <path_to_videos>
125
+ ```
126
+
127
+ For example:
128
+
129
+ ```bash
130
+ python calculate_fid.py --videos_dir outputs/2025-11-30/00-02-42/videos/test_vis
131
+ ```
132
+
133
+ **Expected Results:**
134
+
135
+ | Metric | Value |
136
+ |--------|--------|
137
+ | PSNR | 24.01 |
138
+ | LPIPS | 0.1667 |
139
+ | FID | 15.13 |
140
+
141
+ *Note: FID is computed over 5000 frames.*
142
+
143
+ ---
144
+
145
+ ## Dataset
146
+
147
+ Download the Minecraft dataset from [Hugging Face](https://huggingface.co/datasets/zeqixiao/worldmem_minecraft_dataset)
148
+
149
+ Place the dataset in the following directory structure:
150
+
151
+ ```
152
+ data/
153
+ └── minecraft/
154
+ ├── training/
155
+ └── validation/
156
+ └── test/
157
+ ```
158
+
159
+ ## Data Generation
160
+
161
+ After setting up the environment as described in [MineDojo's GitHub repository](https://github.com/MineDojo/MineDojo), you can generate data using the following command:
162
+
163
+ ```bash
164
+ xvfb-run -a python data_generator.py -o data/test -z 4 --env_type plains
165
+ ```
166
+
167
+ **Parameters:**
168
+ - `-o`: Output directory for generated data
169
+ - `-z`: Number of parallel workers
170
+ - `--env_type`: Environment type (e.g., `plains`)
171
+
172
+
173
+ ## TODO
174
+
175
+ - [x] Release inference models and weights;
176
+ - [x] Release training pipeline on Minecraft;
177
+ - [x] Release training data on Minecraft;
178
+ - [x] Release evaluation scripts and data generator.
179
+
180
+
181
+
182
+ ## 🔗 Citation
183
+
184
+ If you find our work helpful, please cite:
185
+
186
+ ```
187
+ @misc{xiao2025worldmemlongtermconsistentworld,
188
+ title={WORLDMEM: Long-term Consistent World Simulation with Memory},
189
+ author={Zeqi Xiao and Yushi Lan and Yifan Zhou and Wenqi Ouyang and Shuai Yang and Yanhong Zeng and Xingang Pan},
190
+ year={2025},
191
+ eprint={2504.12369},
192
+ archivePrefix={arXiv},
193
+ primaryClass={cs.CV},
194
+ url={https://arxiv.org/abs/2504.12369},
195
+ }
196
+ ```
197
+
198
+ ## 👏 Acknowledgements
199
+ - [Diffusion Forcing](https://github.com/buoyancy99/diffusion-forcing): Diffusion Forcing provides flexible training and inference strategies for our methods.
200
+ - [Minedojo](https://github.com/MineDojo/MineDojo): We collect our Minecraft dataset from Minedojo.
201
+ - [Open-oasis](https://github.com/etched-ai/open-oasis): Our model architecture is based on Open-oasis. We also use pretrained VAE and DiT weight from it.
algorithms/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # algorithms
2
+
3
+ `algorithms` folder is designed to contain implementation of algorithms or models.
4
+ Content in `algorithms` can be loosely grouped components (e.g. models) or an algorithm has already has all
5
+ components chained together (e.g. Lightning Module, RL algo).
6
+ You should create a folder name after your own algorithm or baselines in it.
7
+
8
+ Two example can be found in `examples` subfolder.
9
+
10
+ The `common` subfolder is designed to contain general purpose classes that's useful for many projects, e.g MLP.
11
+
12
+ You should not run any `.py` file from algorithms folder.
13
+ Instead, you write unit tests / debug python files in `debug` and launch script in `experiments`.
14
+
15
+ You are discouraged from putting visualization utilities in algorithms, as those should go to `utils` in project root.
16
+
17
+ Each algorithm class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/algorithm` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
18
+
19
+ ---
20
+
21
+ This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
algorithms/__init__.py ADDED
File without changes
algorithms/common/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ THis folder contains models / algorithms that are considered general for many algorithms.
2
+
3
+ ---
4
+
5
+ This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
algorithms/common/__init__.py ADDED
File without changes
algorithms/common/base_algo.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ from omegaconf import DictConfig
5
+
6
+
7
+ class BaseAlgo(ABC):
8
+ """
9
+ A base class for generic algorithms.
10
+ """
11
+
12
+ def __init__(self, cfg: DictConfig):
13
+ super().__init__()
14
+ self.cfg = cfg
15
+
16
+ @abstractmethod
17
+ def run(*args: Any, **kwargs: Any) -> Any:
18
+ """
19
+ Run the algorithm.
20
+ """
21
+ raise NotImplementedError
algorithms/common/base_pytorch_algo.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import warnings
3
+ from typing import Any, Union, Sequence, Optional
4
+
5
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
6
+ from omegaconf import DictConfig
7
+ import lightning.pytorch as pl
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import wandb
12
+ import einops
13
+
14
+
15
+ class BasePytorchAlgo(pl.LightningModule, ABC):
16
+ """
17
+ A base class for Pytorch algorithms using Pytorch Lightning.
18
+ See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
19
+ """
20
+
21
+ def __init__(self, cfg: DictConfig):
22
+ super().__init__()
23
+ self.cfg = cfg
24
+ self._build_model()
25
+
26
+ @abstractmethod
27
+ def _build_model(self):
28
+ """
29
+ Create all pytorch nn.Modules here.
30
+ """
31
+ raise NotImplementedError
32
+
33
+ @abstractmethod
34
+ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
35
+ r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
36
+ logger.
37
+
38
+ Args:
39
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
40
+ batch_idx: The index of this batch.
41
+ dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
42
+
43
+ Return:
44
+ Any of these options:
45
+ - :class:`~torch.Tensor` - The loss tensor
46
+ - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
47
+ - ``None`` - Skip to the next batch. This is only supported for automatic optimization.
48
+ This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
49
+
50
+ In this step you'd normally do the forward pass and calculate the loss for a batch.
51
+ You can also do fancier things like multiple forward passes or something model specific.
52
+
53
+ Example::
54
+
55
+ def training_step(self, batch, batch_idx):
56
+ x, y, z = batch
57
+ out = self.encoder(x)
58
+ loss = self.loss(out, x)
59
+ return loss
60
+
61
+ To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
62
+
63
+ .. code-block:: python
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+ self.automatic_optimization = False
68
+
69
+
70
+ # Multiple optimizers (e.g.: GANs)
71
+ def training_step(self, batch, batch_idx):
72
+ opt1, opt2 = self.optimizers()
73
+
74
+ # do training_step with encoder
75
+ ...
76
+ opt1.step()
77
+ # do training_step with decoder
78
+ ...
79
+ opt2.step()
80
+
81
+ Note:
82
+ When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
83
+ normalized by ``accumulate_grad_batches`` internally.
84
+
85
+ """
86
+ return super().training_step(*args, **kwargs)
87
+
88
+ def configure_optimizers(self):
89
+ """
90
+ Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
91
+ https://lightning.ai/docs/pytorch/stable/common/optimization.html
92
+ """
93
+ parameters = self.parameters()
94
+ return torch.optim.Adam(parameters, lr=self.cfg.lr)
95
+
96
+ def log_video(
97
+ self,
98
+ key: str,
99
+ video: Union[np.ndarray, torch.Tensor],
100
+ mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
101
+ std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
102
+ fps: int = 5,
103
+ format: str = "mp4",
104
+ ):
105
+ """
106
+ Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
107
+
108
+ Args:
109
+ video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
110
+ (batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
111
+ or [0, 1] otherwise.
112
+ mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
113
+ std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
114
+ key: the name of the video.
115
+ fps: the frame rate of the video.
116
+ format: the format of the video. Can be either "mp4" or "gif".
117
+ """
118
+
119
+ if isinstance(video, torch.Tensor):
120
+ video = video.detach().cpu().numpy()
121
+
122
+ expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
123
+ if std is not None:
124
+ if isinstance(std, (float, int)):
125
+ std = [std] * 3
126
+ if isinstance(std, torch.Tensor):
127
+ std = std.detach().cpu().numpy()
128
+ std = np.array(std).reshape(*expand_shape)
129
+ video = video * std
130
+ if mean is not None:
131
+ if isinstance(mean, (float, int)):
132
+ mean = [mean] * 3
133
+ if isinstance(mean, torch.Tensor):
134
+ mean = mean.detach().cpu().numpy()
135
+ mean = np.array(mean).reshape(*expand_shape)
136
+ video = video + mean
137
+
138
+ if video.dtype != np.uint8:
139
+ video = np.clip(video, a_min=0, a_max=1) * 255
140
+ video = video.astype(np.uint8)
141
+
142
+ self.logger.experiment.log(
143
+ {
144
+ key: wandb.Video(video, fps=fps, format=format),
145
+ },
146
+ step=self.global_step,
147
+ )
148
+
149
+ def log_image(
150
+ self,
151
+ key: str,
152
+ image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
153
+ mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
154
+ std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
155
+ **kwargs: Any,
156
+ ):
157
+ """
158
+ Log image(s) using WandbLogger.
159
+ Args:
160
+ key: the name of the video.
161
+ image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
162
+ mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
163
+ std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
164
+ kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
165
+ """
166
+ if isinstance(image, Image.Image):
167
+ image = [image]
168
+ elif len(image) and not isinstance(image[0], Image.Image):
169
+ if isinstance(image, torch.Tensor):
170
+ image = image.detach().cpu().numpy()
171
+
172
+ if len(image.shape) == 3:
173
+ image = image[None]
174
+
175
+ if image.shape[1] == 3:
176
+ if image.shape[-1] == 3:
177
+ warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
178
+ image = einops.rearrange(image, "b c h w -> b h w c")
179
+
180
+ if std is not None:
181
+ if isinstance(std, (float, int)):
182
+ std = [std] * 3
183
+ if isinstance(std, torch.Tensor):
184
+ std = std.detach().cpu().numpy()
185
+ std = np.array(std)[None, None, None]
186
+ image = image * std
187
+ if mean is not None:
188
+ if isinstance(mean, (float, int)):
189
+ mean = [mean] * 3
190
+ if isinstance(mean, torch.Tensor):
191
+ mean = mean.detach().cpu().numpy()
192
+ mean = np.array(mean)[None, None, None]
193
+ image = image + mean
194
+
195
+ if image.dtype != np.uint8:
196
+ image = np.clip(image, a_min=0.0, a_max=1.0) * 255
197
+ image = image.astype(np.uint8)
198
+ image = [img for img in image]
199
+
200
+ self.logger.log_image(key=key, images=image, **kwargs)
201
+
202
+ def log_gradient_stats(self):
203
+ """Log gradient statistics such as the mean or std of norm."""
204
+
205
+ with torch.no_grad():
206
+ grad_norms = []
207
+ gpr = [] # gradient-to-parameter ratio
208
+ for param in self.parameters():
209
+ if param.grad is not None:
210
+ grad_norms.append(torch.norm(param.grad).item())
211
+ gpr.append(torch.norm(param.grad) / torch.norm(param))
212
+ if len(grad_norms) == 0:
213
+ return
214
+ grad_norms = torch.tensor(grad_norms)
215
+ gpr = torch.tensor(gpr)
216
+ self.log_dict(
217
+ {
218
+ "train/grad_norm/min": grad_norms.min(),
219
+ "train/grad_norm/max": grad_norms.max(),
220
+ "train/grad_norm/std": grad_norms.std(),
221
+ "train/grad_norm/mean": grad_norms.mean(),
222
+ "train/grad_norm/median": torch.median(grad_norms),
223
+ "train/gpr/min": gpr.min(),
224
+ "train/gpr/max": gpr.max(),
225
+ "train/gpr/std": gpr.std(),
226
+ "train/gpr/mean": gpr.mean(),
227
+ "train/gpr/median": torch.median(gpr),
228
+ }
229
+ )
230
+
231
+ def register_data_mean_std(
232
+ self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
233
+ ):
234
+ """
235
+ Register mean and std of data as tensor buffer.
236
+
237
+ Args:
238
+ mean: the mean of data.
239
+ std: the std of data.
240
+ namespace: the namespace of the registered buffer.
241
+ """
242
+ for k, v in [("mean", mean), ("std", std)]:
243
+ if isinstance(v, str):
244
+ if v.endswith(".npy"):
245
+ v = torch.from_numpy(np.load(v))
246
+ elif v.endswith(".pt"):
247
+ v = torch.load(v)
248
+ else:
249
+ raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
250
+ else:
251
+ v = torch.tensor(v)
252
+ self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))
algorithms/common/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .fid import FrechetInceptionDistance
2
+ from .lpips import LearnedPerceptualImagePatchSimilarity
3
+ from .fvd import FrechetVideoDistance
algorithms/common/metrics/fid.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torchmetrics.image.fid import FrechetInceptionDistance
algorithms/common/metrics/fvd.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v
3
+ Verified to be the same as tf version by https://github.com/universome/fvd-comparison
4
+ """
5
+
6
+ import io
7
+ import re
8
+ import requests
9
+ import html
10
+ import hashlib
11
+ import urllib
12
+ import urllib.request
13
+ from typing import Any, List, Tuple, Union, Dict
14
+ import scipy
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import numpy as np
19
+
20
+
21
+ def open_url(
22
+ url: str,
23
+ num_attempts: int = 10,
24
+ verbose: bool = True,
25
+ return_filename: bool = False,
26
+ ) -> Any:
27
+ """Download the given URL and return a binary-mode file object to access the data."""
28
+ assert num_attempts >= 1
29
+
30
+ # Doesn't look like an URL scheme so interpret it as a local filename.
31
+ if not re.match("^[a-z]+://", url):
32
+ return url if return_filename else open(url, "rb")
33
+
34
+ # Handle file URLs. This code handles unusual file:// patterns that
35
+ # arise on Windows:
36
+ #
37
+ # file:///c:/foo.txt
38
+ #
39
+ # which would translate to a local '/c:/foo.txt' filename that's
40
+ # invalid. Drop the forward slash for such pathnames.
41
+ #
42
+ # If you touch this code path, you should test it on both Linux and
43
+ # Windows.
44
+ #
45
+ # Some internet resources suggest using urllib.request.url2pathname() but
46
+ # but that converts forward slashes to backslashes and this causes
47
+ # its own set of problems.
48
+ if url.startswith("file://"):
49
+ filename = urllib.parse.urlparse(url).path
50
+ if re.match(r"^/[a-zA-Z]:", filename):
51
+ filename = filename[1:]
52
+ return filename if return_filename else open(filename, "rb")
53
+
54
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
55
+
56
+ # Download.
57
+ url_name = None
58
+ url_data = None
59
+ with requests.Session() as session:
60
+ if verbose:
61
+ print("Downloading %s ..." % url, end="", flush=True)
62
+ for attempts_left in reversed(range(num_attempts)):
63
+ try:
64
+ with session.get(url) as res:
65
+ res.raise_for_status()
66
+ if len(res.content) == 0:
67
+ raise IOError("No data received")
68
+
69
+ if len(res.content) < 8192:
70
+ content_str = res.content.decode("utf-8")
71
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
72
+ links = [
73
+ html.unescape(link)
74
+ for link in content_str.split('"')
75
+ if "export=download" in link
76
+ ]
77
+ if len(links) == 1:
78
+ url = requests.compat.urljoin(url, links[0])
79
+ raise IOError("Google Drive virus checker nag")
80
+ if "Google Drive - Quota exceeded" in content_str:
81
+ raise IOError(
82
+ "Google Drive download quota exceeded -- please try again later"
83
+ )
84
+
85
+ match = re.search(
86
+ r'filename="([^"]*)"',
87
+ res.headers.get("Content-Disposition", ""),
88
+ )
89
+ url_name = match[1] if match else url
90
+ url_data = res.content
91
+ if verbose:
92
+ print(" done")
93
+ break
94
+ except KeyboardInterrupt:
95
+ raise
96
+ except:
97
+ if not attempts_left:
98
+ if verbose:
99
+ print(" failed")
100
+ raise
101
+ if verbose:
102
+ print(".", end="", flush=True)
103
+
104
+ # Return data as file object.
105
+ assert not return_filename
106
+ return io.BytesIO(url_data)
107
+
108
+
109
+ def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
110
+ mu_gen, sigma_gen = compute_stats(feats_fake)
111
+ mu_real, sigma_real = compute_stats(feats_real)
112
+
113
+ m = np.square(mu_gen - mu_real).sum()
114
+ s, _ = scipy.linalg.sqrtm(
115
+ np.dot(sigma_gen, sigma_real), disp=False
116
+ ) # pylint: disable=no-member
117
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
118
+
119
+ return float(fid)
120
+
121
+
122
+ def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
123
+ mu = feats.mean(axis=0) # [d]
124
+ sigma = np.cov(feats, rowvar=False) # [d, d]
125
+
126
+ return mu, sigma
127
+
128
+
129
+ class FrechetVideoDistance(nn.Module):
130
+ def __init__(self):
131
+ super().__init__()
132
+ detector_url = (
133
+ "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
134
+ )
135
+ # Return raw features before the softmax layer.
136
+ self.detector_kwargs = dict(rescale=False, resize=True, return_features=True)
137
+ with open_url(detector_url, verbose=False) as f:
138
+ self.detector = torch.jit.load(f).eval()
139
+
140
+ @torch.no_grad()
141
+ def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor):
142
+ """
143
+ :param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width)
144
+ :param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width)
145
+ :return:
146
+ """
147
+ n_frames, batch_size, c, h, w = videos_fake.shape
148
+ if n_frames < 2:
149
+ raise ValueError("Video must have more than 1 frame for FVD")
150
+
151
+ videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous()
152
+ videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous()
153
+
154
+ # detector takes in tensors of shape [batch_size, c, video_len, h, w] with range -1 to 1
155
+ feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy()
156
+ feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy()
157
+
158
+ return compute_fvd(feats_fake, feats_real)
algorithms/common/metrics/lpips.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
algorithms/common/models/__init__.py ADDED
File without changes
algorithms/common/models/cnn.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def is_square_of_two(num):
7
+ if num <= 0:
8
+ return False
9
+ return num & (num - 1) == 0
10
+
11
+ class CnnEncoder(nn.Module):
12
+ """
13
+ Simple cnn encoder that encodes a 64x64 image to embeddings
14
+ """
15
+ def __init__(self, embedding_size, activation_function='relu'):
16
+ super().__init__()
17
+ self.act_fn = getattr(F, activation_function)
18
+ self.embedding_size = embedding_size
19
+ self.fc = nn.Linear(1024, self.embedding_size)
20
+ self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
21
+ self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
22
+ self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
23
+ self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
24
+ self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
25
+
26
+ def forward(self, observation):
27
+ batch_size = observation.shape[0]
28
+ hidden = self.act_fn(self.conv1(observation))
29
+ hidden = self.act_fn(self.conv2(hidden))
30
+ hidden = self.act_fn(self.conv3(hidden))
31
+ hidden = self.act_fn(self.conv4(hidden))
32
+ hidden = self.fc(hidden.view(batch_size, 1024))
33
+ return hidden
34
+
35
+
36
+ class CnnDecoder(nn.Module):
37
+ """
38
+ Simple Cnn decoder that decodes an embedding to 64x64 images
39
+ """
40
+ def __init__(self, embedding_size, activation_function='relu'):
41
+ super().__init__()
42
+ self.act_fn = getattr(F, activation_function)
43
+ self.embedding_size = embedding_size
44
+ self.fc = nn.Linear(embedding_size, 128)
45
+ self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2)
46
+ self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
47
+ self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
48
+ self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)
49
+ self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
50
+
51
+ def forward(self, embedding):
52
+ batch_size = embedding.shape[0]
53
+ hidden = self.fc(embedding)
54
+ hidden = hidden.view(batch_size, 128, 1, 1)
55
+ hidden = self.act_fn(self.conv1(hidden))
56
+ hidden = self.act_fn(self.conv2(hidden))
57
+ hidden = self.act_fn(self.conv3(hidden))
58
+ observation = self.conv4(hidden)
59
+ return observation
60
+
61
+
62
+ class FullyConvEncoder(nn.Module):
63
+ """
64
+ Simple fully convolutional encoder, with 2D input and 2D output
65
+ """
66
+ def __init__(self,
67
+ input_shape=(3, 64, 64),
68
+ embedding_shape=(8, 16, 16),
69
+ activation_function='relu',
70
+ init_channels=16,
71
+ ):
72
+ super().__init__()
73
+
74
+ assert len(input_shape) == 3, "input_shape must be a tuple of length 3"
75
+ assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
76
+ assert input_shape[1] == input_shape[2] and is_square_of_two(input_shape[1]), "input_shape must be square"
77
+ assert embedding_shape[1] == embedding_shape[2], "embedding_shape must be square"
78
+ assert input_shape[1] % embedding_shape[1] == 0, "input_shape must be divisible by embedding_shape"
79
+ assert is_square_of_two(init_channels), "init_channels must be a square of 2"
80
+
81
+ depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1
82
+ channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
83
+ self.act_fn = getattr(F, activation_function)
84
+
85
+ self.downs = nn.ModuleList([])
86
+ self.downs.append(nn.Conv2d(input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1))
87
+
88
+ for i in range(1, depth):
89
+ self.downs.append(nn.Conv2d(channels_per_layer[i-1], channels_per_layer[i],
90
+ kernel_size=3, stride=2, padding=1))
91
+
92
+ # Bottleneck layer
93
+ self.downs.append(nn.Conv2d(channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0))
94
+
95
+ def forward(self, observation):
96
+ hidden = observation
97
+ for layer in self.downs:
98
+ hidden = self.act_fn(layer(hidden))
99
+ return hidden
100
+
101
+
102
+ class FullyConvDecoder(nn.Module):
103
+ """
104
+ Simple fully convolutional decoder, with 2D input and 2D output
105
+ """
106
+ def __init__(self,
107
+ embedding_shape=(8, 16, 16),
108
+ output_shape=(3, 64, 64),
109
+ activation_function='relu',
110
+ init_channels=16,
111
+ ):
112
+ super().__init__()
113
+
114
+ assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
115
+ assert len(output_shape) == 3, "output_shape must be a tuple of length 3"
116
+ assert output_shape[1] == output_shape[2] and is_square_of_two(output_shape[1]), "output_shape must be square"
117
+ assert embedding_shape[1] == embedding_shape[2], "input_shape must be square"
118
+ assert output_shape[1] % embedding_shape[1] == 0, "output_shape must be divisible by input_shape"
119
+ assert is_square_of_two(init_channels), "init_channels must be a square of 2"
120
+
121
+ depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1
122
+ channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
123
+ self.act_fn = getattr(F, activation_function)
124
+
125
+ self.ups = nn.ModuleList([])
126
+ self.ups.append(nn.ConvTranspose2d(embedding_shape[0], channels_per_layer[-1],
127
+ kernel_size=1, stride=1, padding=0))
128
+
129
+ for i in range(1, depth):
130
+ self.ups.append(nn.ConvTranspose2d(channels_per_layer[-i], channels_per_layer[-i-1],
131
+ kernel_size=3, stride=2, padding=1, output_padding=1))
132
+
133
+ self.output_layer = nn.ConvTranspose2d(channels_per_layer[0], output_shape[0],
134
+ kernel_size=3, stride=1, padding=1)
135
+
136
+ def forward(self, embedding):
137
+ hidden = embedding
138
+ for layer in self.ups:
139
+ hidden = self.act_fn(layer(hidden))
140
+
141
+ return self.output_layer(hidden)
algorithms/common/models/mlp.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, Optional
2
+
3
+ import torch
4
+ from torch import nn as nn
5
+
6
+
7
+ class SimpleMlp(nn.Module):
8
+ """
9
+ A class for very simple multi layer perceptron
10
+ """
11
+ def __init__(self, in_dim=2, out_dim=1, hidden_dim=64, n_layers=2,
12
+ activation: Type[nn.Module] = nn.ReLU, output_activation: Optional[Type[nn.Module]] = None):
13
+ super(SimpleMlp, self).__init__()
14
+ layers = [nn.Linear(in_dim, hidden_dim), activation()]
15
+ layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()] * (n_layers - 2))
16
+ layers.append(nn.Linear(hidden_dim, out_dim))
17
+ if output_activation:
18
+ layers.append(output_activation())
19
+ self.net = nn.Sequential(*layers)
20
+
21
+ def forward(self, x):
22
+ return self.net(x)
algorithms/worldmem/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .df_video import WorldMemMinecraft
2
+ from .pose_prediction import PosePrediction
algorithms/worldmem/df_base.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
3
+ template [repo](https://github.com/buoyancy99/research-template).
4
+ By its MIT license, you must keep the above sentence in `README.md`
5
+ and the `LICENSE` file to credit the author.
6
+ """
7
+
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from omegaconf import DictConfig
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from typing import Any
15
+ from einops import rearrange
16
+
17
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
18
+
19
+ from algorithms.common.base_pytorch_algo import BasePytorchAlgo
20
+ from .models.diffusion import Diffusion
21
+
22
+
23
+ class DiffusionForcingBase(BasePytorchAlgo):
24
+ def __init__(self, cfg: DictConfig):
25
+ self.cfg = cfg
26
+ self.x_shape = cfg.x_shape
27
+ self.frame_stack = cfg.frame_stack
28
+ self.x_stacked_shape = list(self.x_shape)
29
+ self.x_stacked_shape[0] *= cfg.frame_stack
30
+ self.guidance_scale = cfg.guidance_scale
31
+ self.context_frames = cfg.context_frames
32
+ self.chunk_size = cfg.chunk_size
33
+ self.action_cond_dim = cfg.action_cond_dim
34
+ self.causal = cfg.causal
35
+
36
+ self.uncertainty_scale = cfg.uncertainty_scale
37
+ self.timesteps = cfg.diffusion.timesteps
38
+ self.sampling_timesteps = cfg.diffusion.sampling_timesteps
39
+ self.clip_noise = cfg.diffusion.clip_noise
40
+
41
+ self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip)
42
+
43
+ self.validation_step_outputs = []
44
+ super().__init__(cfg)
45
+
46
+ def _build_model(self):
47
+ self.diffusion_model = Diffusion(
48
+ x_shape=self.x_stacked_shape,
49
+ action_cond_dim=self.action_cond_dim,
50
+ is_causal=self.causal,
51
+ cfg=self.cfg.diffusion,
52
+ )
53
+ self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
54
+
55
+ def configure_optimizers(self):
56
+ params = tuple(self.diffusion_model.parameters())
57
+ optimizer_dynamics = torch.optim.AdamW(
58
+ params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta
59
+ )
60
+ return optimizer_dynamics
61
+
62
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
63
+ # update params
64
+ optimizer.step(closure=optimizer_closure)
65
+
66
+ # manually warm up lr without a scheduler
67
+ if self.trainer.global_step < self.cfg.warmup_steps:
68
+ lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps)
69
+ for pg in optimizer.param_groups:
70
+ pg["lr"] = lr_scale * self.cfg.lr
71
+
72
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
73
+ xs, conditions, masks = self._preprocess_batch(batch)
74
+
75
+ rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item()
76
+ xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]])
77
+ conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]])
78
+ masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]])
79
+ noise_levels=self._generate_noise_levels(xs)
80
+ noise_levels[:rand_length] = 15 # stable_noise_levels
81
+ noise_levels[rand_length+1:] = 15 # stable_noise_levels
82
+
83
+ xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels)
84
+ loss = self.reweight_loss(loss, masks)
85
+
86
+ # log the loss
87
+ if batch_idx % 20 == 0:
88
+ self.log("training/loss", loss)
89
+
90
+ xs = self._unstack_and_unnormalize(xs)
91
+ xs_pred = self._unstack_and_unnormalize(xs_pred)
92
+
93
+ output_dict = {
94
+ "loss": loss,
95
+ "xs_pred": xs_pred,
96
+ "xs": xs,
97
+ }
98
+
99
+ return output_dict
100
+
101
+ @torch.no_grad()
102
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
103
+ xs, conditions, masks = self._preprocess_batch(batch)
104
+ n_frames, batch_size, *_ = xs.shape
105
+ xs_pred = []
106
+ curr_frame = 0
107
+
108
+ # context
109
+ n_context_frames = self.context_frames // self.frame_stack
110
+ xs_pred = xs[:n_context_frames].clone()
111
+ curr_frame += n_context_frames
112
+
113
+ if self.condtion_similar_length:
114
+ n_frames -= self.condtion_similar_length
115
+
116
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
117
+ while curr_frame < n_frames:
118
+ if self.chunk_size > 0:
119
+ horizon = min(n_frames - curr_frame, self.chunk_size)
120
+ else:
121
+ horizon = n_frames - curr_frame
122
+ assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
123
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
124
+
125
+ chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device)
126
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
127
+ xs_pred = torch.cat([xs_pred, chunk], 0)
128
+
129
+ # sliding window: only input the last n_tokens frames
130
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
131
+
132
+ pbar.set_postfix(
133
+ {
134
+ "start": start_frame,
135
+ "end": curr_frame + horizon,
136
+ }
137
+ )
138
+
139
+ if self.condtion_similar_length:
140
+ xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0)
141
+
142
+ for m in range(scheduling_matrix.shape[0] - 1):
143
+
144
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
145
+ :, None
146
+ ].repeat(batch_size, axis=1)
147
+ to_noise_levels = np.concatenate(
148
+ (
149
+ np.zeros((curr_frame,), dtype=np.int64),
150
+ scheduling_matrix[m + 1],
151
+ )
152
+ )[
153
+ :, None
154
+ ].repeat(batch_size, axis=1)
155
+
156
+ if self.condtion_similar_length:
157
+ from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
158
+ to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
159
+
160
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
161
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
162
+
163
+ # update xs_pred by DDIM or DDPM sampling
164
+ # input frames within the sliding window
165
+
166
+ try:
167
+ input_condition = conditions[start_frame : curr_frame + horizon].clone()
168
+ except:
169
+ import pdb;pdb.set_trace()
170
+ if self.condtion_similar_length:
171
+ input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0)
172
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
173
+ xs_pred[start_frame:],
174
+ input_condition,
175
+ from_noise_levels[start_frame:],
176
+ to_noise_levels[start_frame:],
177
+ )
178
+
179
+ if self.condtion_similar_length:
180
+ xs_pred = xs_pred[:-self.condtion_similar_length]
181
+
182
+ curr_frame += horizon
183
+ pbar.update(horizon)
184
+
185
+ if self.condtion_similar_length:
186
+ xs = xs[:-self.condtion_similar_length]
187
+ # FIXME: loss
188
+ loss = F.mse_loss(xs_pred, xs, reduction="none")
189
+ loss = self.reweight_loss(loss, masks)
190
+ self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu()))
191
+
192
+ return loss
193
+
194
+ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
195
+ return self.validation_step(*args, **kwargs, namespace="test")
196
+
197
+ def on_test_epoch_end(self) -> None:
198
+ self.on_validation_epoch_end(namespace="test")
199
+
200
+ def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor:
201
+ """
202
+ Generate noise levels for training.
203
+ """
204
+ num_frames, batch_size, *_ = xs.shape
205
+ match self.cfg.noise_level:
206
+ case "random_all": # entirely random noise levels
207
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
208
+ case "same":
209
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
210
+ noise_levels[1:] = noise_levels[0]
211
+
212
+ if masks is not None:
213
+ # for frames that are not available, treat as full noise
214
+ discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
215
+ noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
216
+
217
+ return noise_levels
218
+
219
+ def _generate_scheduling_matrix(self, horizon: int):
220
+ match self.cfg.scheduling_matrix:
221
+ case "pyramid":
222
+ return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale)
223
+ case "full_sequence":
224
+ return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1)
225
+ case "autoregressive":
226
+ return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps)
227
+ case "trapezoid":
228
+ return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale)
229
+
230
+ def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
231
+ height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1
232
+ scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
233
+ for m in range(height):
234
+ for t in range(horizon):
235
+ scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
236
+
237
+ return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
238
+
239
+ def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
240
+ height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale)
241
+ scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
242
+ for m in range(height):
243
+ for t in range((horizon + 1) // 2):
244
+ scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
245
+ scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
246
+
247
+ return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
248
+
249
+ def reweight_loss(self, loss, weight=None):
250
+ # Note there is another part of loss reweighting (fused_snr) inside the Diffusion class!
251
+ loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack)
252
+ if weight is not None:
253
+ expand_dim = len(loss.shape) - len(weight.shape) - 1
254
+ weight = rearrange(
255
+ weight,
256
+ "(t fs) b ... -> t b fs ..." + " 1" * expand_dim,
257
+ fs=self.frame_stack,
258
+ )
259
+ loss = loss * weight
260
+
261
+ return loss.mean()
262
+
263
+ def _preprocess_batch(self, batch):
264
+ xs = batch[0]
265
+ batch_size, n_frames = xs.shape[:2]
266
+
267
+ if n_frames % self.frame_stack != 0:
268
+ raise ValueError("Number of frames must be divisible by frame stack size")
269
+ if self.context_frames % self.frame_stack != 0:
270
+ raise ValueError("Number of context frames must be divisible by frame stack size")
271
+
272
+ masks = torch.ones(n_frames, batch_size).to(xs.device)
273
+ n_frames = n_frames // self.frame_stack
274
+
275
+ if self.action_cond_dim:
276
+ conditions = batch[1]
277
+ conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
278
+ conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous()
279
+
280
+ # f, _, _ = conditions.shape
281
+ # predefined_1 = torch.tensor([0,0,0,1]).to(conditions.device)
282
+ # predefined_2 = torch.tensor([0,0,1,0]).to(conditions.device)
283
+ # conditions[:f//2] = predefined_1
284
+ # conditions[f//2:] = predefined_2
285
+ else:
286
+ conditions = [None for _ in range(n_frames)]
287
+
288
+ xs = self._normalize_x(xs)
289
+ xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous()
290
+
291
+ return xs, conditions, masks
292
+
293
+ def _normalize_x(self, xs):
294
+ shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
295
+ mean = self.data_mean.reshape(shape)
296
+ std = self.data_std.reshape(shape)
297
+ return (xs - mean) / std
298
+
299
+ def _unnormalize_x(self, xs):
300
+ shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
301
+ mean = self.data_mean.reshape(shape)
302
+ std = self.data_std.reshape(shape)
303
+ return xs * std + mean
304
+
305
+ def _unstack_and_unnormalize(self, xs):
306
+ xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
307
+ return self._unnormalize_x(xs)
algorithms/worldmem/df_video.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms.functional as TF
8
+ from torchvision.transforms import InterpolationMode
9
+ from PIL import Image
10
+ from packaging import version as pver
11
+ from einops import rearrange
12
+ from tqdm import tqdm
13
+ from omegaconf import DictConfig
14
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
15
+ from algorithms.common.metrics import (
16
+ LearnedPerceptualImagePatchSimilarity,
17
+ )
18
+ from utils.logging_utils import log_video, get_validation_metrics_for_videos
19
+ from .df_base import DiffusionForcingBase
20
+ from .models.vae import VAE_models
21
+ from .models.diffusion import Diffusion
22
+ from .models.pose_prediction import PosePredictionNet
23
+ import glob
24
+
25
+ # Utility Functions
26
+ def euler_to_rotation_matrix(pitch, yaw):
27
+ """
28
+ Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix.
29
+ Supports batch input.
30
+
31
+ Args:
32
+ pitch (torch.Tensor): Pitch angles in radians.
33
+ yaw (torch.Tensor): Yaw angles in radians.
34
+
35
+ Returns:
36
+ torch.Tensor: Rotation matrix of shape (batch_size, 3, 3).
37
+ """
38
+ cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
39
+ cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
40
+
41
+ R_pitch = torch.stack([
42
+ torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
43
+ torch.zeros_like(pitch), cos_pitch, -sin_pitch,
44
+ torch.zeros_like(pitch), sin_pitch, cos_pitch
45
+ ], dim=-1).reshape(-1, 3, 3)
46
+
47
+ R_yaw = torch.stack([
48
+ cos_yaw, torch.zeros_like(yaw), sin_yaw,
49
+ torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
50
+ -sin_yaw, torch.zeros_like(yaw), cos_yaw
51
+ ], dim=-1).reshape(-1, 3, 3)
52
+
53
+ return torch.matmul(R_yaw, R_pitch)
54
+
55
+
56
+ def euler_to_camera_to_world_matrix(pose):
57
+ """
58
+ Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch.
59
+ Supports both (5,) and (f, b, 5) shaped inputs.
60
+
61
+ Args:
62
+ pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5).
63
+
64
+ Returns:
65
+ torch.Tensor: Camera-to-world transformation matrix of shape (4, 4).
66
+ """
67
+
68
+ origin_dim = pose.ndim
69
+ if origin_dim == 1:
70
+ pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5)
71
+ elif origin_dim == 2:
72
+ pose = pose.unsqueeze(0)
73
+
74
+ x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4]
75
+ pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw)
76
+
77
+ # Compute rotation matrix (batch mode)
78
+ R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3)
79
+
80
+ # Create the 4x4 transformation matrix
81
+ eye = torch.eye(4, dtype=torch.float32, device=pose.device)
82
+ camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4)
83
+
84
+ # Assign rotation
85
+ camera_to_world[:, :3, :3] = R
86
+
87
+ # Assign translation
88
+ camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1)
89
+
90
+ # Reshape back to (f, b, 4, 4) if needed
91
+ if origin_dim == 3:
92
+ return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4)
93
+ elif origin_dim == 2:
94
+ return camera_to_world.view(pose.shape[0], 4, 4)
95
+ else:
96
+ return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4)
97
+
98
+ def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v):
99
+ """
100
+ Check whether points are within a given 3D field of view (FOV)
101
+ with separately defined horizontal and vertical ranges.
102
+
103
+ The center view direction is specified by pitch and yaw (in degrees).
104
+
105
+ :param points: (N, B, 3) Sample point coordinates
106
+ :param center: (3,) Center coordinates of the FOV
107
+ :param center_pitch: Pitch angle of the center view (in degrees)
108
+ :param center_yaw: Yaw angle of the center view (in degrees)
109
+ :param fov_half_h: Horizontal half-FOV angle (in degrees)
110
+ :param fov_half_v: Vertical half-FOV angle (in degrees)
111
+ :return: Boolean tensor (N, B), indicating whether each point is inside the FOV
112
+ """
113
+ # Compute vectors relative to the center
114
+ vectors = points - center # shape (N, B, 3)
115
+ x = vectors[..., 0]
116
+ y = vectors[..., 1]
117
+ z = vectors[..., 2]
118
+
119
+ # Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction,
120
+ # and the x-axis as left-right, resulting in a range of -180 to 180 degrees.
121
+ azimuth = torch.atan2(x, z) * (180 / math.pi)
122
+
123
+ # Compute vertical angle (pitch): measured with respect to the horizontal plane,
124
+ # resulting in a range of -90 to 90 degrees.
125
+ elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi)
126
+
127
+ # Compute the angular difference from the center view (handling circular angle wrap-around)
128
+ diff_azimuth = (azimuth - center_yaw).abs() % 360
129
+ diff_elevation = (elevation - center_pitch).abs() % 360
130
+
131
+ # Adjust values greater than 180 degrees to the shorter angular difference
132
+ diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth)
133
+ diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation)
134
+
135
+ # Check if both horizontal and vertical angles are within their respective FOV limits
136
+ return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v)
137
+
138
+ def generate_points_in_sphere(n_points, radius):
139
+ # Sample three independent uniform distributions
140
+ samples_r = torch.rand(n_points) # For radius distribution
141
+ samples_phi = torch.rand(n_points) # For azimuthal angle phi
142
+ samples_u = torch.rand(n_points) # For polar angle theta
143
+
144
+ # Apply cube root to ensure uniform volumetric distribution
145
+ r = radius * torch.pow(samples_r, 1/3)
146
+ # Azimuthal angle phi uniformly distributed in [0, 2π]
147
+ phi = 2 * math.pi * samples_phi
148
+ # Convert u to theta to ensure cos(theta) is uniformly distributed
149
+ theta = torch.acos(1 - 2 * samples_u)
150
+
151
+ # Convert spherical coordinates to Cartesian coordinates
152
+ x = r * torch.sin(theta) * torch.cos(phi)
153
+ y = r * torch.sin(theta) * torch.sin(phi)
154
+ z = r * torch.cos(theta)
155
+
156
+ points = torch.stack((x, y, z), dim=1)
157
+ return points
158
+
159
+ def tensor_max_with_number(tensor, number):
160
+ number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device)
161
+ result = torch.max(tensor, number_tensor)
162
+ return result
163
+
164
+ def custom_meshgrid(*args):
165
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
166
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
167
+ return torch.meshgrid(*args)
168
+ else:
169
+ return torch.meshgrid(*args, indexing='ij')
170
+
171
+ def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor:
172
+ """
173
+ Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4).
174
+
175
+ Args:
176
+ camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where:
177
+ f = number of frames,
178
+ b = batch size.
179
+
180
+ Returns:
181
+ torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices.
182
+ """
183
+ # Ensure input is a 4D tensor
184
+ assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \
185
+ "Input must be of shape (f, b, 4, 4)"
186
+
187
+ # Extract the rotation (R) and translation (T) parts
188
+ R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3)
189
+ T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3)
190
+
191
+ # Initialize an identity matrix for the output
192
+ world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0)
193
+ world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4)
194
+
195
+ # Compute the rotation (transpose of R)
196
+ world_to_camera[:, :, :3, :3] = R.transpose(2, 3)
197
+
198
+ # Compute the translation (-R^T * T)
199
+ world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1)
200
+
201
+ return world_to_camera.to(camera_to_world.dtype)
202
+
203
+ def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height):
204
+
205
+ intrinsic = np.asarray([focal_length * image_width,
206
+ focal_length * image_height,
207
+ 0.5 * image_width,
208
+ 0.5 * image_height], dtype=np.float32)
209
+
210
+ c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame)
211
+ c2ws = rearrange(c2ws, "t b m n -> b t m n")
212
+
213
+ K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4]
214
+ plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device)
215
+ plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous()
216
+
217
+ return plucker_embedding
218
+
219
+
220
+ def get_relative_pose(abs_c2ws, zero_first_frame_scale):
221
+ abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws)
222
+ target_cam_c2w = torch.tensor([
223
+ [1, 0, 0, 0],
224
+ [0, 1, 0, 0],
225
+ [0, 0, 1, 0],
226
+ [0, 0, 0, 1]
227
+ ]).to(abs_c2ws.device).to(abs_c2ws.dtype)
228
+ abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale]
229
+ ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws]
230
+ ret_poses = torch.stack(ret_poses)
231
+ return ret_poses
232
+
233
+ def ray_condition(K, c2w, H, W, device):
234
+ # c2w: B, V, 4, 4
235
+ # K: B, V, 4
236
+
237
+ B = K.shape[0]
238
+
239
+ j, i = custom_meshgrid(
240
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
241
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
242
+ )
243
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
244
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
245
+
246
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
247
+
248
+ zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW]
249
+ xs = -(i - cx) / fx * zs
250
+ ys = -(j - cy) / fy * zs
251
+
252
+ zs = zs.expand_as(ys)
253
+
254
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
255
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
256
+
257
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
258
+ rays_o = c2w[..., :3, 3] # B, V, 3
259
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
260
+ # c2w @ dirctions
261
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
262
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
263
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
264
+
265
+ return plucker
266
+
267
+ def random_transform(tensor):
268
+ """
269
+ Apply the same random translation, rotation, and scaling to all frames in the batch.
270
+
271
+ Args:
272
+ tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W).
273
+
274
+ Returns:
275
+ torch.Tensor: Transformed tensor of shape (F, B, 3, H, W).
276
+ """
277
+ if tensor.ndim != 5:
278
+ raise ValueError("Input tensor must have shape (F, B, 3, H, W)")
279
+
280
+ F, B, C, H, W = tensor.shape
281
+
282
+ # Generate random transformation parameters
283
+ max_translate = 0.2 # Translate up to 20% of width/height
284
+ max_rotate = 30 # Rotate up to 30 degrees
285
+ max_scale = 0.2 # Scale change by up to +/- 20%
286
+
287
+ translate_x = random.uniform(-max_translate, max_translate) * W
288
+ translate_y = random.uniform(-max_translate, max_translate) * H
289
+ rotate_angle = random.uniform(-max_rotate, max_rotate)
290
+ scale_factor = 1 + random.uniform(-max_scale, max_scale)
291
+
292
+ # Apply the same transformation to all frames and batches
293
+
294
+ tensor = tensor.reshape(F*B, C, H, W)
295
+ transformed_tensor = TF.affine(
296
+ tensor,
297
+ angle=rotate_angle,
298
+ translate=(translate_x, translate_y),
299
+ scale=scale_factor,
300
+ shear=(0, 0),
301
+ interpolation=InterpolationMode.BILINEAR,
302
+ fill=0
303
+ )
304
+
305
+ transformed_tensor = transformed_tensor.reshape(F, B, C, H, W)
306
+ return transformed_tensor
307
+
308
+ def save_tensor_as_png(tensor, file_path):
309
+ """
310
+ Save a 3*H*W tensor as a PNG image.
311
+
312
+ Args:
313
+ tensor (torch.Tensor): Input tensor of shape (3, H, W).
314
+ file_path (str): Path to save the PNG file.
315
+ """
316
+ if tensor.ndim != 3 or tensor.shape[0] != 3:
317
+ raise ValueError("Input tensor must have shape (3, H, W)")
318
+
319
+ # Convert tensor to PIL Image
320
+ image = TF.to_pil_image(tensor)
321
+
322
+ # Save image
323
+ image.save(file_path)
324
+
325
+ class WorldMemMinecraft(DiffusionForcingBase):
326
+ """
327
+ Video generation for MineCraft with memory.
328
+ """
329
+
330
+ def __init__(self, cfg: DictConfig):
331
+ """
332
+ Initialize the WorldMemMinecraft class with the given configuration.
333
+
334
+ Args:
335
+ cfg (DictConfig): Configuration object.
336
+ """
337
+ self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model
338
+ self.n_frames = cfg.n_frames
339
+ if hasattr(cfg, "n_tokens"):
340
+ self.n_tokens = cfg.n_tokens // cfg.frame_stack
341
+ self.memory_condition_length = cfg.memory_condition_length
342
+ self.pose_cond_dim = getattr(cfg, "pose_cond_dim", 5)
343
+
344
+ self.use_plucker = getattr(cfg, "use_plucker", True)
345
+ self.relative_embedding = getattr(cfg, "relative_embedding", True)
346
+ self.state_embed_only_on_qk = getattr(cfg, "state_embed_only_on_qk", True)
347
+ self.use_memory_attention = getattr(cfg, "use_memory_attention", True)
348
+ self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
349
+ self.ref_mode = getattr(cfg, "ref_mode", 'sequential')
350
+ self.log_curve = getattr(cfg, "log_curve", False)
351
+ self.focal_length = getattr(cfg, "focal_length", 0.35)
352
+ self.log_video = cfg.log_video
353
+ self.save_local = getattr(cfg, "save_local", True)
354
+ self.local_save_dir = getattr(cfg, "local_save_dir", None)
355
+ self.lpips_batch_size = getattr(cfg, "lpips_batch_size", 16)
356
+ self.next_frame_length = getattr(cfg, "next_frame_length", 1)
357
+ self.require_pose_prediction = getattr(cfg, "require_pose_prediction", False)
358
+
359
+ super().__init__(cfg)
360
+
361
+ def _build_model(self):
362
+
363
+ self.diffusion_model = Diffusion(
364
+ reference_length=self.memory_condition_length,
365
+ x_shape=self.x_stacked_shape,
366
+ action_cond_dim=self.action_cond_dim,
367
+ pose_cond_dim=self.pose_cond_dim,
368
+ is_causal=self.causal,
369
+ cfg=self.cfg.diffusion,
370
+ is_dit=True,
371
+ use_plucker=self.use_plucker,
372
+ relative_embedding=self.relative_embedding,
373
+ state_embed_only_on_qk=self.state_embed_only_on_qk,
374
+ use_memory_attention=self.use_memory_attention,
375
+ add_timestamp_embedding=self.add_timestamp_embedding,
376
+ ref_mode=self.ref_mode
377
+ )
378
+
379
+ self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
380
+ vae = VAE_models["vit-l-20-shallow-encoder"]()
381
+ self.vae = vae.eval()
382
+
383
+ if self.require_pose_prediction:
384
+ self.pose_prediction_model = PosePredictionNet()
385
+
386
+ def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor:
387
+ """
388
+ Generate noise levels for training.
389
+ """
390
+ num_frames, batch_size, *_ = xs.shape
391
+ match self.cfg.noise_level:
392
+ case "random_all": # entirely random noise levels
393
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
394
+ case "same":
395
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
396
+ noise_levels[1:] = noise_levels[0]
397
+
398
+ if masks is not None:
399
+ # for frames that are not available, treat as full noise
400
+ discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
401
+ noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
402
+
403
+ return noise_levels
404
+
405
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
406
+ """
407
+ Perform a single training step.
408
+
409
+ This function processes the input batch,
410
+ encodes the input frames, generates noise levels, and computes the loss using the diffusion model.
411
+
412
+ Args:
413
+ batch: Input batch of data containing frames, conditions, poses, etc.
414
+ batch_idx: Index of the current batch.
415
+
416
+ Returns:
417
+ dict: A dictionary containing the training loss.
418
+ """
419
+ xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
420
+
421
+ if self.use_plucker:
422
+ if self.relative_embedding:
423
+ input_pose_condition = []
424
+ frame_idx_list = []
425
+ for i in range(self.n_frames):
426
+ input_pose_condition.append(
427
+ convert_to_plucker(
428
+ torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.memory_condition_length:]]).clone(),
429
+ 0,
430
+ focal_length=self.focal_length,
431
+ image_height=xs.shape[-2],image_width=xs.shape[-1]
432
+ ).to(xs.dtype)
433
+ )
434
+ frame_idx_list.append(
435
+ torch.cat([
436
+ frame_idx[i:i + 1] - frame_idx[i:i + 1],
437
+ frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
438
+ ]).clone()
439
+ )
440
+ input_pose_condition = torch.cat(input_pose_condition)
441
+ frame_idx_list = torch.cat(frame_idx_list)
442
+ else:
443
+ input_pose_condition = convert_to_plucker(
444
+ c2w_mat, 0, focal_length=self.focal_length
445
+ ).to(xs.dtype)
446
+ frame_idx_list = frame_idx
447
+ else:
448
+ input_pose_condition = pose_conditions.to(xs.dtype)
449
+ frame_idx_list = None
450
+
451
+ xs = self.encode(xs)
452
+
453
+ noise_levels = self._generate_noise_levels(xs)
454
+
455
+ if self.memory_condition_length:
456
+ noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
457
+ conditions[-self.memory_condition_length:] *= 0
458
+
459
+ _, loss = self.diffusion_model(
460
+ xs,
461
+ conditions,
462
+ input_pose_condition,
463
+ noise_levels=noise_levels,
464
+ reference_length=self.memory_condition_length,
465
+ frame_idx=frame_idx_list
466
+ )
467
+
468
+ if self.memory_condition_length:
469
+ loss = loss[:-self.memory_condition_length]
470
+
471
+ loss = self.reweight_loss(loss, None)
472
+
473
+ if batch_idx % 20 == 0:
474
+ self.log("training/loss", loss.cpu())
475
+
476
+ return {"loss": loss}
477
+
478
+ def on_validation_epoch_end(self, namespace="validation") -> None:
479
+ if not self.validation_step_outputs:
480
+ return
481
+
482
+ xs_pred = []
483
+ xs = []
484
+ for pred, gt in self.validation_step_outputs:
485
+ xs_pred.append(pred)
486
+ xs.append(gt)
487
+
488
+ xs_pred = torch.cat(xs_pred, 1)
489
+ if gt is not None:
490
+ xs = torch.cat(xs, 1)
491
+ else:
492
+ xs = None
493
+
494
+ if self.logger and self.log_video:
495
+ log_video(
496
+ xs_pred,
497
+ xs,
498
+ step=None if namespace == "test" else self.global_step,
499
+ namespace=namespace + "_vis",
500
+ context_frames=self.context_frames,
501
+ logger=self.logger.experiment,
502
+ save_local=self.save_local,
503
+ local_save_dir=self.local_save_dir,
504
+ )
505
+
506
+ if xs is not None:
507
+ # Move data to the same device as LPIPS model for metric calculation
508
+ device = next(self.validation_lpips_model.parameters()).device
509
+ xs_pred_device = xs_pred.to(device)
510
+ xs_device = xs.to(device)
511
+
512
+ metric_dict = get_validation_metrics_for_videos(
513
+ xs_pred_device, xs_device,
514
+ lpips_model=self.validation_lpips_model,
515
+ lpips_batch_size=self.lpips_batch_size)
516
+
517
+ self.log_dict(
518
+ {"mse": metric_dict['mse'],
519
+ "psnr": metric_dict['psnr'],
520
+ "lpips": metric_dict['lpips']},
521
+ sync_dist=True
522
+ )
523
+
524
+ if self.log_curve:
525
+ psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
526
+ frames = list(range(len(psnr_values)))
527
+ line_plot = wandb.plot.line_series(
528
+ xs = frames,
529
+ ys = [psnr_values],
530
+ keys = ["PSNR"],
531
+ title = "Frame-wise PSNR",
532
+ xname = "Frame index"
533
+ )
534
+
535
+ self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
536
+
537
+ self.validation_step_outputs.clear()
538
+
539
+ def _preprocess_batch(self, batch):
540
+
541
+ xs, conditions, pose_conditions, frame_index = batch
542
+
543
+ if self.action_cond_dim:
544
+ conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
545
+ conditions = rearrange(conditions, "b t d -> t b d").contiguous()
546
+ else:
547
+ raise NotImplementedError("Only support external cond.")
548
+
549
+ pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
550
+ c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
551
+ xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
552
+ frame_index = rearrange(frame_index, "b t -> t b").contiguous()
553
+
554
+ return xs, conditions, pose_conditions, c2w_mat, frame_index
555
+
556
+ def encode(self, x):
557
+ # vae encoding
558
+ T = x.shape[0]
559
+ H, W = x.shape[-2:]
560
+ scaling_factor = 0.07843137255
561
+
562
+ x = rearrange(x, "t b c h w -> (t b) c h w")
563
+ with torch.no_grad():
564
+ x = self.vae.encode(x * 2 - 1).mean * scaling_factor
565
+ x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
566
+ return x
567
+
568
+ def decode(self, x):
569
+ total_frames = x.shape[0]
570
+ scaling_factor = 0.07843137255
571
+ x = rearrange(x, "t b c h w -> (t b) (h w) c")
572
+ with torch.no_grad():
573
+ x = (self.vae.decode(x / scaling_factor) + 1) / 2
574
+ x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
575
+ return x
576
+
577
+ def _generate_condition_indices(self, curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon):
578
+ """
579
+ Generate indices for condition similarity based on the current frame and pose conditions.
580
+ """
581
+ if curr_frame < memory_condition_length:
582
+ random_idx = [i for i in range(curr_frame)] + [0] * (memory_condition_length - curr_frame)
583
+ random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1)
584
+ else:
585
+ # Generate points in a sphere and filter based on field of view
586
+ num_samples = 10000
587
+ radius = 30
588
+ points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
589
+ points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
590
+ points += pose_conditions[curr_frame, :, :3][None]
591
+ fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device)
592
+ fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device)
593
+
594
+ # in_fov1 = is_inside_fov_3d_hv(
595
+ # points, pose_conditions[curr_frame, :, :3],
596
+ # pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
597
+ # fov_half_h, fov_half_v
598
+ # )
599
+
600
+ in_fov1 = torch.stack([
601
+ is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
602
+ for pc in pose_conditions[curr_frame:curr_frame+horizon]
603
+ ])
604
+
605
+ in_fov1 = torch.sum(in_fov1, 0) > 0
606
+
607
+ # Compute overlap ratios and select indices
608
+ in_fov_list = torch.stack([
609
+ is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
610
+ for pc in pose_conditions[:curr_frame]
611
+ ])
612
+
613
+ random_idx = []
614
+ for _ in range(memory_condition_length):
615
+ overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
616
+
617
+ confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
618
+
619
+ if len(random_idx) > 0:
620
+ confidence[torch.cat(random_idx)] = -1e10
621
+ _, r_idx = torch.topk(confidence, k=1, dim=0)
622
+ random_idx.append(r_idx[0])
623
+
624
+ # choice 1: directly remove overlapping region
625
+ occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
626
+ in_fov1 = in_fov1 & ~occupied_mask
627
+
628
+ # choice 2: apply similarity filter
629
+ # cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
630
+ # range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
631
+ # cos_sim = cos_sim.mean((-2,-1))
632
+
633
+ # mask_sim = cos_sim>0.9
634
+ # in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
635
+
636
+ random_idx = torch.stack(random_idx).cpu()
637
+
638
+ return random_idx
639
+
640
+ def _prepare_conditions(self,
641
+ start_frame, curr_frame, horizon, conditions,
642
+ pose_conditions, c2w_mat, frame_idx, random_idx,
643
+ image_width, image_height):
644
+ """
645
+ Prepare input conditions and pose conditions for sampling.
646
+ """
647
+
648
+ padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
649
+ input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0)
650
+
651
+ batch_size = conditions.shape[1]
652
+
653
+ if self.use_plucker:
654
+ if self.relative_embedding:
655
+ frame_idx_list = []
656
+ input_pose_condition = []
657
+ for i in range(start_frame, curr_frame + horizon):
658
+ input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length,
659
+ image_width=image_width, image_height=image_height).to(conditions.dtype))
660
+ frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]]))
661
+ input_pose_condition = torch.cat(input_pose_condition)
662
+ frame_idx_list = torch.cat(frame_idx_list)
663
+
664
+ else:
665
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
666
+ input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length)
667
+ frame_idx_list = None
668
+ else:
669
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
670
+ frame_idx_list = None
671
+
672
+ return input_condition, input_pose_condition, frame_idx_list
673
+
674
+ def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, memory_condition_length):
675
+ """
676
+ Prepare noise levels for the current sampling step.
677
+ """
678
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1)
679
+ to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1)
680
+ if memory_condition_length:
681
+ from_noise_levels = np.concatenate([from_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
682
+ to_noise_levels = np.concatenate([to_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
683
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
684
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
685
+ return from_noise_levels, to_noise_levels
686
+
687
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
688
+ """
689
+ Perform a single validation step.
690
+
691
+ This function processes the input batch, encodes frames, generates predictions using a sliding window approach,
692
+ and handles condition similarity logic for sampling. The results are decoded and stored for evaluation.
693
+
694
+ Args:
695
+ batch: Input batch of data containing frames, conditions, poses, etc.
696
+ batch_idx: Index of the current batch.
697
+ namespace: Namespace for logging (default: "validation").
698
+
699
+ Returns:
700
+ None: Appends the predicted and ground truth frames to `self.validation_step_outputs`.
701
+ """
702
+ # Preprocess the input batch
703
+ memory_condition_length = self.memory_condition_length
704
+ xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
705
+
706
+
707
+ # Encode frames in chunks if necessary
708
+ total_frame = xs_raw.shape[0]
709
+ if total_frame > 10:
710
+ xs = torch.cat([
711
+ self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu()
712
+ for i in range(10)
713
+ ])
714
+ else:
715
+ xs = self.encode(xs_raw).cpu()
716
+
717
+ n_frames, batch_size, *_ = xs.shape
718
+ curr_frame = 0
719
+
720
+ # Initialize context frames
721
+ n_context_frames = self.context_frames // self.frame_stack
722
+ xs_pred = xs[:n_context_frames].clone()
723
+ curr_frame += n_context_frames
724
+
725
+ # Progress bar for sampling
726
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
727
+
728
+ while curr_frame < n_frames:
729
+ # Determine the horizon for the current chunk
730
+ horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
731
+ assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
732
+
733
+ # Generate scheduling matrix and initialize noise
734
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
735
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
736
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
737
+ xs_pred = torch.cat([xs_pred, chunk], 0)
738
+
739
+ # Sliding window: only input the last `n_tokens` frames
740
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
741
+ pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
742
+
743
+ # Handle condition similarity logic
744
+ if memory_condition_length:
745
+ random_idx = self._generate_condition_indices(
746
+ curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon
747
+ )
748
+
749
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
750
+
751
+ # Prepare input conditions and pose conditions
752
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
753
+ start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
754
+ image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2]
755
+ )
756
+
757
+ # Perform sampling for each step in the scheduling matrix
758
+ for m in range(scheduling_matrix.shape[0] - 1):
759
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
760
+ scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
761
+ )
762
+
763
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
764
+ xs_pred[start_frame:].to(input_condition.device),
765
+ input_condition,
766
+ input_pose_condition,
767
+ from_noise_levels[start_frame:],
768
+ to_noise_levels[start_frame:],
769
+ current_frame=curr_frame,
770
+ mode="validation",
771
+ reference_length=memory_condition_length,
772
+ frame_idx=frame_idx_list
773
+ ).cpu()
774
+
775
+ # Remove condition similarity frames if applicable
776
+ if memory_condition_length:
777
+ xs_pred = xs_pred[:-memory_condition_length]
778
+
779
+ curr_frame += horizon
780
+ pbar.update(horizon)
781
+
782
+ # Decode predictions and ground truth
783
+ xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
784
+ xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
785
+
786
+ # Store results for evaluation (move to CPU to save GPU memory)
787
+ self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
788
+ return
789
+
790
+ @torch.no_grad()
791
+ def interactive(self, first_frame, new_actions, first_pose, device,
792
+ memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
793
+
794
+ memory_condition_length = self.memory_condition_length
795
+
796
+ if memory_latent_frames is None:
797
+ first_frame = torch.from_numpy(first_frame)
798
+ new_actions = torch.from_numpy(new_actions)
799
+ first_pose = torch.from_numpy(first_pose)
800
+ first_frame_encode = self.encode(first_frame[None, None].to(device))
801
+ memory_latent_frames = first_frame_encode.cpu()
802
+ memory_actions = new_actions[None, None].to(device)
803
+ memory_poses = first_pose[None, None].to(device)
804
+ new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
805
+ memory_c2w = new_c2w_mat[None, None].to(device)
806
+ memory_frame_idx = torch.tensor([[0]]).to(device)
807
+ return first_frame.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
808
+ else:
809
+ memory_latent_frames = torch.from_numpy(memory_latent_frames)
810
+ memory_actions = torch.from_numpy(memory_actions).to(device)
811
+ memory_poses = torch.from_numpy(memory_poses).to(device)
812
+ memory_c2w = torch.from_numpy(memory_c2w).to(device)
813
+ memory_frame_idx = torch.from_numpy(memory_frame_idx).to(device)
814
+ new_actions = new_actions.to(device)
815
+
816
+ curr_frame = 0
817
+ batch_size = 1
818
+ horizon = self.next_frame_length
819
+ n_frames = curr_frame + horizon
820
+ # context
821
+ n_context_frames = len(memory_latent_frames)
822
+ xs_pred = memory_latent_frames[:n_context_frames].clone()
823
+ curr_frame += n_context_frames
824
+
825
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
826
+
827
+ new_pose_condition_list = []
828
+ last_frame = xs_pred[-1].clone()
829
+ last_pose_condition = memory_poses[-1].clone()
830
+ curr_actions = new_actions.clone()
831
+ for hi in range(len(new_actions)):
832
+ last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
833
+ new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None, hi], last_pose_condition)
834
+ new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
835
+ new_pose_condition = last_pose_condition + new_pose_condition_offset
836
+ new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
837
+ new_pose_condition[:,3:] %= 360
838
+ last_pose_condition = new_pose_condition.clone()
839
+ new_pose_condition_list.append(new_pose_condition[None])
840
+ new_pose_condition_list = torch.cat(new_pose_condition_list, 0)
841
+
842
+ ai = 0
843
+ while ai < len(new_actions):
844
+ next_horizon = min(horizon, len(new_actions) - ai)
845
+ last_frame = xs_pred[-1].clone()
846
+ curr_actions = new_actions[ai:ai+next_horizon].clone()
847
+
848
+ new_pose_condition = new_pose_condition_list[ai:ai+next_horizon].clone()
849
+
850
+ new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
851
+ memory_poses = torch.cat([memory_poses, new_pose_condition])
852
+ memory_actions = torch.cat([memory_actions, curr_actions[:, None]])
853
+ memory_c2w = torch.cat([memory_c2w, new_c2w_mat])
854
+ new_indices = memory_frame_idx[-1,0] + torch.arange(next_horizon, device=memory_frame_idx.device) + 1
855
+
856
+ memory_frame_idx = torch.cat([memory_frame_idx, new_indices[:, None]])
857
+
858
+ conditions = memory_actions.clone()
859
+ pose_conditions = memory_poses.clone()
860
+ c2w_mat = memory_c2w .clone()
861
+ frame_idx = memory_frame_idx.clone()
862
+
863
+ # generation on frame
864
+ scheduling_matrix = self._generate_scheduling_matrix(next_horizon)
865
+ chunk = torch.randn((next_horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
866
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
867
+
868
+ xs_pred = torch.cat([xs_pred, chunk], 0)
869
+
870
+ # sliding window: only input the last n_tokens frames
871
+ start_frame = max(0, curr_frame - self.n_tokens)
872
+
873
+ pbar.set_postfix(
874
+ {
875
+ "start": start_frame,
876
+ "end": curr_frame + next_horizon,
877
+ }
878
+ )
879
+
880
+ # Handle condition similarity logic
881
+ if memory_condition_length:
882
+ random_idx = self._generate_condition_indices(
883
+ curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, next_horizon
884
+ )
885
+
886
+ # random_idx = np.unique(random_idx)[:, None]
887
+ # memory_condition_length = len(random_idx)
888
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
889
+
890
+ # Prepare input conditions and pose conditions
891
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
892
+ start_frame, curr_frame, next_horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
893
+ image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
894
+ )
895
+
896
+ # Perform sampling for each step in the scheduling matrix
897
+ for m in range(scheduling_matrix.shape[0] - 1):
898
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
899
+ scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
900
+ )
901
+
902
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
903
+ xs_pred[start_frame:].to(input_condition.device),
904
+ input_condition,
905
+ input_pose_condition,
906
+ from_noise_levels[start_frame:],
907
+ to_noise_levels[start_frame:],
908
+ current_frame=curr_frame,
909
+ mode="validation",
910
+ reference_length=memory_condition_length,
911
+ frame_idx=frame_idx_list
912
+ ).cpu()
913
+
914
+
915
+ if memory_condition_length:
916
+ xs_pred = xs_pred[:-memory_condition_length]
917
+
918
+ curr_frame += next_horizon
919
+ pbar.update(next_horizon)
920
+ ai += next_horizon
921
+
922
+ memory_latent_frames = torch.cat([memory_latent_frames, xs_pred[n_context_frames:]])
923
+ xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
924
+
925
+ return xs_pred.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), \
926
+ memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
algorithms/worldmem/models/attention.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
3
+ """
4
+
5
+ from typing import Optional
6
+ from collections import namedtuple
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from einops import rearrange
11
+ from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
12
+ import numpy as np
13
+
14
+ class TemporalAxialAttention(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ heads: int,
19
+ dim_head: int,
20
+ reference_length: int,
21
+ rotary_emb: RotaryEmbedding,
22
+ is_causal: bool = True,
23
+ is_temporal_independent: bool = False,
24
+ use_domain_adapter = False
25
+ ):
26
+ super().__init__()
27
+ self.inner_dim = dim_head * heads
28
+ self.heads = heads
29
+ self.head_dim = dim_head
30
+ self.inner_dim = dim_head * heads
31
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
32
+
33
+ self.use_domain_adapter = use_domain_adapter
34
+ if self.use_domain_adapter:
35
+ lora_rank = 8
36
+ self.lora_A = nn.Linear(dim, lora_rank, bias=False)
37
+ self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
38
+
39
+ self.to_out = nn.Linear(self.inner_dim, dim)
40
+
41
+ self.rotary_emb = rotary_emb
42
+ self.is_causal = is_causal
43
+ self.is_temporal_independent = is_temporal_independent
44
+
45
+ self.reference_length = reference_length
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ B, T, H, W, D = x.shape
49
+
50
+ # if T>=9:
51
+ # try:
52
+ # # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1)
53
+ # x = torch.cat([x[:,16-T:17-T],x], dim=1)
54
+ # except:
55
+ # import pdb;pdb.set_trace()
56
+ # print("="*50)
57
+ # print(x.shape)
58
+
59
+ B, T, H, W, D = x.shape
60
+
61
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
62
+
63
+ if self.use_domain_adapter:
64
+ q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
65
+ q = q+q_lora
66
+ k = k+k_lora
67
+ v = v+v_lora
68
+
69
+ q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
70
+ k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
71
+ v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
72
+
73
+ q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
74
+ k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
75
+
76
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
77
+
78
+ if self.is_temporal_independent:
79
+ attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device)
80
+ attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
81
+ attn_bias[range(T), range(T)] = 0
82
+ elif self.is_causal:
83
+ attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1)
84
+ attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
85
+ attn_bias[(T-self.reference_length):] = float('-inf')
86
+ attn_bias[range(T), range(T)] = 0
87
+ else:
88
+ attn_bias = None
89
+
90
+ try:
91
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
92
+ except:
93
+ import pdb;pdb.set_trace()
94
+
95
+ x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
96
+ x = x.to(q.dtype)
97
+
98
+ # linear proj
99
+ x = self.to_out(x)
100
+
101
+ # if T>=10:
102
+ # try:
103
+ # # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1)
104
+ # x = x[:,1:]
105
+ # except:
106
+ # import pdb;pdb.set_trace()
107
+ # print(x.shape)
108
+ return x
109
+
110
+ class SpatialAxialAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ heads: int,
115
+ dim_head: int,
116
+ rotary_emb: RotaryEmbedding,
117
+ use_domain_adapter = False
118
+ ):
119
+ super().__init__()
120
+ self.inner_dim = dim_head * heads
121
+ self.heads = heads
122
+ self.head_dim = dim_head
123
+ self.inner_dim = dim_head * heads
124
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
125
+ self.use_domain_adapter = use_domain_adapter
126
+ if self.use_domain_adapter:
127
+ lora_rank = 8
128
+ self.lora_A = nn.Linear(dim, lora_rank, bias=False)
129
+ self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
130
+
131
+ self.to_out = nn.Linear(self.inner_dim, dim)
132
+
133
+ self.rotary_emb = rotary_emb
134
+
135
+ def forward(self, x: torch.Tensor):
136
+ B, T, H, W, D = x.shape
137
+
138
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
139
+
140
+ if self.use_domain_adapter:
141
+ q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
142
+ q = q+q_lora
143
+ k = k+k_lora
144
+ v = v+v_lora
145
+
146
+ q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
147
+ k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
148
+ v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
149
+
150
+ freqs = self.rotary_emb.get_axial_freqs(H, W)
151
+ q = apply_rotary_emb(freqs, q)
152
+ k = apply_rotary_emb(freqs, k)
153
+
154
+ # prepare for attn
155
+ q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
156
+ k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
157
+ v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
158
+
159
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False)
160
+
161
+ x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
162
+ x = x.to(q.dtype)
163
+
164
+ # linear proj
165
+ x = self.to_out(x)
166
+ return x
167
+
168
+ class MemTemporalAxialAttention(nn.Module):
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ heads: int,
173
+ dim_head: int,
174
+ rotary_emb: RotaryEmbedding,
175
+ is_causal: bool = True,
176
+ ):
177
+ super().__init__()
178
+ self.inner_dim = dim_head * heads
179
+ self.heads = heads
180
+ self.head_dim = dim_head
181
+ self.inner_dim = dim_head * heads
182
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
183
+ self.to_out = nn.Linear(self.inner_dim, dim)
184
+
185
+ self.rotary_emb = rotary_emb
186
+ self.is_causal = is_causal
187
+
188
+ self.reference_length = 3
189
+
190
+ def forward(self, x: torch.Tensor):
191
+ B, T, H, W, D = x.shape
192
+
193
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
194
+
195
+
196
+ q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
197
+ k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
198
+ v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
199
+
200
+
201
+
202
+ # q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
203
+ # k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
204
+
205
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
206
+
207
+ # if T == 21000:
208
+ # # 手动计算缩放点积分数
209
+ # _, _, _, d_k = q.shape
210
+ # scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k)
211
+
212
+ # # 计算注意力图 (Attention Map)
213
+ # attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k)
214
+ # b_, h_, n_, m_ = attention_map.shape
215
+ # attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_)
216
+ # attention_map = attention_map.mean(3)
217
+
218
+ # attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
219
+ # T_origin = T - self.reference_length
220
+ # attn_bias[:T_origin, T_origin:] = 1
221
+ # attn_bias[range(T), range(T)] = 1
222
+
223
+ # attention_map = attention_map * attn_bias
224
+
225
+ # # print 注意力图
226
+ # import matplotlib.pyplot as plt
227
+ # fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小
228
+
229
+ # # 遍历3*3维度
230
+ # for i in range(21000):
231
+ # for j in range(21000):
232
+ # # 取出第(i, j)个子图像
233
+ # img = attention_map[0, :, :, i, j].cpu().numpy()
234
+ # axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap
235
+ # axes[i, j].axis('off') # 隐藏坐标轴
236
+
237
+ # # 调整子图间距
238
+ # plt.tight_layout()
239
+ # plt.savefig('attention_map.png')
240
+ # import pdb; pdb.set_trace()
241
+ # plt.close()
242
+
243
+ attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
244
+ attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf'))
245
+ T_origin = T - self.reference_length
246
+ attn_bias[:T_origin, T_origin:] = 0
247
+ attn_bias[range(T), range(T)] = 0
248
+
249
+ # if T==121000:
250
+ # import pdb;pdb.set_trace()
251
+
252
+ try:
253
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
254
+ except:
255
+ import pdb;pdb.set_trace()
256
+
257
+ x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
258
+ x = x.to(q.dtype)
259
+
260
+ # linear proj
261
+ x = self.to_out(x)
262
+ return x
263
+
264
+ class MemFullAttention(nn.Module):
265
+ def __init__(
266
+ self,
267
+ dim: int,
268
+ heads: int,
269
+ dim_head: int,
270
+ reference_length: int,
271
+ rotary_emb: RotaryEmbedding,
272
+ is_causal: bool = True
273
+ ):
274
+ super().__init__()
275
+ self.inner_dim = dim_head * heads
276
+ self.heads = heads
277
+ self.head_dim = dim_head
278
+ self.inner_dim = dim_head * heads
279
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
280
+ self.to_out = nn.Linear(self.inner_dim, dim)
281
+
282
+ self.rotary_emb = rotary_emb
283
+ self.is_causal = is_causal
284
+
285
+ self.reference_length = reference_length
286
+
287
+ self.store = None
288
+
289
+ def forward(self, x: torch.Tensor, relative_embedding=False,
290
+ extra_condition=None,
291
+ state_embed_only_on_qk=False,
292
+ reference_length=None):
293
+
294
+ B, T, H, W, D = x.shape
295
+
296
+ if state_embed_only_on_qk:
297
+ q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1)
298
+ _, _, v = self.to_qkv(x).chunk(3, dim=-1)
299
+ else:
300
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
301
+
302
+ if relative_embedding:
303
+ length = reference_length+1
304
+ n_frames = T // length
305
+ x = x.reshape(B, n_frames, length, H, W, D)
306
+
307
+ x_list = []
308
+
309
+ for i in range(n_frames):
310
+ if i == n_frames-1:
311
+ q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
312
+ k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
313
+ v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
314
+ else:
315
+ q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads)
316
+ k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
317
+ v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
318
+
319
+ q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i))
320
+ x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i)
321
+ x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
322
+ x_i = x_i.to(q.dtype)
323
+ x_list.append(x_i)
324
+
325
+ x = torch.cat(x_list, dim=1)
326
+
327
+
328
+ else:
329
+ T_ = T - reference_length
330
+ q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads)
331
+ k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
332
+ v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
333
+
334
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
335
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v)
336
+ x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
337
+ x = x.to(q.dtype)
338
+
339
+ # linear proj
340
+ x = self.to_out(x)
341
+
342
+ return x
algorithms/worldmem/models/cameractrl_module.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ class SimpleCameraPoseEncoder(nn.Module):
3
+ def __init__(self, c_in, c_out, hidden_dim=128):
4
+ super(SimpleCameraPoseEncoder, self).__init__()
5
+ self.model = nn.Sequential(
6
+ nn.Linear(c_in, hidden_dim),
7
+ nn.ReLU(),
8
+ nn.Linear(hidden_dim, c_out)
9
+ )
10
+ def forward(self, x):
11
+ return self.model(x)
12
+
algorithms/worldmem/models/diffusion.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable
2
+ from collections import namedtuple
3
+ from omegaconf import DictConfig
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from einops import rearrange
8
+ from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract
9
+ from .dit import DiT_models
10
+
11
+ ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"])
12
+
13
+
14
+ class Diffusion(nn.Module):
15
+ # Special thanks to lucidrains for the implementation of the base Diffusion model
16
+ # https://github.com/lucidrains/denoising-diffusion-pytorch
17
+
18
+ def __init__(
19
+ self,
20
+ x_shape: torch.Size,
21
+ reference_length: int,
22
+ action_cond_dim: int,
23
+ pose_cond_dim,
24
+ is_causal: bool,
25
+ cfg: DictConfig,
26
+ is_dit: bool=False,
27
+ use_plucker=False,
28
+ relative_embedding=False,
29
+ state_embed_only_on_qk=False,
30
+ use_memory_attention=False,
31
+ add_timestamp_embedding=False,
32
+ ref_mode='sequential'
33
+ ):
34
+ super().__init__()
35
+ self.cfg = cfg
36
+
37
+ self.x_shape = x_shape
38
+ self.action_cond_dim = action_cond_dim
39
+ self.timesteps = cfg.timesteps
40
+ self.sampling_timesteps = cfg.sampling_timesteps
41
+ self.beta_schedule = cfg.beta_schedule
42
+ self.schedule_fn_kwargs = cfg.schedule_fn_kwargs
43
+ self.objective = cfg.objective
44
+ self.use_fused_snr = cfg.use_fused_snr
45
+ self.snr_clip = cfg.snr_clip
46
+ self.cum_snr_decay = cfg.cum_snr_decay
47
+ self.ddim_sampling_eta = cfg.ddim_sampling_eta
48
+ self.clip_noise = cfg.clip_noise
49
+ self.arch = cfg.architecture
50
+ self.stabilization_level = cfg.stabilization_level
51
+ self.is_causal = is_causal
52
+ self.is_dit = is_dit
53
+ self.reference_length = reference_length
54
+ self.pose_cond_dim = pose_cond_dim
55
+ self.use_plucker = use_plucker
56
+ self.relative_embedding = relative_embedding
57
+ self.state_embed_only_on_qk = state_embed_only_on_qk
58
+ self.use_memory_attention = use_memory_attention
59
+ self.add_timestamp_embedding = add_timestamp_embedding
60
+ self.ref_mode = ref_mode
61
+
62
+ self._build_model()
63
+ self._build_buffer()
64
+
65
+ def _build_model(self):
66
+ x_channel = self.x_shape[0]
67
+ if self.is_dit:
68
+ self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim,
69
+ pose_cond_dim=self.pose_cond_dim, reference_length=self.reference_length,
70
+ use_plucker=self.use_plucker,
71
+ relative_embedding=self.relative_embedding,
72
+ state_embed_only_on_qk=self.state_embed_only_on_qk,
73
+ use_memory_attention=self.use_memory_attention,
74
+ add_timestamp_embedding=self.add_timestamp_embedding,
75
+ ref_mode=self.ref_mode)
76
+ else:
77
+ raise NotImplementedError
78
+
79
+ def _build_buffer(self):
80
+ if self.beta_schedule == "linear":
81
+ beta_schedule_fn = linear_beta_schedule
82
+ elif self.beta_schedule == "cosine":
83
+ beta_schedule_fn = cosine_beta_schedule
84
+ elif self.beta_schedule == "sigmoid":
85
+ beta_schedule_fn = sigmoid_beta_schedule
86
+ else:
87
+ raise ValueError(f"unknown beta schedule {self.beta_schedule}")
88
+
89
+ betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs)
90
+
91
+ alphas = 1.0 - betas
92
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
93
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
94
+
95
+ # sampling related parameters
96
+ assert self.sampling_timesteps <= self.timesteps
97
+ self.is_ddim_sampling = self.sampling_timesteps < self.timesteps
98
+
99
+ # helper function to register buffer from float64 to float32
100
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
101
+
102
+ register_buffer("betas", betas)
103
+ register_buffer("alphas_cumprod", alphas_cumprod)
104
+ register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
105
+
106
+ # calculations for diffusion q(x_t | x_{t-1}) and others
107
+
108
+ register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
109
+ register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
110
+ register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
111
+ register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
112
+ register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
113
+
114
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
115
+
116
+ posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
117
+
118
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
119
+
120
+ register_buffer("posterior_variance", posterior_variance)
121
+
122
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
123
+
124
+ register_buffer(
125
+ "posterior_log_variance_clipped",
126
+ torch.log(posterior_variance.clamp(min=1e-20)),
127
+ )
128
+ register_buffer(
129
+ "posterior_mean_coef1",
130
+ betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
131
+ )
132
+ register_buffer(
133
+ "posterior_mean_coef2",
134
+ (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
135
+ )
136
+
137
+ # calculate p2 reweighting
138
+
139
+ # register_buffer(
140
+ # "p2_loss_weight",
141
+ # (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
142
+ # ** -self.p2_loss_weight_gamma,
143
+ # )
144
+
145
+ # derive loss weight
146
+ # https://arxiv.org/abs/2303.09556
147
+ # snr: signal noise ratio
148
+ snr = alphas_cumprod / (1 - alphas_cumprod)
149
+ clipped_snr = snr.clone()
150
+ clipped_snr.clamp_(max=self.snr_clip)
151
+
152
+ register_buffer("clipped_snr", clipped_snr)
153
+ register_buffer("snr", snr)
154
+
155
+ def add_shape_channels(self, x):
156
+ return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}")
157
+
158
+ def model_predictions(self, x, t, action_cond=None, current_frame=None,
159
+ pose_cond=None, mode="training", reference_length=None, frame_idx=None):
160
+ x = x.permute(1,0,2,3,4)
161
+ action_cond = action_cond.permute(1,0,2)
162
+ if pose_cond is not None and pose_cond[0] is not None:
163
+ try:
164
+ pose_cond = pose_cond.permute(1,0,2)
165
+ except:
166
+ pass
167
+ t = t.permute(1,0)
168
+ model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond,
169
+ mode=mode, reference_length=reference_length, frame_idx=frame_idx)
170
+ model_output = model_output.permute(1,0,2,3,4)
171
+ x = x.permute(1,0,2,3,4)
172
+ t = t.permute(1,0)
173
+
174
+ if self.objective == "pred_noise":
175
+ pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
176
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
177
+
178
+ elif self.objective == "pred_x0":
179
+ x_start = model_output
180
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
181
+
182
+ elif self.objective == "pred_v":
183
+ v = model_output
184
+ x_start = self.predict_start_from_v(x, t, v)
185
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
186
+
187
+
188
+ return ModelPrediction(pred_noise, x_start, model_output)
189
+
190
+ def predict_start_from_noise(self, x_t, t, noise):
191
+ return (
192
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
193
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
194
+ )
195
+
196
+ def predict_noise_from_start(self, x_t, t, x0):
197
+ return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract(
198
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
199
+ )
200
+
201
+ def predict_v(self, x_start, t, noise):
202
+ return (
203
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
204
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
205
+ )
206
+
207
+ def predict_start_from_v(self, x_t, t, v):
208
+ return (
209
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
210
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
211
+ )
212
+
213
+ def q_mean_variance(self, x_start, t):
214
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
215
+ variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
216
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
217
+ return mean, variance, log_variance
218
+
219
+ def q_posterior(self, x_start, x_t, t):
220
+ posterior_mean = (
221
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
222
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
223
+ )
224
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
225
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
226
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
227
+
228
+ def q_sample(self, x_start, t, noise=None):
229
+ if noise is None:
230
+ noise = torch.randn_like(x_start)
231
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
232
+ return (
233
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
234
+ + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
235
+ )
236
+
237
+ def p_mean_variance(self, x, t, action_cond=None, pose_cond=None, reference_length=None):
238
+ model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond,
239
+ pose_cond=pose_cond, reference_length=reference_length,
240
+ frame_idx=frame_idx)
241
+ x_start = model_pred.pred_x_start
242
+ return self.q_posterior(x_start=x_start, x_t=x, t=t)
243
+
244
+ def compute_loss_weights(self, noise_levels: torch.Tensor):
245
+
246
+ snr = self.snr[noise_levels]
247
+ clipped_snr = self.clipped_snr[noise_levels]
248
+ normalized_clipped_snr = clipped_snr / self.snr_clip
249
+ normalized_snr = snr / self.snr_clip
250
+
251
+ if not self.use_fused_snr:
252
+ # min SNR reweighting
253
+ match self.objective:
254
+ case "pred_noise":
255
+ return clipped_snr / snr
256
+ case "pred_x0":
257
+ return clipped_snr
258
+ case "pred_v":
259
+ return clipped_snr / (snr + 1)
260
+
261
+ cum_snr = torch.zeros_like(normalized_snr)
262
+ for t in range(0, noise_levels.shape[0]):
263
+ if t == 0:
264
+ cum_snr[t] = normalized_clipped_snr[t]
265
+ else:
266
+ cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t]
267
+
268
+ cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0)
269
+ clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr)
270
+ fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr)
271
+
272
+ match self.objective:
273
+ case "pred_noise":
274
+ return clipped_fused_snr / fused_snr
275
+ case "pred_x0":
276
+ return clipped_fused_snr * self.snr_clip
277
+ case "pred_v":
278
+ return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1)
279
+ case _:
280
+ raise ValueError(f"unknown objective {self.objective}")
281
+
282
+ def forward(
283
+ self,
284
+ x: torch.Tensor,
285
+ action_cond: Optional[torch.Tensor],
286
+ pose_cond,
287
+ noise_levels: torch.Tensor,
288
+ reference_length,
289
+ frame_idx=None
290
+ ):
291
+ noise = torch.randn_like(x)
292
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
293
+
294
+ noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise)
295
+
296
+ model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond,
297
+ pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx)
298
+
299
+ pred = model_pred.model_out
300
+ x_pred = model_pred.pred_x_start
301
+
302
+ if self.objective == "pred_noise":
303
+ target = noise
304
+ elif self.objective == "pred_x0":
305
+ target = x
306
+ elif self.objective == "pred_v":
307
+ target = self.predict_v(x, noise_levels, noise)
308
+ else:
309
+ raise ValueError(f"unknown objective {self.objective}")
310
+
311
+ # 训练的时候每个frame随便给噪声
312
+ loss = F.mse_loss(pred, target.detach(), reduction="none")
313
+ loss_weight = self.compute_loss_weights(noise_levels)
314
+
315
+ loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2)))
316
+
317
+ loss = loss * loss_weight
318
+
319
+ return x_pred, loss
320
+
321
+ def sample_step(
322
+ self,
323
+ x: torch.Tensor,
324
+ action_cond: Optional[torch.Tensor],
325
+ pose_cond,
326
+ curr_noise_level: torch.Tensor,
327
+ next_noise_level: torch.Tensor,
328
+ guidance_fn: Optional[Callable] = None,
329
+ current_frame=None,
330
+ mode="training",
331
+ reference_length=None,
332
+ frame_idx=None
333
+ ):
334
+ real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long()
335
+
336
+ # convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1)
337
+ curr_noise_level = real_steps[curr_noise_level]
338
+ next_noise_level = real_steps[next_noise_level]
339
+
340
+ if self.is_ddim_sampling:
341
+ return self.ddim_sample_step(
342
+ x=x,
343
+ action_cond=action_cond,
344
+ pose_cond=pose_cond,
345
+ curr_noise_level=curr_noise_level,
346
+ next_noise_level=next_noise_level,
347
+ guidance_fn=guidance_fn,
348
+ current_frame=current_frame,
349
+ mode=mode,
350
+ reference_length=reference_length,
351
+ frame_idx=frame_idx
352
+ )
353
+
354
+ # FIXME: temporary code for checking ddpm sampling
355
+ assert torch.all(
356
+ (curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1))
357
+ ), "Wrong noise level given for ddpm sampling."
358
+
359
+ assert (
360
+ self.sampling_timesteps == self.timesteps
361
+ ), "sampling_timesteps should be equal to timesteps for ddpm sampling."
362
+
363
+ return self.ddpm_sample_step(
364
+ x=x,
365
+ action_cond=action_cond,
366
+ pose_cond=pose_cond,
367
+ curr_noise_level=curr_noise_level,
368
+ guidance_fn=guidance_fn,
369
+ reference_length=reference_length,
370
+ frame_idx=frame_idx
371
+ )
372
+
373
+ def ddpm_sample_step(
374
+ self,
375
+ x: torch.Tensor,
376
+ action_cond: Optional[torch.Tensor],
377
+ pose_cond,
378
+ curr_noise_level: torch.Tensor,
379
+ guidance_fn: Optional[Callable] = None,
380
+ reference_length=None,
381
+ frame_idx=None,
382
+ ):
383
+ clipped_curr_noise_level = torch.where(
384
+ curr_noise_level < 0,
385
+ torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
386
+ curr_noise_level,
387
+ )
388
+
389
+ # treating as stabilization would require us to scale with sqrt of alpha_cum
390
+ orig_x = x.clone().detach()
391
+ scaled_context = self.q_sample(
392
+ x,
393
+ clipped_curr_noise_level,
394
+ noise=torch.zeros_like(x),
395
+ )
396
+ x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
397
+
398
+ if guidance_fn is not None:
399
+ raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.")
400
+
401
+ else:
402
+ model_mean, _, model_log_variance = self.p_mean_variance(
403
+ x=x,
404
+ t=clipped_curr_noise_level,
405
+ action_cond=action_cond,
406
+ pose_cond=pose_cond,
407
+ reference_length=reference_length,
408
+ frame_idx=frame_idx
409
+ )
410
+
411
+ noise = torch.where(
412
+ self.add_shape_channels(clipped_curr_noise_level > 0),
413
+ torch.randn_like(x),
414
+ 0,
415
+ )
416
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
417
+ x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise
418
+
419
+ # only update frames where the noise level decreases
420
+ return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred)
421
+
422
+ def ddim_sample_step(
423
+ self,
424
+ x: torch.Tensor,
425
+ action_cond: Optional[torch.Tensor],
426
+ pose_cond,
427
+ curr_noise_level: torch.Tensor,
428
+ next_noise_level: torch.Tensor,
429
+ guidance_fn: Optional[Callable] = None,
430
+ current_frame=None,
431
+ mode="training",
432
+ reference_length=None,
433
+ frame_idx=None
434
+ ):
435
+ # convert noise level -1 to self.stabilization_level - 1
436
+ clipped_curr_noise_level = torch.where(
437
+ curr_noise_level < 0,
438
+ torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
439
+ curr_noise_level,
440
+ )
441
+
442
+ # treating as stabilization would require us to scale with sqrt of alpha_cum
443
+ orig_x = x.clone().detach()
444
+ scaled_context = self.q_sample(
445
+ x,
446
+ clipped_curr_noise_level,
447
+ noise=torch.zeros_like(x),
448
+ )
449
+ x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
450
+
451
+ alpha = self.alphas_cumprod[clipped_curr_noise_level]
452
+ alpha_next = torch.where(
453
+ next_noise_level < 0,
454
+ torch.ones_like(next_noise_level),
455
+ self.alphas_cumprod[next_noise_level],
456
+ )
457
+ sigma = torch.where(
458
+ next_noise_level < 0,
459
+ torch.zeros_like(next_noise_level),
460
+ self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(),
461
+ )
462
+ c = (1 - alpha_next - sigma**2).sqrt()
463
+
464
+ alpha_next = self.add_shape_channels(alpha_next)
465
+ c = self.add_shape_channels(c)
466
+ sigma = self.add_shape_channels(sigma)
467
+
468
+ if guidance_fn is not None:
469
+ with torch.enable_grad():
470
+ x = x.detach().requires_grad_()
471
+
472
+ model_pred = self.model_predictions(
473
+ x=x,
474
+ t=clipped_curr_noise_level,
475
+ action_cond=action_cond,
476
+ pose_cond=pose_cond,
477
+ current_frame=current_frame,
478
+ mode=mode,
479
+ reference_length=reference_length,
480
+ frame_idx=frame_idx
481
+ )
482
+
483
+ guidance_loss = guidance_fn(model_pred.pred_x_start)
484
+ grad = -torch.autograd.grad(
485
+ guidance_loss,
486
+ x,
487
+ )[0]
488
+
489
+ pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad
490
+ x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise)
491
+
492
+ else:
493
+ # print(clipped_curr_noise_level)
494
+ model_pred = self.model_predictions(
495
+ x=x,
496
+ t=clipped_curr_noise_level,
497
+ action_cond=action_cond,
498
+ pose_cond=pose_cond,
499
+ current_frame=current_frame,
500
+ mode=mode,
501
+ reference_length=reference_length,
502
+ frame_idx=frame_idx
503
+ )
504
+ x_start = model_pred.pred_x_start
505
+ pred_noise = model_pred.pred_noise
506
+
507
+ noise = torch.randn_like(x)
508
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
509
+
510
+ x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise
511
+
512
+ # only update frames where the noise level decreases
513
+ mask = curr_noise_level == next_noise_level
514
+ x_pred = torch.where(
515
+ self.add_shape_channels(mask),
516
+ orig_x,
517
+ x_pred,
518
+ )
519
+
520
+ return x_pred
algorithms/worldmem/models/dit.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
5
+ - Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
6
+ """
7
+
8
+ from typing import Optional, Literal
9
+ import torch
10
+ from torch import nn
11
+ from .rotary_embedding_torch import RotaryEmbedding
12
+ from einops import rearrange
13
+ from .attention import SpatialAxialAttention, TemporalAxialAttention, MemTemporalAxialAttention, MemFullAttention
14
+ from timm.models.vision_transformer import Mlp
15
+ from timm.layers.helpers import to_2tuple
16
+ import math
17
+ from collections import namedtuple
18
+ from typing import Optional, Callable
19
+ from .cameractrl_module import SimpleCameraPoseEncoder
20
+
21
+ def modulate(x, shift, scale):
22
+ fixed_dims = [1] * len(shift.shape[1:])
23
+ shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
24
+ scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
25
+ while shift.dim() < x.dim():
26
+ shift = shift.unsqueeze(-2)
27
+ scale = scale.unsqueeze(-2)
28
+ return x * (1 + scale) + shift
29
+
30
+ def gate(x, g):
31
+ fixed_dims = [1] * len(g.shape[1:])
32
+ g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
33
+ while g.dim() < x.dim():
34
+ g = g.unsqueeze(-2)
35
+ return g * x
36
+
37
+
38
+ class PatchEmbed(nn.Module):
39
+ """2D Image to Patch Embedding"""
40
+
41
+ def __init__(
42
+ self,
43
+ img_height=256,
44
+ img_width=256,
45
+ patch_size=16,
46
+ in_chans=3,
47
+ embed_dim=768,
48
+ norm_layer=None,
49
+ flatten=True,
50
+ ):
51
+ super().__init__()
52
+ img_size = (img_height, img_width)
53
+ patch_size = to_2tuple(patch_size)
54
+ self.img_size = img_size
55
+ self.patch_size = patch_size
56
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
57
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
58
+ self.flatten = flatten
59
+
60
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
61
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
62
+
63
+ def forward(self, x, random_sample=False):
64
+ B, C, H, W = x.shape
65
+ assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
66
+
67
+ x = self.proj(x)
68
+ if self.flatten:
69
+ x = rearrange(x, "B C H W -> B (H W) C")
70
+ else:
71
+ x = rearrange(x, "B C H W -> B H W C")
72
+ x = self.norm(x)
73
+ return x
74
+
75
+
76
+ class TimestepEmbedder(nn.Module):
77
+ """
78
+ Embeds scalar timesteps into vector representations.
79
+ """
80
+
81
+ def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'):
82
+ super().__init__()
83
+ self.mlp = nn.Sequential(
84
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
85
+ nn.SiLU(),
86
+ nn.Linear(hidden_size, hidden_size, bias=True),
87
+ )
88
+ self.frequency_embedding_size = frequency_embedding_size
89
+ self.freq_type = freq_type
90
+
91
+ @staticmethod
92
+ def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'):
93
+ """
94
+ Create sinusoidal timestep embeddings.
95
+ :param t: a 1-D Tensor of N indices, one per batch element.
96
+ These may be fractional.
97
+ :param dim: the dimension of the output.
98
+ :param max_period: controls the minimum frequency of the embeddings.
99
+ :return: an (N, D) Tensor of positional embeddings.
100
+ """
101
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
102
+ half = dim // 2
103
+
104
+ if freq_type == 'time_step':
105
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
106
+ elif freq_type == 'spatial': # ~(-5 5)
107
+ freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi
108
+ elif freq_type == 'angle': # 0-360
109
+ freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180
110
+
111
+
112
+ args = t[:, None].float() * freqs[None]
113
+
114
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
115
+ if dim % 2:
116
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
117
+ return embedding
118
+
119
+ def forward(self, t):
120
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type)
121
+ t_emb = self.mlp(t_freq)
122
+ return t_emb
123
+
124
+
125
+ class FinalLayer(nn.Module):
126
+ """
127
+ The final layer of DiT.
128
+ """
129
+
130
+ def __init__(self, hidden_size, patch_size, out_channels):
131
+ super().__init__()
132
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
133
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
134
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
135
+
136
+ def forward(self, x, c):
137
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
138
+ x = modulate(self.norm_final(x), shift, scale)
139
+ x = self.linear(x)
140
+ return x
141
+
142
+
143
+ class SpatioTemporalDiTBlock(nn.Module):
144
+ def __init__(
145
+ self,
146
+ hidden_size,
147
+ num_heads,
148
+ reference_length,
149
+ mlp_ratio=4.0,
150
+ is_causal=True,
151
+ spatial_rotary_emb: Optional[RotaryEmbedding] = None,
152
+ temporal_rotary_emb: Optional[RotaryEmbedding] = None,
153
+ reference_rotary_emb=None,
154
+ use_plucker=False,
155
+ relative_embedding=False,
156
+ state_embed_only_on_qk=False,
157
+ use_memory_attention=False,
158
+ ref_mode='sequential'
159
+ ):
160
+ super().__init__()
161
+ self.is_causal = is_causal
162
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
163
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
164
+
165
+ self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
166
+ self.s_attn = SpatialAxialAttention(
167
+ hidden_size,
168
+ heads=num_heads,
169
+ dim_head=hidden_size // num_heads,
170
+ rotary_emb=spatial_rotary_emb
171
+ )
172
+ self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
173
+ self.s_mlp = Mlp(
174
+ in_features=hidden_size,
175
+ hidden_features=mlp_hidden_dim,
176
+ act_layer=approx_gelu,
177
+ drop=0,
178
+ )
179
+ self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
180
+
181
+ self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
182
+ self.t_attn = TemporalAxialAttention(
183
+ hidden_size,
184
+ heads=num_heads,
185
+ dim_head=hidden_size // num_heads,
186
+ is_causal=is_causal,
187
+ rotary_emb=temporal_rotary_emb,
188
+ reference_length=reference_length
189
+ )
190
+ self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
191
+ self.t_mlp = Mlp(
192
+ in_features=hidden_size,
193
+ hidden_features=mlp_hidden_dim,
194
+ act_layer=approx_gelu,
195
+ drop=0,
196
+ )
197
+ self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
198
+
199
+ self.use_memory_attention = use_memory_attention
200
+ if self.use_memory_attention:
201
+ self.r_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
202
+ self.ref_type = "full_ref"
203
+ if self.ref_type == "temporal_ref":
204
+ self.r_attn = MemTemporalAxialAttention(
205
+ hidden_size,
206
+ heads=num_heads,
207
+ dim_head=hidden_size // num_heads,
208
+ is_causal=is_causal,
209
+ rotary_emb=None
210
+ )
211
+ elif self.ref_type == "full_ref":
212
+ self.r_attn = MemFullAttention(
213
+ hidden_size,
214
+ heads=num_heads,
215
+ dim_head=hidden_size // num_heads,
216
+ is_causal=is_causal,
217
+ rotary_emb=reference_rotary_emb,
218
+ reference_length=reference_length
219
+ )
220
+ self.r_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
221
+ self.r_mlp = Mlp(
222
+ in_features=hidden_size,
223
+ hidden_features=mlp_hidden_dim,
224
+ act_layer=approx_gelu,
225
+ drop=0,
226
+ )
227
+ self.r_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
228
+
229
+ self.use_plucker = use_plucker
230
+ if use_plucker:
231
+ self.pose_cond_mlp = nn.Linear(hidden_size, hidden_size)
232
+ self.temporal_pose_cond_mlp = nn.Linear(hidden_size, hidden_size)
233
+
234
+ self.reference_length = reference_length
235
+ self.relative_embedding = relative_embedding
236
+ self.state_embed_only_on_qk = state_embed_only_on_qk
237
+
238
+ self.ref_mode = ref_mode
239
+
240
+ if self.ref_mode == 'parallel':
241
+ self.parallel_map = nn.Linear(hidden_size, hidden_size)
242
+
243
+ def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False,
244
+ pose_cond=None, mode="training", c_action_cond=None, reference_length=None):
245
+ B, T, H, W, D = x.shape
246
+
247
+ # spatial block
248
+
249
+ s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
250
+ x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
251
+ x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
252
+
253
+ # temporal block
254
+ if c_action_cond is not None:
255
+ t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1)
256
+ else:
257
+ t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
258
+
259
+ x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
260
+ x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
261
+
262
+ if self.ref_mode == 'sequential':
263
+ x = x_t
264
+
265
+ # memory block
266
+ relative_embedding = self.relative_embedding # and mode == "training"
267
+
268
+ if self.use_memory_attention:
269
+ r_shift_msa, r_scale_msa, r_gate_msa, r_shift_mlp, r_scale_mlp, r_gate_mlp = self.r_adaLN_modulation(c).chunk(6, dim=-1)
270
+
271
+ if pose_cond is not None:
272
+ if self.use_plucker:
273
+ input_cond = self.pose_cond_mlp(pose_cond)
274
+
275
+ if relative_embedding:
276
+ n_frames = x.shape[1] - reference_length
277
+ x1_relative_embedding = []
278
+ r_shift_msa_relative_embedding = []
279
+ r_scale_msa_relative_embedding = []
280
+ for i in range(n_frames):
281
+ x1_relative_embedding.append(torch.cat([x[:,i:i+1], x[:, -reference_length:]], dim=1).clone())
282
+ r_shift_msa_relative_embedding.append(torch.cat([r_shift_msa[:,i:i+1], r_shift_msa[:, -reference_length:]], dim=1).clone())
283
+ r_scale_msa_relative_embedding.append(torch.cat([r_scale_msa[:,i:i+1], r_scale_msa[:, -reference_length:]], dim=1).clone())
284
+ x1_zero_frame = torch.cat(x1_relative_embedding, dim=1)
285
+ r_shift_msa = torch.cat(r_shift_msa_relative_embedding, dim=1)
286
+ r_scale_msa = torch.cat(r_scale_msa_relative_embedding, dim=1)
287
+
288
+ # if current_frame == 18:
289
+ # import pdb;pdb.set_trace()
290
+
291
+ if self.state_embed_only_on_qk:
292
+ attn_input = x1_zero_frame
293
+ extra_condition = input_cond
294
+ else:
295
+ attn_input = input_cond + x1_zero_frame
296
+ extra_condition = None
297
+ else:
298
+ attn_input = input_cond + x
299
+ extra_condition = None
300
+ # print("input_cond2:", input_cond.abs().mean())
301
+ # print("c:", c.abs().mean())
302
+ # input_cond = x1
303
+
304
+ x = x + gate(self.r_attn(modulate(self.r_norm1(attn_input), r_shift_msa, r_scale_msa),
305
+ relative_embedding=relative_embedding,
306
+ extra_condition=extra_condition,
307
+ state_embed_only_on_qk=self.state_embed_only_on_qk,
308
+ reference_length=reference_length), r_gate_msa)
309
+ else:
310
+ # pose_cond *= 0
311
+ x = x + gate(self.r_attn(modulate(self.r_norm1(x+pose_cond[:,:,None, None]), r_shift_msa, r_scale_msa),
312
+ current_frame=current_frame, timestep=timestep,
313
+ is_last_block=is_last_block,
314
+ reference_length=reference_length), r_gate_msa)
315
+ else:
316
+ x = x + gate(self.r_attn(modulate(self.r_norm1(x), r_shift_msa, r_scale_msa), current_frame=current_frame, timestep=timestep,
317
+ is_last_block=is_last_block), r_gate_msa)
318
+
319
+ x = x + gate(self.r_mlp(modulate(self.r_norm2(x), r_shift_mlp, r_scale_mlp)), r_gate_mlp)
320
+
321
+ if self.ref_mode == 'parallel':
322
+ x = x_t + self.parallel_map(x)
323
+
324
+ return x
325
+
326
+ # print((x1-x2).abs().sum())
327
+ # r_shift_msa, r_scale_msa, r_gate_msa, r_shift_mlp, r_scale_mlp, r_gate_mlp = self.r_adaLN_modulation(c).chunk(6, dim=-1)
328
+ # x2 = x1 + gate(self.r_attn(modulate(self.r_norm1(x_), r_shift_msa, r_scale_msa)), r_gate_msa)
329
+ # x2 = gate(self.r_mlp(modulate(self.r_norm2(x2), r_shift_mlp, r_scale_mlp)), r_gate_mlp)
330
+ # x = x1 + x2
331
+
332
+ # print(x.mean())
333
+ # return x
334
+
335
+
336
+ class DiT(nn.Module):
337
+ """
338
+ Diffusion model with a Transformer backbone.
339
+ """
340
+
341
+ def __init__(
342
+ self,
343
+ input_h=18,
344
+ input_w=32,
345
+ patch_size=2,
346
+ in_channels=16,
347
+ hidden_size=1024,
348
+ depth=12,
349
+ num_heads=16,
350
+ mlp_ratio=4.0,
351
+ action_cond_dim=25,
352
+ pose_cond_dim=4,
353
+ max_frames=32,
354
+ reference_length=8,
355
+ use_plucker=False,
356
+ relative_embedding=False,
357
+ state_embed_only_on_qk=False,
358
+ use_memory_attention=False,
359
+ add_timestamp_embedding=False,
360
+ ref_mode='sequential'
361
+ ):
362
+ super().__init__()
363
+ self.in_channels = in_channels
364
+ self.out_channels = in_channels
365
+ self.patch_size = patch_size
366
+ self.num_heads = num_heads
367
+ self.max_frames = max_frames
368
+
369
+ self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
370
+ self.t_embedder = TimestepEmbedder(hidden_size)
371
+
372
+ self.add_timestamp_embedding = add_timestamp_embedding
373
+ if self.add_timestamp_embedding:
374
+ self.timestamp_embedding = TimestepEmbedder(hidden_size)
375
+
376
+ frame_h, frame_w = self.x_embedder.grid_size
377
+
378
+ self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
379
+ self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
380
+ # self.reference_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
381
+ self.reference_rotary_emb = None
382
+
383
+ self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity()
384
+
385
+ # self.pose_cond = nn.Linear(pose_cond_dim, hidden_size) if pose_cond_dim > 0 else nn.Identity()
386
+
387
+ self.use_plucker = use_plucker
388
+ if not self.use_plucker:
389
+ self.position_embedder = TimestepEmbedder(hidden_size, freq_type='spatial')
390
+ self.angle_embedder = TimestepEmbedder(hidden_size, freq_type='angle')
391
+ else:
392
+ self.pose_embedder = SimpleCameraPoseEncoder(c_in=6, c_out=hidden_size)
393
+
394
+ self.blocks = nn.ModuleList(
395
+ [
396
+ SpatioTemporalDiTBlock(
397
+ hidden_size,
398
+ num_heads,
399
+ mlp_ratio=mlp_ratio,
400
+ is_causal=True,
401
+ reference_length=reference_length,
402
+ spatial_rotary_emb=self.spatial_rotary_emb,
403
+ temporal_rotary_emb=self.temporal_rotary_emb,
404
+ reference_rotary_emb=self.reference_rotary_emb,
405
+ use_plucker=self.use_plucker,
406
+ relative_embedding=relative_embedding,
407
+ state_embed_only_on_qk=state_embed_only_on_qk,
408
+ use_memory_attention=use_memory_attention,
409
+ ref_mode=ref_mode
410
+ )
411
+ for _ in range(depth)
412
+ ]
413
+ )
414
+ self.use_memory_attention = use_memory_attention
415
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
416
+ self.initialize_weights()
417
+
418
+ def initialize_weights(self):
419
+ # Initialize transformer layers:
420
+ def _basic_init(module):
421
+ if isinstance(module, nn.Linear):
422
+ torch.nn.init.xavier_uniform_(module.weight)
423
+ if module.bias is not None:
424
+ nn.init.constant_(module.bias, 0)
425
+
426
+ self.apply(_basic_init)
427
+
428
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
429
+ w = self.x_embedder.proj.weight.data
430
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
431
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
432
+
433
+ # Initialize timestep embedding MLP:
434
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
435
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
436
+
437
+ if self.use_memory_attention:
438
+ if not self.use_plucker:
439
+ nn.init.normal_(self.position_embedder.mlp[0].weight, std=0.02)
440
+ nn.init.normal_(self.position_embedder.mlp[2].weight, std=0.02)
441
+
442
+ nn.init.normal_(self.angle_embedder.mlp[0].weight, std=0.02)
443
+ nn.init.normal_(self.angle_embedder.mlp[2].weight, std=0.02)
444
+
445
+ if self.add_timestamp_embedding:
446
+ nn.init.normal_(self.timestamp_embedding.mlp[0].weight, std=0.02)
447
+ nn.init.normal_(self.timestamp_embedding.mlp[2].weight, std=0.02)
448
+
449
+
450
+ # Zero-out adaLN modulation layers in DiT blocks:
451
+ for block in self.blocks:
452
+ nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
453
+ nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
454
+ nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
455
+ nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
456
+
457
+ if self.use_plucker and self.use_memory_attention:
458
+ nn.init.constant_(block.pose_cond_mlp.weight, 0)
459
+ nn.init.constant_(block.pose_cond_mlp.bias, 0)
460
+
461
+ # Zero-out output layers:
462
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
463
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
464
+ nn.init.constant_(self.final_layer.linear.weight, 0)
465
+ nn.init.constant_(self.final_layer.linear.bias, 0)
466
+
467
+ def unpatchify(self, x):
468
+ """
469
+ x: (N, H, W, patch_size**2 * C)
470
+ imgs: (N, H, W, C)
471
+ """
472
+ c = self.out_channels
473
+ p = self.x_embedder.patch_size[0]
474
+ h = x.shape[1]
475
+ w = x.shape[2]
476
+
477
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
478
+ x = torch.einsum("nhwpqc->nchpwq", x)
479
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
480
+ return imgs
481
+
482
+ def forward(self, x, t, action_cond=None, pose_cond=None, current_frame=None, mode=None,
483
+ reference_length=None, frame_idx=None):
484
+ """
485
+ Forward pass of DiT.
486
+ x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
487
+ t: (B, T,) tensor of diffusion timesteps
488
+ """
489
+
490
+ B, T, C, H, W = x.shape
491
+
492
+ # add spatial embeddings
493
+ x = rearrange(x, "b t c h w -> (b t) c h w")
494
+
495
+ x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
496
+ # restore shape
497
+ x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
498
+ # embed noise steps
499
+ t = rearrange(t, "b t -> (b t)")
500
+
501
+ c_t = self.t_embedder(t) # (N, D)
502
+ c = c_t.clone()
503
+ c = rearrange(c, "(b t) d -> b t d", t=T)
504
+
505
+ if torch.is_tensor(action_cond):
506
+ try:
507
+ c_action_cond = c + self.external_cond(action_cond)
508
+ except:
509
+ import pdb;pdb.set_trace()
510
+ else:
511
+ c_action_cond = None
512
+
513
+ if torch.is_tensor(pose_cond):
514
+ if not self.use_plucker:
515
+ pose_cond = pose_cond.to(action_cond.dtype)
516
+ b_, t_, d_ = pose_cond.shape
517
+ pos_emb = self.position_embedder(rearrange(pose_cond[...,:3], "b t d -> (b t d)"))
518
+ angle_emb = self.angle_embedder(rearrange(pose_cond[...,3:], "b t d -> (b t d)"))
519
+ pos_emb = rearrange(pos_emb, "(b t d) c -> b t d c", b=b_, t=t_, d=3).sum(-2)
520
+ angle_emb = rearrange(angle_emb, "(b t d) c -> b t d c", b=b_, t=t_, d=2).sum(-2)
521
+ pc = pos_emb + angle_emb
522
+ else:
523
+ pose_cond = pose_cond[:, :, ::40, ::40]
524
+ # pc = self.pose_embedder(pose_cond)[0]
525
+ # pc = pc.permute(0,2,3,4,1)
526
+ pc = self.pose_embedder(pose_cond)
527
+ pc = pc.permute(1,0,2,3,4)
528
+
529
+ if torch.is_tensor(frame_idx) and self.add_timestamp_embedding:
530
+ bb = frame_idx.shape[1]
531
+ frame_idx = rearrange(frame_idx, "t b -> (b t)")
532
+ frame_idx = self.timestamp_embedding(frame_idx)
533
+ frame_idx = rearrange(frame_idx, "(b t) d -> b t d", b=bb)
534
+ pc = pc + frame_idx[:, :, None, None]
535
+
536
+ # pc = pc + rearrange(c_t.clone(), "(b t) d -> b t d", t=T)[:,:,None,None] # add time condition for different timestep scaling
537
+ else:
538
+ pc = None
539
+
540
+ for i, block in enumerate(self.blocks):
541
+ x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)),
542
+ pose_cond=pc, mode=mode, c_action_cond=c_action_cond, reference_length=reference_length) # (N, T, H, W, D)
543
+ x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
544
+ # unpatchify
545
+ x = rearrange(x, "b t h w d -> (b t) h w d")
546
+ x = self.unpatchify(x) # (N, out_channels, H, W)
547
+ x = rearrange(x, "(b t) c h w -> b t c h w", t=T)
548
+ return x
549
+
550
+
551
+ def DiT_S_2(action_cond_dim, pose_cond_dim, reference_length,
552
+ use_plucker, relative_embedding,
553
+ state_embed_only_on_qk, use_memory_attention, add_timestamp_embedding,
554
+ ref_mode):
555
+ return DiT(
556
+ patch_size=2,
557
+ hidden_size=1024,
558
+ depth=16,
559
+ num_heads=16,
560
+ action_cond_dim=action_cond_dim,
561
+ pose_cond_dim=pose_cond_dim,
562
+ reference_length=reference_length,
563
+ use_plucker=use_plucker,
564
+ relative_embedding=relative_embedding,
565
+ state_embed_only_on_qk=state_embed_only_on_qk,
566
+ use_memory_attention=use_memory_attention,
567
+ add_timestamp_embedding=add_timestamp_embedding,
568
+ ref_mode=ref_mode
569
+ )
570
+
571
+
572
+ DiT_models = {"DiT-S/2": DiT_S_2}
algorithms/worldmem/models/pose_prediction.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class PosePredictionNet(nn.Module):
6
+ def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
7
+ super(PosePredictionNet, self).__init__()
8
+
9
+ self.cnn = nn.Sequential(
10
+ nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
11
+ nn.ReLU(),
12
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
13
+ nn.ReLU(),
14
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
15
+ nn.ReLU(),
16
+ nn.AdaptiveAvgPool2d((1, 1))
17
+ )
18
+
19
+ self.fc_img = nn.Linear(128, img_feat_dim)
20
+
21
+ self.mlp_motion = nn.Sequential(
22
+ nn.Linear(pose_dim + action_dim, hidden_dim),
23
+ nn.ReLU(),
24
+ nn.Linear(hidden_dim, hidden_dim),
25
+ nn.ReLU()
26
+ )
27
+
28
+ self.fc_out = nn.Sequential(
29
+ nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
30
+ nn.ReLU(),
31
+ nn.Linear(hidden_dim, pose_dim)
32
+ )
33
+
34
+ def forward(self, img, action, pose):
35
+ img_feat = self.cnn(img).view(img.size(0), -1)
36
+ img_feat = self.fc_img(img_feat)
37
+
38
+ motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
39
+ fused_feat = torch.cat([img_feat, motion_feat], dim=1)
40
+ pose_next_pred = self.fc_out(fused_feat)
41
+
42
+ return pose_next_pred
algorithms/worldmem/models/rotary_embedding_torch.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from math import pi, log
7
+
8
+ import torch
9
+ from torch.nn import Module, ModuleList
10
+ from torch.amp import autocast
11
+ from torch import nn, einsum, broadcast_tensors, Tensor
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from typing import Literal
16
+
17
+ # helper functions
18
+
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+
28
+ # broadcat, as tortoise-tts was using it
29
+
30
+
31
+ def broadcat(tensors, dim=-1):
32
+ broadcasted_tensors = broadcast_tensors(*tensors)
33
+ return torch.cat(broadcasted_tensors, dim=dim)
34
+
35
+
36
+ # rotary embedding helper functions
37
+
38
+
39
+ def rotate_half(x):
40
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
41
+ x1, x2 = x.unbind(dim=-1)
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return rearrange(x, "... d r -> ... (d r)")
44
+
45
+
46
+ @autocast("cuda", enabled=False)
47
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
48
+ dtype = t.dtype
49
+
50
+ if t.ndim == 3:
51
+ seq_len = t.shape[seq_dim]
52
+ freqs = freqs[-seq_len:]
53
+
54
+ rot_dim = freqs.shape[-1]
55
+ end_index = start_index + rot_dim
56
+
57
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
58
+
59
+ # Split t into three parts: left, middle (to be transformed), and right
60
+ t_left = t[..., :start_index]
61
+ t_middle = t[..., start_index:end_index]
62
+ t_right = t[..., end_index:]
63
+
64
+ # Apply rotary embeddings without modifying t in place
65
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
66
+
67
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
68
+
69
+ return out.type(dtype)
70
+
71
+
72
+ # learned rotation helpers
73
+
74
+
75
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
76
+ if exists(freq_ranges):
77
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
78
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
79
+
80
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
81
+ return apply_rotary_emb(rotations, t, start_index=start_index)
82
+
83
+
84
+ # classes
85
+
86
+
87
+ class RotaryEmbedding(Module):
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ custom_freqs: Tensor | None = None,
92
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
93
+ theta=10000,
94
+ max_freq=10,
95
+ num_freqs=1,
96
+ learned_freq=False,
97
+ use_xpos=False,
98
+ xpos_scale_base=512,
99
+ interpolate_factor=1.0,
100
+ theta_rescale_factor=1.0,
101
+ seq_before_head_dim=False,
102
+ cache_if_possible=True,
103
+ cache_max_seq_len=8192,
104
+ ):
105
+ super().__init__()
106
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
107
+ # has some connection to NTK literature
108
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
109
+
110
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
111
+
112
+ self.freqs_for = freqs_for
113
+
114
+ if exists(custom_freqs):
115
+ freqs = custom_freqs
116
+ elif freqs_for == "lang":
117
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
118
+ elif freqs_for == "pixel":
119
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
120
+ elif freqs_for == "spacetime":
121
+ time_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
122
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
123
+ elif freqs_for == "constant":
124
+ freqs = torch.ones(num_freqs).float()
125
+
126
+ if freqs_for == "spacetime":
127
+ self.time_freqs = nn.Parameter(time_freqs, requires_grad=learned_freq)
128
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
129
+
130
+ self.cache_if_possible = cache_if_possible
131
+ self.cache_max_seq_len = cache_max_seq_len
132
+
133
+ self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
134
+ self.register_buffer("cached_freqs_seq_len", torch.tensor(0), persistent=False)
135
+
136
+ self.learned_freq = learned_freq
137
+
138
+ # dummy for device
139
+
140
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
141
+
142
+ # default sequence dimension
143
+
144
+ self.seq_before_head_dim = seq_before_head_dim
145
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
146
+
147
+ # interpolation factors
148
+
149
+ assert interpolate_factor >= 1.0
150
+ self.interpolate_factor = interpolate_factor
151
+
152
+ # xpos
153
+
154
+ self.use_xpos = use_xpos
155
+
156
+ if not use_xpos:
157
+ return
158
+
159
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
160
+ self.scale_base = xpos_scale_base
161
+
162
+ self.register_buffer("scale", scale, persistent=False)
163
+ self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
164
+ self.register_buffer("cached_scales_seq_len", torch.tensor(0), persistent=False)
165
+
166
+ # add apply_rotary_emb as static method
167
+
168
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
169
+
170
+ @property
171
+ def device(self):
172
+ return self.dummy.device
173
+
174
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
175
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
176
+
177
+ def rotate_queries_or_keys(self, t, freqs, seq_dim=None, offset=0, scale=None):
178
+ seq_dim = default(seq_dim, self.default_seq_dim)
179
+
180
+ assert not self.use_xpos or exists(scale), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
181
+
182
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
183
+
184
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
185
+
186
+ seq_freqs = self.forward(seq, freqs, seq_len=seq_len, offset=offset)
187
+
188
+ if seq_dim == -3:
189
+ seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
190
+
191
+ return apply_rotary_emb(seq_freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
192
+
193
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
194
+ dtype, device, seq_dim = (
195
+ q.dtype,
196
+ q.device,
197
+ default(seq_dim, self.default_seq_dim),
198
+ )
199
+
200
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
201
+ assert q_len <= k_len
202
+
203
+ q_scale = k_scale = 1.0
204
+
205
+ if self.use_xpos:
206
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
207
+
208
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
209
+ k_scale = self.get_scale(seq).type(dtype)
210
+
211
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
212
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
213
+
214
+ rotated_q = rotated_q.type(q.dtype)
215
+ rotated_k = rotated_k.type(k.dtype)
216
+
217
+ return rotated_q, rotated_k
218
+
219
+ def rotate_queries_and_keys(self, q, k, freqs, seq_dim=None):
220
+ seq_dim = default(seq_dim, self.default_seq_dim)
221
+
222
+ assert self.use_xpos
223
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
224
+
225
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
226
+
227
+ seq_freqs = self.forward(seq, freqs, seq_len=seq_len)
228
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
229
+
230
+ if seq_dim == -3:
231
+ seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
232
+ scale = rearrange(scale, "n d -> n 1 d")
233
+
234
+ rotated_q = apply_rotary_emb(seq_freqs, q, scale=scale, seq_dim=seq_dim)
235
+ rotated_k = apply_rotary_emb(seq_freqs, k, scale=scale**-1, seq_dim=seq_dim)
236
+
237
+ rotated_q = rotated_q.type(q.dtype)
238
+ rotated_k = rotated_k.type(k.dtype)
239
+
240
+ return rotated_q, rotated_k
241
+
242
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
243
+ assert self.use_xpos
244
+
245
+ should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
246
+
247
+ if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len.item():
248
+ return self.cached_scales[offset : (offset + seq_len)]
249
+
250
+ scale = 1.0
251
+ if self.use_xpos:
252
+ power = (t - len(t) // 2) / self.scale_base
253
+ scale = self.scale ** rearrange(power, "n -> n 1")
254
+ scale = repeat(scale, "n d -> n (d r)", r=2)
255
+
256
+ if should_cache and offset == 0:
257
+ self.cached_scales[:seq_len] = scale.detach()
258
+ self.cached_scales_seq_len.copy_(seq_len)
259
+
260
+ return scale
261
+
262
+ def get_axial_freqs(self, *dims):
263
+ Colon = slice(None)
264
+ all_freqs = []
265
+
266
+ for ind, dim in enumerate(dims):
267
+ # only allow pixel freqs for last two dimensions
268
+ use_pixel = (self.freqs_for == "pixel" or self.freqs_for == "spacetime") and ind >= len(dims) - 2
269
+ if use_pixel:
270
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
271
+ else:
272
+ pos = torch.arange(dim, device=self.device)
273
+
274
+ if self.freqs_for == "spacetime" and not use_pixel:
275
+ seq_freqs = self.forward(pos, self.time_freqs, seq_len=dim)
276
+ else:
277
+ seq_freqs = self.forward(pos, self.freqs, seq_len=dim)
278
+
279
+ all_axis = [None] * len(dims)
280
+ all_axis[ind] = Colon
281
+
282
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
283
+ all_freqs.append(seq_freqs[new_axis_slice])
284
+
285
+ all_freqs = broadcast_tensors(*all_freqs)
286
+ return torch.cat(all_freqs, dim=-1)
287
+
288
+ @autocast("cuda", enabled=False)
289
+ def forward(self, t: Tensor, freqs: Tensor, seq_len=None, offset=0):
290
+ should_cache = self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" and (offset + seq_len) <= self.cache_max_seq_len
291
+
292
+ if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len.item():
293
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
294
+
295
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
296
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
297
+
298
+ if should_cache and offset == 0:
299
+ self.cached_freqs[:seq_len] = freqs.detach()
300
+ self.cached_freqs_seq_len.copy_(seq_len)
301
+
302
+ return freqs
algorithms/worldmem/models/utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
3
+ Action format derived from VPT https://github.com/openai/Video-Pre-Training
4
+ Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ from torch import nn
10
+ from torchvision.io import read_image, read_video
11
+ from torchvision.transforms.functional import resize
12
+ from einops import rearrange
13
+ from typing import Mapping, Sequence
14
+ from einops import rearrange, parse_shape
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if callable(d) else d
25
+
26
+
27
+ def extract(a, t, x_shape):
28
+ f, b = t.shape
29
+ out = a[t]
30
+ return out.reshape(f, b, *((1,) * (len(x_shape) - 2)))
31
+
32
+
33
+ def linear_beta_schedule(timesteps):
34
+ """
35
+ linear schedule, proposed in original ddpm paper
36
+ """
37
+ scale = 1000 / timesteps
38
+ beta_start = scale * 0.0001
39
+ beta_end = scale * 0.02
40
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
41
+
42
+
43
+ def cosine_beta_schedule(timesteps, s=0.008):
44
+ """
45
+ cosine schedule
46
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
47
+ """
48
+ steps = timesteps + 1
49
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
50
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
51
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
52
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
53
+ return torch.clip(betas, 0, 0.999)
54
+
55
+
56
+
57
+ def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
58
+ """
59
+ sigmoid schedule
60
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
61
+ better for images > 64x64, when used during training
62
+ """
63
+ steps = timesteps + 1
64
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
65
+ v_start = torch.tensor(start / tau).sigmoid()
66
+ v_end = torch.tensor(end / tau).sigmoid()
67
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
68
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
69
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
70
+ return torch.clip(betas, 0, 0.999)
71
+
72
+
73
+ ACTION_KEYS = [
74
+ "inventory",
75
+ "ESC",
76
+ "hotbar.1",
77
+ "hotbar.2",
78
+ "hotbar.3",
79
+ "hotbar.4",
80
+ "hotbar.5",
81
+ "hotbar.6",
82
+ "hotbar.7",
83
+ "hotbar.8",
84
+ "hotbar.9",
85
+ "forward",
86
+ "back",
87
+ "left",
88
+ "right",
89
+ "cameraX",
90
+ "cameraY",
91
+ "jump",
92
+ "sneak",
93
+ "sprint",
94
+ "swapHands",
95
+ "attack",
96
+ "use",
97
+ "pickItem",
98
+ "drop",
99
+ ]
100
+
101
+
102
+ def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
103
+ actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
104
+ for i, current_actions in enumerate(actions):
105
+ for j, action_key in enumerate(ACTION_KEYS):
106
+ if action_key.startswith("camera"):
107
+ if action_key == "cameraX":
108
+ value = current_actions["camera"][0]
109
+ elif action_key == "cameraY":
110
+ value = current_actions["camera"][1]
111
+ else:
112
+ raise ValueError(f"Unknown camera action key: {action_key}")
113
+ max_val = 20
114
+ bin_size = 0.5
115
+ num_buckets = int(max_val / bin_size)
116
+ value = (value - num_buckets) / num_buckets
117
+ assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
118
+ else:
119
+ value = current_actions[action_key]
120
+ assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
121
+ actions_one_hot[i, j] = value
122
+
123
+ return actions_one_hot
124
+
125
+
126
+ IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
127
+ VIDEO_EXTENSIONS = {"mp4"}
128
+
129
+
130
+ def load_prompt(path, video_offset=None, n_prompt_frames=1):
131
+ if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
132
+ print("prompt is image; ignoring video_offset and n_prompt_frames")
133
+ prompt = read_image(path)
134
+ # add frame dimension
135
+ prompt = rearrange(prompt, "c h w -> 1 c h w")
136
+ elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
137
+ prompt = read_video(path, pts_unit="sec")[0]
138
+ if video_offset is not None:
139
+ prompt = prompt[video_offset:]
140
+ prompt = prompt[:n_prompt_frames]
141
+ else:
142
+ raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
143
+ assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
144
+ prompt = resize(prompt, (360, 640))
145
+ # add batch dimension
146
+ prompt = rearrange(prompt, "t c h w -> 1 t c h w")
147
+ prompt = prompt.float() / 255.0
148
+ return prompt
149
+
150
+
151
+ def load_actions(path, action_offset=None):
152
+ if path.endswith(".actions.pt"):
153
+ actions = one_hot_actions(torch.load(path))
154
+ elif path.endswith(".one_hot_actions.pt"):
155
+ actions = torch.load(path, weights_only=True)
156
+ else:
157
+ raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
158
+ if action_offset is not None:
159
+ actions = actions[action_offset:]
160
+ actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
161
+ # add batch dimension
162
+ actions = rearrange(actions, "t d -> 1 t d")
163
+ return actions
algorithms/worldmem/models/vae.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VQGAN: https://github.com/CompVis/taming-transformers
4
+ - MAE: https://github.com/facebookresearch/mae
5
+ """
6
+
7
+ import numpy as np
8
+ import math
9
+ import functools
10
+ from collections import namedtuple
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Mlp
16
+ from timm.layers.helpers import to_2tuple
17
+ from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
18
+ from .dit import PatchEmbed
19
+
20
+
21
+ class DiagonalGaussianDistribution(object):
22
+ def __init__(self, parameters, deterministic=False, dim=1):
23
+ self.parameters = parameters
24
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
25
+ if dim == 1:
26
+ self.dims = [1, 2, 3]
27
+ elif dim == 2:
28
+ self.dims = [1, 2]
29
+ else:
30
+ raise NotImplementedError
31
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
32
+ self.deterministic = deterministic
33
+ self.std = torch.exp(0.5 * self.logvar)
34
+ self.var = torch.exp(self.logvar)
35
+ if self.deterministic:
36
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
37
+
38
+ def sample(self):
39
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
40
+ return x
41
+
42
+ def mode(self):
43
+ return self.mean
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ num_heads,
51
+ frame_height,
52
+ frame_width,
53
+ qkv_bias=False,
54
+ ):
55
+ super().__init__()
56
+ self.num_heads = num_heads
57
+ head_dim = dim // num_heads
58
+ self.frame_height = frame_height
59
+ self.frame_width = frame_width
60
+
61
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62
+ self.proj = nn.Linear(dim, dim)
63
+
64
+ rotary_freqs = RotaryEmbedding(
65
+ dim=head_dim // 4,
66
+ freqs_for="pixel",
67
+ max_freq=frame_height * frame_width,
68
+ ).get_axial_freqs(frame_height, frame_width)
69
+ self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ assert N == self.frame_height * self.frame_width
74
+
75
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
76
+
77
+ q = rearrange(
78
+ q,
79
+ "b (H W) (h d) -> b h H W d",
80
+ H=self.frame_height,
81
+ W=self.frame_width,
82
+ h=self.num_heads,
83
+ )
84
+ k = rearrange(
85
+ k,
86
+ "b (H W) (h d) -> b h H W d",
87
+ H=self.frame_height,
88
+ W=self.frame_width,
89
+ h=self.num_heads,
90
+ )
91
+ v = rearrange(
92
+ v,
93
+ "b (H W) (h d) -> b h H W d",
94
+ H=self.frame_height,
95
+ W=self.frame_width,
96
+ h=self.num_heads,
97
+ )
98
+
99
+ q = apply_rotary_emb(self.rotary_freqs, q)
100
+ k = apply_rotary_emb(self.rotary_freqs, k)
101
+
102
+ q = rearrange(q, "b h H W d -> b h (H W) d")
103
+ k = rearrange(k, "b h H W d -> b h (H W) d")
104
+ v = rearrange(v, "b h H W d -> b h (H W) d")
105
+
106
+ x = F.scaled_dot_product_attention(q, k, v)
107
+ x = rearrange(x, "b h N d -> b N (h d)")
108
+
109
+ x = self.proj(x)
110
+ return x
111
+
112
+
113
+ class AttentionBlock(nn.Module):
114
+ def __init__(
115
+ self,
116
+ dim,
117
+ num_heads,
118
+ frame_height,
119
+ frame_width,
120
+ mlp_ratio=4.0,
121
+ qkv_bias=False,
122
+ attn_causal=False,
123
+ act_layer=nn.GELU,
124
+ norm_layer=nn.LayerNorm,
125
+ ):
126
+ super().__init__()
127
+ self.norm1 = norm_layer(dim)
128
+ self.attn = Attention(
129
+ dim,
130
+ num_heads,
131
+ frame_height,
132
+ frame_width,
133
+ qkv_bias=qkv_bias,
134
+ )
135
+ self.norm2 = norm_layer(dim)
136
+ mlp_hidden_dim = int(dim * mlp_ratio)
137
+ self.mlp = Mlp(
138
+ in_features=dim,
139
+ hidden_features=mlp_hidden_dim,
140
+ act_layer=act_layer,
141
+ )
142
+
143
+ def forward(self, x):
144
+ x = x + self.attn(self.norm1(x))
145
+ x = x + self.mlp(self.norm2(x))
146
+ return x
147
+
148
+
149
+ class AutoencoderKL(nn.Module):
150
+ def __init__(
151
+ self,
152
+ latent_dim,
153
+ input_height=256,
154
+ input_width=256,
155
+ patch_size=16,
156
+ enc_dim=768,
157
+ enc_depth=6,
158
+ enc_heads=12,
159
+ dec_dim=768,
160
+ dec_depth=6,
161
+ dec_heads=12,
162
+ mlp_ratio=4.0,
163
+ norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
164
+ use_variational=True,
165
+ **kwargs,
166
+ ):
167
+ super().__init__()
168
+ self.input_height = input_height
169
+ self.input_width = input_width
170
+ self.patch_size = patch_size
171
+ self.seq_h = input_height // patch_size
172
+ self.seq_w = input_width // patch_size
173
+ self.seq_len = self.seq_h * self.seq_w
174
+ self.patch_dim = 3 * patch_size**2
175
+
176
+ self.latent_dim = latent_dim
177
+ self.enc_dim = enc_dim
178
+ self.dec_dim = dec_dim
179
+
180
+ # patch
181
+ self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
182
+
183
+ # encoder
184
+ self.encoder = nn.ModuleList(
185
+ [
186
+ AttentionBlock(
187
+ enc_dim,
188
+ enc_heads,
189
+ self.seq_h,
190
+ self.seq_w,
191
+ mlp_ratio,
192
+ qkv_bias=True,
193
+ norm_layer=norm_layer,
194
+ )
195
+ for i in range(enc_depth)
196
+ ]
197
+ )
198
+ self.enc_norm = norm_layer(enc_dim)
199
+
200
+ # bottleneck
201
+ self.use_variational = use_variational
202
+ mult = 2 if self.use_variational else 1
203
+ self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
204
+ self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
205
+
206
+ # decoder
207
+ self.decoder = nn.ModuleList(
208
+ [
209
+ AttentionBlock(
210
+ dec_dim,
211
+ dec_heads,
212
+ self.seq_h,
213
+ self.seq_w,
214
+ mlp_ratio,
215
+ qkv_bias=True,
216
+ norm_layer=norm_layer,
217
+ )
218
+ for i in range(dec_depth)
219
+ ]
220
+ )
221
+ self.dec_norm = norm_layer(dec_dim)
222
+ self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
223
+
224
+ # initialize this weight first
225
+ self.initialize_weights()
226
+
227
+ def initialize_weights(self):
228
+ # initialization
229
+ # initialize nn.Linear and nn.LayerNorm
230
+ self.apply(self._init_weights)
231
+
232
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
233
+ w = self.patch_embed.proj.weight.data
234
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
235
+
236
+ def _init_weights(self, m):
237
+ if isinstance(m, nn.Linear):
238
+ # we use xavier_uniform following official JAX ViT:
239
+ nn.init.xavier_uniform_(m.weight)
240
+ if isinstance(m, nn.Linear) and m.bias is not None:
241
+ nn.init.constant_(m.bias, 0.0)
242
+ elif isinstance(m, nn.LayerNorm):
243
+ nn.init.constant_(m.bias, 0.0)
244
+ nn.init.constant_(m.weight, 1.0)
245
+
246
+ def patchify(self, x):
247
+ # patchify
248
+ bsz, _, h, w = x.shape
249
+ x = x.reshape(
250
+ bsz,
251
+ 3,
252
+ self.seq_h,
253
+ self.patch_size,
254
+ self.seq_w,
255
+ self.patch_size,
256
+ ).permute([0, 1, 3, 5, 2, 4]) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
257
+ x = x.reshape(bsz, self.patch_dim, self.seq_h, self.seq_w) # --> [b, cxpxp, h, w]
258
+ x = x.permute([0, 2, 3, 1]).reshape(bsz, self.seq_len, self.patch_dim) # --> [b, hxw, cxpxp]
259
+ return x
260
+
261
+ def unpatchify(self, x):
262
+ bsz = x.shape[0]
263
+ # unpatchify
264
+ x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute([0, 3, 1, 2]) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
265
+ x = x.reshape(
266
+ bsz,
267
+ 3,
268
+ self.patch_size,
269
+ self.patch_size,
270
+ self.seq_h,
271
+ self.seq_w,
272
+ ).permute([0, 1, 4, 2, 5, 3]) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
273
+ x = x.reshape(
274
+ bsz,
275
+ 3,
276
+ self.input_height,
277
+ self.input_width,
278
+ ) # [b, c, hxp, wxp]
279
+ return x
280
+
281
+ def encode(self, x):
282
+ # patchify
283
+ x = self.patch_embed(x)
284
+
285
+ # encoder
286
+ for blk in self.encoder:
287
+ x = blk(x)
288
+ x = self.enc_norm(x)
289
+
290
+ # bottleneck
291
+ moments = self.quant_conv(x)
292
+ if not self.use_variational:
293
+ moments = torch.cat((moments, torch.zeros_like(moments)), 2)
294
+ posterior = DiagonalGaussianDistribution(moments, deterministic=(not self.use_variational), dim=2)
295
+ return posterior
296
+
297
+ def decode(self, z):
298
+ # bottleneck
299
+ z = self.post_quant_conv(z)
300
+
301
+ # decoder
302
+ for blk in self.decoder:
303
+ z = blk(z)
304
+ z = self.dec_norm(z)
305
+
306
+ # predictor
307
+ z = self.predictor(z)
308
+
309
+ # unpatchify
310
+ dec = self.unpatchify(z)
311
+ return dec
312
+
313
+ def autoencode(self, input, sample_posterior=True):
314
+ posterior = self.encode(input)
315
+ if self.use_variational and sample_posterior:
316
+ z = posterior.sample()
317
+ else:
318
+ z = posterior.mode()
319
+ dec = self.decode(z)
320
+ return dec, posterior, z
321
+
322
+ def get_input(self, batch, k):
323
+ x = batch[k]
324
+ if len(x.shape) == 3:
325
+ x = x[..., None]
326
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
327
+ return x
328
+
329
+ def forward(self, inputs, labels, split="train"):
330
+ rec, post, latent = self.autoencode(inputs)
331
+ return rec, post, latent
332
+
333
+ def get_last_layer(self):
334
+ return self.predictor.weight
335
+
336
+
337
+ def ViT_L_20_Shallow_Encoder(**kwargs):
338
+ if "latent_dim" in kwargs:
339
+ latent_dim = kwargs.pop("latent_dim")
340
+ else:
341
+ latent_dim = 16
342
+ return AutoencoderKL(
343
+ latent_dim=latent_dim,
344
+ patch_size=20,
345
+ enc_dim=1024,
346
+ enc_depth=6,
347
+ enc_heads=16,
348
+ dec_dim=1024,
349
+ dec_depth=12,
350
+ dec_heads=16,
351
+ input_height=360,
352
+ input_width=640,
353
+ **kwargs,
354
+ )
355
+
356
+
357
+ VAE_models = {
358
+ "vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
359
+ }
algorithms/worldmem/pose_prediction.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import DictConfig
2
+ import torch
3
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
4
+ from algorithms.common.metrics import (
5
+ FrechetInceptionDistance,
6
+ LearnedPerceptualImagePatchSimilarity,
7
+ FrechetVideoDistance,
8
+ )
9
+ from .df_base import DiffusionForcingBase
10
+ from utils.logging_utils import log_video, get_validation_metrics_for_videos
11
+ from .models.vae import VAE_models
12
+ from .models.dit import DiT_models
13
+ from einops import rearrange
14
+ from torch import autocast
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+ import torch.nn.functional as F
18
+ from .models.pose_prediction import PosePredictionNet
19
+ import torchvision.transforms.functional as TF
20
+ import random
21
+ from torchvision.transforms import InterpolationMode
22
+ from PIL import Image
23
+ import math
24
+ from packaging import version as pver
25
+ import torch.distributed as dist
26
+ import matplotlib.pyplot as plt
27
+
28
+ import torch
29
+ import math
30
+ import wandb
31
+
32
+ import torch.nn as nn
33
+ from algorithms.common.base_pytorch_algo import BasePytorchAlgo
34
+
35
+ class PosePrediction(BasePytorchAlgo):
36
+
37
+ def __init__(self, cfg: DictConfig):
38
+
39
+ super().__init__(cfg)
40
+
41
+ def _build_model(self):
42
+ self.pose_prediction_model = PosePredictionNet()
43
+ vae = VAE_models["vit-l-20-shallow-encoder"]()
44
+ self.vae = vae.eval()
45
+
46
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
47
+ xs, conditions, pose_conditions= batch
48
+ pose_conditions[:,:,3:] = pose_conditions[:,:,3:] // 15
49
+ xs = self.encode(xs)
50
+
51
+ b,f,c,h,w = xs.shape
52
+ xs = xs[:,:-1].reshape(-1, c, h, w)
53
+ conditions = conditions[:,1:].reshape(-1, 25)
54
+ offset_gt = pose_conditions[:,1:] - pose_conditions[:,:-1]
55
+ pose_conditions = pose_conditions[:,:-1].reshape(-1, 5)
56
+ offset_gt = offset_gt.reshape(-1, 5)
57
+ offset_gt[:, 3][offset_gt[:, 3]==23] = -1
58
+ offset_gt[:, 3][offset_gt[:, 3]==-23] = 1
59
+ offset_gt[:, 4][offset_gt[:, 4]==23] = -1
60
+ offset_gt[:, 4][offset_gt[:, 4]==-23] = 1
61
+
62
+ offset_pred = self.pose_prediction_model(xs, conditions, pose_conditions)
63
+ criterion = nn.MSELoss()
64
+ loss = criterion(offset_pred, offset_gt)
65
+ if batch_idx % 200 == 0:
66
+ self.log("training/loss", loss.cpu())
67
+ output_dict = {
68
+ "loss": loss}
69
+ return output_dict
70
+
71
+ def encode(self, x):
72
+ # vae encoding
73
+ B = x.shape[1]
74
+ T = x.shape[0]
75
+ H, W = x.shape[-2:]
76
+ scaling_factor = 0.07843137255
77
+
78
+ x = rearrange(x, "t b c h w -> (t b) c h w")
79
+ with torch.no_grad():
80
+ with autocast("cuda", dtype=torch.half):
81
+ x = self.vae.encode(x * 2 - 1).mean * scaling_factor
82
+ x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
83
+ # x = x[:, :n_prompt_frames]
84
+ return x
85
+
86
+ def decode(self, x):
87
+ total_frames = x.shape[0]
88
+ scaling_factor = 0.07843137255
89
+ x = rearrange(x, "t b c h w -> (t b) (h w) c")
90
+ with torch.no_grad():
91
+ with autocast("cuda", dtype=torch.half):
92
+ x = (self.vae.decode(x / scaling_factor) + 1) / 2
93
+
94
+ x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
95
+ return x
96
+
97
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
98
+ xs, conditions, pose_conditions= batch
99
+ pose_conditions[:,:,3:] = pose_conditions[:,:,3:] // 15
100
+ xs = self.encode(xs)
101
+
102
+ b,f,c,h,w = xs.shape
103
+ xs = xs[:,:-1].reshape(-1, c, h, w)
104
+ conditions = conditions[:,1:].reshape(-1, 25)
105
+ offset_gt = pose_conditions[:,1:] - pose_conditions[:,:-1]
106
+ pose_conditions = pose_conditions[:,:-1].reshape(-1, 5)
107
+ offset_gt = offset_gt.reshape(-1, 5)
108
+ offset_gt[:, 3][offset_gt[:, 3]==23] = -1
109
+ offset_gt[:, 3][offset_gt[:, 3]==-23] = 1
110
+ offset_gt[:, 4][offset_gt[:, 4]==23] = -1
111
+ offset_gt[:, 4][offset_gt[:, 4]==-23] = 1
112
+
113
+ offset_pred = self.pose_prediction_model(xs, conditions, pose_conditions)
114
+
115
+ criterion = nn.MSELoss()
116
+ loss = criterion(offset_pred, offset_gt)
117
+
118
+ if batch_idx % 200 == 0:
119
+ self.log("validation/loss", loss.cpu())
120
+ output_dict = {
121
+ "loss": loss}
122
+ return
123
+
124
+ @torch.no_grad()
125
+ def interactive(self, batch, context_frames, device):
126
+ with torch.cuda.amp.autocast():
127
+ condition_similar_length = self.condition_similar_length
128
+ # xs_raw, conditions, pose_conditions, c2w_mat, masks, frame_idx = self._preprocess_batch(batch)
129
+
130
+ first_frame, new_conditions, new_pose_conditions, new_c2w_mat, new_frame_idx = batch
131
+
132
+ if self.frames is None:
133
+ first_frame_encode = self.encode(first_frame[None, None].to(device))
134
+ self.frames = first_frame_encode.to(device)
135
+ self.actions = new_conditions[None, None].to(device)
136
+ self.poses = new_pose_conditions[None, None].to(device)
137
+ self.memory_c2w = new_c2w_mat[None, None].to(device)
138
+ self.frame_idx = torch.tensor([[new_frame_idx]]).to(device)
139
+ return first_frame
140
+ else:
141
+ self.actions = torch.cat([self.actions, new_conditions[None, None].to(device)])
142
+ self.poses = torch.cat([self.poses, new_pose_conditions[None, None].to(device)])
143
+ self.memory_c2w = torch.cat([self.memory_c2w, new_c2w_mat[None, None].to(device)])
144
+ self.frame_idx = torch.cat([self.frame_idx, torch.tensor([[new_frame_idx]]).to(device)])
145
+
146
+ conditions = self.actions.clone()
147
+ pose_conditions = self.poses.clone()
148
+ c2w_mat = self.memory_c2w .clone()
149
+ frame_idx = self.frame_idx.clone()
150
+
151
+
152
+ curr_frame = 0
153
+ horizon = 1
154
+ batch_size = 1
155
+ n_frames = curr_frame + horizon
156
+ # context
157
+ n_context_frames = context_frames // self.frame_stack
158
+ xs_pred = self.frames[:n_context_frames].clone()
159
+ curr_frame += n_context_frames
160
+
161
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
162
+
163
+ # generation on frame
164
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
165
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
166
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
167
+
168
+ xs_pred = torch.cat([xs_pred, chunk], 0)
169
+
170
+ # sliding window: only input the last n_tokens frames
171
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
172
+
173
+ pbar.set_postfix(
174
+ {
175
+ "start": start_frame,
176
+ "end": curr_frame + horizon,
177
+ }
178
+ )
179
+
180
+ if condition_similar_length:
181
+
182
+ if curr_frame < condition_similar_length:
183
+ random_idx = [i for i in range(curr_frame)] + [0] * (condition_similar_length-curr_frame)
184
+ random_idx = np.repeat(np.array(random_idx)[:,None], xs_pred.shape[1], -1)
185
+ else:
186
+ num_samples = 10000
187
+ radius = 30
188
+ samples = torch.rand((num_samples, 1), device=pose_conditions.device)
189
+ angles = 2 * np.pi * torch.rand((num_samples,), device=pose_conditions.device)
190
+ # points = radius * torch.sqrt(samples) * torch.stack((torch.cos(angles), torch.sin(angles)), dim=1)
191
+
192
+ points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
193
+ points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
194
+ points += pose_conditions[curr_frame, :, :3][None]
195
+ fov_half_h = torch.tensor(105/2, device=pose_conditions.device)
196
+ fov_half_v = torch.tensor(75/2, device=pose_conditions.device)
197
+ # in_fov1 = is_inside_fov(points, pose_conditions[curr_frame, :, [0, 2]], pose_conditions[curr_frame, :, -1], fov_half)
198
+
199
+ in_fov1 = is_inside_fov_3d_hv(points, pose_conditions[curr_frame, :, :3],
200
+ pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
201
+ fov_half_h, fov_half_v)
202
+
203
+ in_fov_list = []
204
+ for pc in pose_conditions[:curr_frame]:
205
+ in_fov_list.append(is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1],
206
+ fov_half_h, fov_half_v))
207
+
208
+ in_fov_list = torch.stack(in_fov_list)
209
+ # v3
210
+ random_idx = []
211
+
212
+ for csl in range(self.condition_similar_length // 2):
213
+ overlap_ratio = ((in_fov1[None].bool() & in_fov_list).sum(1))/in_fov1.sum()
214
+ # mask = distance > (in_fov1.bool().sum(0) / 4)
215
+ #_, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
216
+
217
+ # if csl > self.condition_similar_length:
218
+ # _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
219
+ # else:
220
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
221
+
222
+ _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
223
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
224
+
225
+ # if curr_frame >=93:
226
+ # import pdb;pdb.set_trace()
227
+
228
+ # start_time = time.time()
229
+ cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
230
+ range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
231
+ cos_sim = cos_sim.mean((-2,-1))
232
+
233
+ mask_sim = cos_sim>0.9
234
+ in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
235
+
236
+ random_idx.append(r_idx)
237
+
238
+ for bi in range(conditions.shape[1]):
239
+ if len(torch.nonzero(conditions[:,bi,24] == 1))==0:
240
+ pass
241
+ else:
242
+ last_idx = torch.nonzero(conditions[:,bi,24] == 1)[-1]
243
+ in_fov_list[:last_idx,:,bi] = False
244
+
245
+ for csl in range(self.condition_similar_length // 2):
246
+ overlap_ratio = ((in_fov1[None].bool() & in_fov_list).sum(1))/in_fov1.sum()
247
+ # mask = distance > (in_fov1.bool().sum(0) / 4)
248
+ #_, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
249
+
250
+ # if csl > self.condition_similar_length:
251
+ # _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
252
+ # else:
253
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
254
+
255
+ _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
256
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
257
+
258
+ # if curr_frame >=93:
259
+ # import pdb;pdb.set_trace()
260
+
261
+ # start_time = time.time()
262
+ cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
263
+ range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
264
+ cos_sim = cos_sim.mean((-2,-1))
265
+
266
+ mask_sim = cos_sim>0.9
267
+ in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
268
+
269
+ random_idx.append(r_idx)
270
+
271
+ random_idx = torch.cat(random_idx).cpu()
272
+ condition_similar_length = len(random_idx)
273
+
274
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
275
+
276
+ if condition_similar_length:
277
+ # import pdb;pdb.set_trace()
278
+ padding = torch.zeros((condition_similar_length,) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
279
+ input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], padding], dim=0)
280
+ if self.pose_cond_dim:
281
+ # if not self.use_plucker:
282
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
283
+
284
+ if self.use_plucker:
285
+ if self.all_zero_frame:
286
+ frame_idx_list = []
287
+ input_pose_condition = []
288
+ for i in range(start_frame, curr_frame + horizon):
289
+ input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]]).clone(), 0, focal_length=self.focal_length, is_old_setting=self.old_setting).to(xs_pred.dtype))
290
+ frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]-frame_idx[i:i+1]]))
291
+ input_pose_condition = torch.cat(input_pose_condition)
292
+ frame_idx_list = torch.cat(frame_idx_list)
293
+
294
+ # print(frame_idx_list[:,0])
295
+ else:
296
+ # print(curr_frame-start_frame)
297
+ # input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
298
+ # import pdb;pdb.set_trace()
299
+ if self.last_frame_refer:
300
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[-1:]], dim=0).clone()
301
+ else:
302
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
303
+
304
+ if self.zero_curr:
305
+ # print("="*50)
306
+ input_pose_condition = convert_to_plucker(input_pose_condition, curr_frame-start_frame, focal_length=self.focal_length, is_old_setting=self.old_setting)
307
+ # input_pose_condition[:curr_frame-start_frame] = input_pose_condition[curr_frame-start_frame:curr_frame-start_frame+1]
308
+ # input_pose_condition = convert_to_plucker(input_pose_condition, -self.condition_similar_length-1, focal_length=self.focal_length)
309
+ else:
310
+ input_pose_condition = convert_to_plucker(input_pose_condition, -condition_similar_length, focal_length=self.focal_length, is_old_setting=self.old_setting)
311
+ frame_idx_list = None
312
+ else:
313
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
314
+ frame_idx_list = None
315
+ else:
316
+ input_condition = conditions[start_frame : curr_frame + horizon]
317
+ input_pose_condition = None
318
+ frame_idx_list = None
319
+
320
+ for m in range(scheduling_matrix.shape[0] - 1):
321
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
322
+ :, None
323
+ ].repeat(batch_size, axis=1)
324
+ to_noise_levels = np.concatenate(
325
+ (
326
+ np.zeros((curr_frame,), dtype=np.int64),
327
+ scheduling_matrix[m + 1],
328
+ )
329
+ )[
330
+ :, None
331
+ ].repeat(batch_size, axis=1)
332
+
333
+ if condition_similar_length:
334
+ from_noise_levels = np.concatenate([from_noise_levels, np.zeros((condition_similar_length,from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
335
+ to_noise_levels = np.concatenate([to_noise_levels, np.zeros((condition_similar_length,from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
336
+
337
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
338
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
339
+
340
+
341
+ if input_pose_condition is not None:
342
+ input_pose_condition = input_pose_condition.to(xs_pred.dtype)
343
+
344
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
345
+ xs_pred[start_frame:],
346
+ input_condition,
347
+ input_pose_condition,
348
+ from_noise_levels[start_frame:],
349
+ to_noise_levels[start_frame:],
350
+ current_frame=curr_frame,
351
+ mode="validation",
352
+ reference_length=condition_similar_length,
353
+ frame_idx=frame_idx_list
354
+ )
355
+
356
+ # if curr_frame > 14:
357
+ # import pdb;pdb.set_trace()
358
+
359
+ # if xs_pred_back is not None:
360
+ # xs_pred = torch.cat([xs_pred[:6], xs_pred_back[6:12], xs_pred[6:]], dim=0)
361
+
362
+ # import pdb;pdb.set_trace()
363
+ if condition_similar_length: # and curr_frame+1!=n_frames:
364
+ xs_pred = xs_pred[:-condition_similar_length]
365
+
366
+ curr_frame += horizon
367
+ pbar.update(horizon)
368
+
369
+ self.frames = torch.cat([self.frames, xs_pred[n_context_frames:]])
370
+
371
+ xs_pred = self.decode(xs_pred[n_context_frames:])
372
+
373
+ return xs_pred[-1,0].cpu()
374
+
app.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+
4
+ import sys
5
+ import subprocess
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import hydra
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from omegaconf.omegaconf import open_dict
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ import cv2
17
+ import subprocess
18
+ from PIL import Image
19
+ from datetime import datetime
20
+ import spaces
21
+ from algorithms.worldmem import WorldMemMinecraft
22
+ from huggingface_hub import hf_hub_download
23
+ import tempfile
24
+ import os
25
+ import requests
26
+ from huggingface_hub import model_info
27
+
28
+ from experiments.exp_base import load_custom_checkpoint
29
+
30
+ torch.set_float32_matmul_precision("high")
31
+
32
+ def download_assets_if_needed():
33
+ ASSETS_URL_BASE = "https://huggingface.co/spaces/yslan/worldmem/resolve/main/assets/examples"
34
+ ASSETS_DIR = "assets/examples"
35
+ ASSETS = ['case1.npz', 'case2.npz', 'case3.npz', 'case4.npz']
36
+
37
+ if not os.path.exists(ASSETS_DIR):
38
+ os.makedirs(ASSETS_DIR)
39
+
40
+ # Download assets if they don't exist (total 4 files)
41
+ for filename in ASSETS:
42
+ filepath = os.path.join(ASSETS_DIR, filename)
43
+ if not os.path.exists(filepath):
44
+ print(f"Downloading {filename}...")
45
+ url = f"{ASSETS_URL_BASE}/{filename}"
46
+ response = requests.get(url)
47
+ if response.status_code == 200:
48
+ with open(filepath, "wb") as f:
49
+ f.write(response.content)
50
+ else:
51
+ print(f"Failed to download {filename}: {response.status_code}")
52
+
53
+ def parse_input_to_tensor(input_str):
54
+ """
55
+ Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
56
+ of the corresponding action key.
57
+
58
+ Args:
59
+ input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
60
+
61
+ Returns:
62
+ torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
63
+ """
64
+ # Get the length of the input sequence
65
+ seq_len = len(input_str)
66
+
67
+ # Initialize a zero tensor of shape (seq_len, 25)
68
+ action_tensor = torch.zeros((seq_len, 25))
69
+
70
+ # Iterate through the input string and update the corresponding positions
71
+ for i, char in enumerate(input_str):
72
+ action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
73
+ if action and action in ACTION_KEYS:
74
+ index = ACTION_KEYS.index(action)
75
+ action_tensor[i, index] = value # Set the corresponding action index to 1
76
+
77
+ return action_tensor
78
+
79
+ def load_image_as_tensor(image_path: str) -> torch.Tensor:
80
+ """
81
+ Load an image and convert it to a 0-1 normalized tensor.
82
+
83
+ Args:
84
+ image_path (str): Path to the image file.
85
+
86
+ Returns:
87
+ torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
88
+ """
89
+ if isinstance(image_path, str):
90
+ image = Image.open(image_path).convert("RGB") # Ensure it's RGB
91
+ else:
92
+ image = image_path
93
+ transform = transforms.Compose([
94
+ transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
95
+ ])
96
+ return transform(image)
97
+
98
+ def enable_amp(model, precision="16-mixed"):
99
+ original_forward = model.forward
100
+
101
+ def amp_forward(*args, **kwargs):
102
+ with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
103
+ return original_forward(*args, **kwargs)
104
+
105
+ model.forward = amp_forward
106
+ return model
107
+
108
+ download_assets_if_needed()
109
+
110
+ ACTION_KEYS = [
111
+ "inventory",
112
+ "ESC",
113
+ "hotbar.1",
114
+ "hotbar.2",
115
+ "hotbar.3",
116
+ "hotbar.4",
117
+ "hotbar.5",
118
+ "hotbar.6",
119
+ "hotbar.7",
120
+ "hotbar.8",
121
+ "hotbar.9",
122
+ "forward",
123
+ "back",
124
+ "left",
125
+ "right",
126
+ "cameraY",
127
+ "cameraX",
128
+ "jump",
129
+ "sneak",
130
+ "sprint",
131
+ "swapHands",
132
+ "attack",
133
+ "use",
134
+ "pickItem",
135
+ "drop",
136
+ ]
137
+
138
+ # Mapping of input keys to action names
139
+ KEY_TO_ACTION = {
140
+ "Q": ("forward", 1),
141
+ "E": ("back", 1),
142
+ "W": ("cameraY", -1),
143
+ "S": ("cameraY", 1),
144
+ "A": ("cameraX", -1),
145
+ "D": ("cameraX", 1),
146
+ "U": ("drop", 1),
147
+ "N": ("noop", 1),
148
+ "1": ("hotbar.1", 1),
149
+ }
150
+
151
+ example_images = [
152
+ ["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
153
+ ["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
154
+ ["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
155
+ ["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
156
+ ]
157
+
158
+ video_frames = []
159
+ input_history = ""
160
+ ICE_PLAINS_IMAGE = "assets/ice_plains.png"
161
+ DESERT_IMAGE = "assets/desert.png"
162
+ SAVANNA_IMAGE = "assets/savanna.png"
163
+ PLAINS_IMAGE = "assets/plans.png"
164
+ PLACE_IMAGE = "assets/place.png"
165
+ SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
166
+ SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
167
+
168
+ device = torch.device('cuda')
169
+
170
+ def save_video(frames, path="output.mp4", fps=10):
171
+ temp_path = path[:-4] + "_temp.mp4"
172
+ h, w, _ = frames[0].shape
173
+
174
+ out = cv2.VideoWriter(temp_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
175
+ for frame in frames:
176
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
177
+ out.release()
178
+
179
+ ffmpeg_cmd = [
180
+ "ffmpeg", "-y", "-i", temp_path,
181
+ "-c:v", "libx264", "-crf", "23", "-preset", "medium",
182
+ path
183
+ ]
184
+ subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
185
+ os.remove(temp_path)
186
+
187
+ cfg = OmegaConf.load("configurations/huggingface.yaml")
188
+ worldmem = WorldMemMinecraft(cfg)
189
+ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffusion_path)
190
+ load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
191
+ load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
192
+ worldmem.to("cuda").eval()
193
+ # worldmem = enable_amp(worldmem, precision="16-mixed")
194
+
195
+ actions = np.zeros((1, 25), dtype=np.float32)
196
+ poses = np.zeros((1, 5), dtype=np.float32)
197
+
198
+
199
+
200
+ def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
201
+ memory_poses, memory_c2w, memory_frame_idx):
202
+ return 5 * len(action) if memory_actions is not None else 5
203
+
204
+ @spaces.GPU(duration=get_duration_single_image_to_long_video)
205
+ def run_interactive(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
206
+ memory_poses, memory_c2w, memory_frame_idx):
207
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = worldmem.interactive(first_frame,
208
+ action,
209
+ first_pose,
210
+ device=device,
211
+ memory_latent_frames=memory_latent_frames,
212
+ memory_actions=memory_actions,
213
+ memory_poses=memory_poses,
214
+ memory_c2w=memory_c2w,
215
+ memory_frame_idx=memory_frame_idx)
216
+
217
+ return new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
218
+
219
+ def set_denoising_steps(denoising_steps, sampling_timesteps_state):
220
+ worldmem.sampling_timesteps = denoising_steps
221
+ worldmem.diffusion_model.sampling_timesteps = denoising_steps
222
+ sampling_timesteps_state = denoising_steps
223
+ print("set denoising steps to", worldmem.sampling_timesteps)
224
+ return sampling_timesteps_state
225
+
226
+ def set_context_length(context_length, sampling_context_length_state):
227
+ worldmem.n_tokens = context_length
228
+ sampling_context_length_state = context_length
229
+ print("set context length to", worldmem.n_tokens)
230
+ return sampling_context_length_state
231
+
232
+ def set_memory_condition_length(memory_condition_length, sampling_memory_condition_length_state):
233
+ worldmem.memory_condition_length = memory_condition_length
234
+ sampling_memory_condition_length_state = memory_condition_length
235
+ print("set memory length to", worldmem.memory_condition_length)
236
+ return sampling_memory_condition_length_state
237
+
238
+ def set_next_frame_length(next_frame_length, sampling_next_frame_length_state):
239
+ worldmem.next_frame_length = next_frame_length
240
+ sampling_next_frame_length_state = next_frame_length
241
+ print("set next frame length to", worldmem.next_frame_length)
242
+ return sampling_next_frame_length_state
243
+
244
+ def generate(keys, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
245
+ input_actions = parse_input_to_tensor(keys)
246
+
247
+ if memory_latent_frames is None:
248
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
249
+ actions[0],
250
+ poses[0],
251
+ device=device,
252
+ memory_latent_frames=memory_latent_frames,
253
+ memory_actions=memory_actions,
254
+ memory_poses=memory_poses,
255
+ memory_c2w=memory_c2w,
256
+ memory_frame_idx=memory_frame_idx)
257
+
258
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
259
+ input_actions,
260
+ None,
261
+ device=device,
262
+ memory_latent_frames=memory_latent_frames,
263
+ memory_actions=memory_actions,
264
+ memory_poses=memory_poses,
265
+ memory_c2w=memory_c2w,
266
+ memory_frame_idx=memory_frame_idx)
267
+
268
+ video_frames = np.concatenate([video_frames, new_frame[:,0]])
269
+
270
+
271
+ out_video = video_frames.transpose(0,2,3,1).copy()
272
+ out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
273
+ out_video = (out_video * 255).astype(np.uint8)
274
+
275
+ last_frame = out_video[-1].copy()
276
+ border_thickness = 2
277
+ out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
278
+ out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
279
+ out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
280
+ out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
281
+
282
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
283
+ save_video(out_video, temporal_video_path)
284
+ input_history += keys
285
+
286
+
287
+ # now = datetime.now()
288
+ # folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
289
+ # folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
290
+ # os.makedirs(folder_path, exist_ok=True)
291
+ # data_dict = {
292
+ # "input_history": input_history,
293
+ # "video_frames": video_frames,
294
+ # "memory_latent_frames": memory_latent_frames,
295
+ # "memory_actions": memory_actions,
296
+ # "memory_poses": memory_poses,
297
+ # "memory_c2w": memory_c2w,
298
+ # "memory_frame_idx": memory_frame_idx,
299
+ # }
300
+
301
+ # np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
302
+
303
+ return last_frame, temporal_video_path, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
304
+
305
+ def reset(selected_image):
306
+ memory_latent_frames = None
307
+ memory_poses = None
308
+ memory_actions = None
309
+ memory_c2w = None
310
+ memory_frame_idx = None
311
+ video_frames = load_image_as_tensor(selected_image).numpy()[None]
312
+ input_history = ""
313
+
314
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
315
+ actions[0],
316
+ poses[0],
317
+ device=device,
318
+ memory_latent_frames=memory_latent_frames,
319
+ memory_actions=memory_actions,
320
+ memory_poses=memory_poses,
321
+ memory_c2w=memory_c2w,
322
+ memory_frame_idx=memory_frame_idx,
323
+ )
324
+
325
+ return input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
326
+
327
+ def on_image_click(selected_image):
328
+ input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(selected_image)
329
+ return input_history, selected_image, selected_image, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
330
+
331
+ def set_memory(examples_case):
332
+ if examples_case == '1':
333
+ data_bundle = np.load("assets/examples/case1.npz")
334
+ input_history = data_bundle['input_history'].item()
335
+ video_frames = data_bundle['memory_frames']
336
+ memory_latent_frames = data_bundle['self_frames']
337
+ memory_actions = data_bundle['self_actions']
338
+ memory_poses = data_bundle['self_poses']
339
+ memory_c2w = data_bundle['self_memory_c2w']
340
+ memory_frame_idx = data_bundle['self_frame_idx']
341
+ elif examples_case == '2':
342
+ data_bundle = np.load("assets/examples/case2.npz")
343
+ input_history = data_bundle['input_history'].item()
344
+ video_frames = data_bundle['memory_frames']
345
+ memory_latent_frames = data_bundle['self_frames']
346
+ memory_actions = data_bundle['self_actions']
347
+ memory_poses = data_bundle['self_poses']
348
+ memory_c2w = data_bundle['self_memory_c2w']
349
+ memory_frame_idx = data_bundle['self_frame_idx']
350
+ elif examples_case == '3':
351
+ data_bundle = np.load("assets/examples/case3.npz")
352
+ input_history = data_bundle['input_history'].item()
353
+ video_frames = data_bundle['memory_frames']
354
+ memory_latent_frames = data_bundle['self_frames']
355
+ memory_actions = data_bundle['self_actions']
356
+ memory_poses = data_bundle['self_poses']
357
+ memory_c2w = data_bundle['self_memory_c2w']
358
+ memory_frame_idx = data_bundle['self_frame_idx']
359
+ elif examples_case == '4':
360
+ data_bundle = np.load("assets/examples/case4.npz")
361
+ input_history = data_bundle['input_history'].item()
362
+ video_frames = data_bundle['memory_frames']
363
+ memory_latent_frames = data_bundle['self_frames']
364
+ memory_actions = data_bundle['self_actions']
365
+ memory_poses = data_bundle['self_poses']
366
+ memory_c2w = data_bundle['self_memory_c2w']
367
+ memory_frame_idx = data_bundle['self_frame_idx']
368
+
369
+ out_video = video_frames.transpose(0,2,3,1)
370
+ out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
371
+ out_video = (out_video * 255).astype(np.uint8)
372
+
373
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
374
+ save_video(out_video, temporal_video_path)
375
+
376
+ return input_history, out_video[-1], temporal_video_path, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
377
+
378
+ css = """
379
+ h1 {
380
+ text-align: center;
381
+ display:block;
382
+ }
383
+ """
384
+
385
+ with gr.Blocks(css=css) as demo:
386
+ gr.Markdown(
387
+ """
388
+ # WORLDMEM: Long-term Consistent World Simulation with Memory
389
+ """
390
+ )
391
+
392
+ gr.Markdown(
393
+ """
394
+ ## 🚀 How to Explore WorldMem
395
+
396
+ Follow these simple steps to get started:
397
+
398
+ 1. **Choose a scene**.
399
+ 2. **Input your action sequence**.
400
+ 3. **Click "Generate"**.
401
+
402
+ - You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
403
+ - For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
404
+ - ⭐️ If you like this project, please [give it a star on GitHub]()!
405
+ - 💬 For questions or feedback, feel free to open an issue or email me at **zeqixiao1@gmail.com**.
406
+
407
+ Happy exploring! 🌍
408
+ """
409
+ )
410
+ # <div style="text-align: center;">
411
+ # <!-- Public Website -->
412
+ # <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
413
+ # <img src="https://img.shields.io/badge/public_website-8A2BE2">
414
+ # </a>
415
+
416
+ # <!-- GitHub Stars -->
417
+ # <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
418
+ # <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
419
+ # </a>
420
+
421
+ # <!-- Project Page -->
422
+ # <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
423
+ # <img src="https://img.shields.io/badge/project_page-blue">
424
+ # </a>
425
+
426
+ # <!-- arXiv Paper -->
427
+ # <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
428
+ # <img src="https://img.shields.io/badge/arXiv-paper-red">
429
+ # </a>
430
+ # </div>
431
+
432
+ example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
433
+ "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
434
+ "turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
435
+ "turn right→go forward→turn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
436
+ "turn right→look up→turn right→look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
437
+ "put item→go backward→put item→go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
438
+
439
+ selected_image = gr.State(ICE_PLAINS_IMAGE)
440
+
441
+ with gr.Row(variant="panel"):
442
+ with gr.Column():
443
+ gr.Markdown("🖼️ Start from this frame.")
444
+ image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
445
+ with gr.Column():
446
+ gr.Markdown("🎞️ Generated videos. New contents are marked in red box.")
447
+ video_display = gr.Video(autoplay=True, loop=True)
448
+
449
+ gr.Markdown("### 🏞️ Choose a scene and start generation.")
450
+
451
+ with gr.Row():
452
+ image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
453
+ image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
454
+ image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
455
+ image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
456
+ image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
457
+ image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
458
+
459
+
460
+ with gr.Row(variant="panel"):
461
+ with gr.Column(scale=2):
462
+ gr.Markdown("### 🕹️ Input action sequences for interaction.")
463
+ input_box = gr.Textbox(label="Action Sequences", placeholder="Enter action sequences here, e.g. (AAAAAAAAAAAADDDDDDDDDDDD)", lines=1, max_lines=1)
464
+ log_output = gr.Textbox(label="History Sequences", interactive=False)
465
+ gr.Markdown(
466
+ """
467
+ ### 💡 Action Key Guide
468
+
469
+ <pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
470
+ W: Turn up S: Turn down A: Turn left D: Turn right
471
+ Q: Go forward E: Go backward N: No-op U: Use item
472
+ </pre>
473
+ """
474
+ )
475
+ gr.Markdown("### 👇 Click to quickly set action sequence examples.")
476
+ with gr.Row():
477
+ buttons = []
478
+ for action_key in list(example_actions.keys())[:2]:
479
+ with gr.Column(scale=len(action_key)):
480
+ buttons.append(gr.Button(action_key))
481
+ with gr.Row():
482
+ for action_key in list(example_actions.keys())[2:4]:
483
+ with gr.Column(scale=len(action_key)):
484
+ buttons.append(gr.Button(action_key))
485
+ with gr.Row():
486
+ for action_key in list(example_actions.keys())[4:6]:
487
+ with gr.Column(scale=len(action_key)):
488
+ buttons.append(gr.Button(action_key))
489
+
490
+ with gr.Column(scale=1):
491
+ submit_button = gr.Button("🎬 Generate!", variant="primary")
492
+ reset_btn = gr.Button("🔄 Reset")
493
+
494
+ # gr.Markdown("<div style='flex-grow:1; height: 100px'></div>")
495
+
496
+ gr.Markdown("### ⚙️ Advanced Settings")
497
+
498
+ slider_denoising_step = gr.Slider(
499
+ minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
500
+ label="Denoising Steps",
501
+ info="Higher values yield better quality but slower speed"
502
+ )
503
+ slider_context_length = gr.Slider(
504
+ minimum=2, maximum=10, value=worldmem.n_tokens, step=1,
505
+ label="Context Length",
506
+ info="How many previous frames in temporal context window."
507
+ )
508
+ slider_memory_condition_length = gr.Slider(
509
+ minimum=4, maximum=16, value=worldmem.memory_condition_length, step=1,
510
+ label="Memory Length",
511
+ info="How many previous frames in memory window. (Recommended: 1, multi-frame generation is not stable yet)"
512
+ )
513
+ slider_next_frame_length = gr.Slider(
514
+ minimum=1, maximum=5, value=worldmem.next_frame_length, step=1,
515
+ label="Next Frame Length",
516
+ info="How many next frames to generate at once."
517
+ )
518
+
519
+ sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
520
+ sampling_context_length_state = gr.State(worldmem.n_tokens)
521
+ sampling_memory_condition_length_state = gr.State(worldmem.memory_condition_length)
522
+ sampling_next_frame_length_state = gr.State(worldmem.next_frame_length)
523
+
524
+ video_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
525
+ memory_latent_frames = gr.State()
526
+ memory_actions = gr.State()
527
+ memory_poses = gr.State()
528
+ memory_c2w = gr.State()
529
+ memory_frame_idx = gr.State()
530
+
531
+ def set_action(action):
532
+ return action
533
+
534
+
535
+
536
+ for button, action_key in zip(buttons, list(example_actions.keys())):
537
+ button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
538
+
539
+ gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
540
+
541
+ example_case = gr.Textbox(label="Case", visible=False)
542
+ image_output = gr.Image(visible=False)
543
+
544
+ examples = gr.Examples(
545
+ examples=example_images,
546
+ inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_condition_length],
547
+ cache_examples=False
548
+ )
549
+
550
+ example_case.change(
551
+ fn=set_memory,
552
+ inputs=[example_case],
553
+ outputs=[log_output, image_display, video_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx]
554
+ )
555
+
556
+ submit_button.click(generate, inputs=[input_box, log_output, video_frames,
557
+ memory_latent_frames, memory_actions, memory_poses,
558
+ memory_c2w, memory_frame_idx],
559
+ outputs=[image_display, video_display, log_output,
560
+ video_frames, memory_latent_frames, memory_actions, memory_poses,
561
+ memory_c2w, memory_frame_idx])
562
+
563
+ reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
564
+ image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
565
+ image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
566
+ image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
567
+ image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
568
+ image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
569
+ image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
570
+
571
+ slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
572
+ slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
573
+ slider_memory_condition_length.change(fn=set_memory_condition_length, inputs=[slider_memory_condition_length, sampling_memory_condition_length_state], outputs=sampling_memory_condition_length_state)
574
+ slider_next_frame_length.change(fn=set_next_frame_length, inputs=[slider_next_frame_length, sampling_next_frame_length_state], outputs=sampling_next_frame_length_state)
575
+
576
+ demo.launch(share=True)
assets/desert.png ADDED

Git LFS Details

  • SHA256: 3b85899ba8b3d111370fbcc25079d661a04d80563ecb43e55eb0c36f36c44b76
  • Pointer size: 131 Bytes
  • Size of remote file: 298 kB
assets/ice_plains.png ADDED

Git LFS Details

  • SHA256: ced8ab54ebb2c8c34b6fd10340dde905dc0f6a3096109521a08ee880688ae9cc
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
assets/place.png ADDED

Git LFS Details

  • SHA256: d4a1630a6f3e73c38e0dfec88bd902a5cf08bc8f857768e94199ea850d7eff81
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB
assets/plains.png ADDED

Git LFS Details

  • SHA256: adf5ad62acc998e35fec82c8e53b2559e26a7d78bb91a0f1cf8039a8610c3c78
  • Pointer size: 131 Bytes
  • Size of remote file: 263 kB
assets/rain_sunflower_plains.png ADDED

Git LFS Details

  • SHA256: 2488d19febab9dac852b5d0b6e6894ac276f48a1788220f2a8d38c7030cf7a98
  • Pointer size: 131 Bytes
  • Size of remote file: 387 kB
assets/savanna.png ADDED

Git LFS Details

  • SHA256: 5f8df1e988d84cd40f1af49eee73ef42d11f29d2c37a66bb5fda12d5b3278a55
  • Pointer size: 131 Bytes
  • Size of remote file: 339 kB
assets/sunflower_plains.png ADDED

Git LFS Details

  • SHA256: 98d828eb41fc7fb53909b66083db07208feffd66e88f4ae07092cc482e4e20df
  • Pointer size: 131 Bytes
  • Size of remote file: 283 kB
assets/worldmem_logo.png ADDED

Git LFS Details

  • SHA256: 8a1c0133cd1c20a557b800e5067ba97787c90e71c33f6b9695ecc44d78238426
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
calculate_fid.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Calculate FID (Fréchet Inception Distance) between predicted and ground truth videos.
4
+
5
+ Usage:
6
+ python calculate_fid.py --videos_dir /path/to/videos
7
+ python calculate_fid.py --videos_dir /path/to/videos --batch_size 32
8
+ """
9
+
10
+ import torch
11
+ import numpy as np
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+ import argparse
15
+ import cv2
16
+ from torchmetrics.image.fid import FrechetInceptionDistance
17
+
18
+
19
+ def load_video_frames(video_path, max_frames=None):
20
+ """
21
+ Load frames from a video file.
22
+
23
+ Args:
24
+ video_path: Path to the video file
25
+ max_frames: Maximum number of frames to load (None = all frames)
26
+
27
+ Returns:
28
+ torch.Tensor: Video frames with shape (T, C, H, W) in range [0, 255]
29
+ """
30
+ cap = cv2.VideoCapture(str(video_path))
31
+ frames = []
32
+ frame_count = 0
33
+
34
+ while True:
35
+ ret, frame = cap.read()
36
+ if not ret:
37
+ break
38
+
39
+ # Convert BGR to RGB
40
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
+ frames.append(frame)
42
+ frame_count += 1
43
+
44
+ if max_frames and frame_count >= max_frames:
45
+ break
46
+
47
+ cap.release()
48
+
49
+ if len(frames) == 0:
50
+ raise ValueError(f"No frames loaded from {video_path}")
51
+
52
+ # Convert to tensor: (T, H, W, C) -> (T, C, H, W)
53
+ frames = np.stack(frames, axis=0)
54
+ frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
55
+
56
+ return frames
57
+
58
+
59
+ def load_videos_from_directory(video_dir, max_frames_per_video=None, max_videos=None):
60
+ """
61
+ Load all videos from a directory.
62
+
63
+ Args:
64
+ video_dir: Directory containing .mp4 files
65
+ max_frames_per_video: Maximum frames to load per video
66
+ max_videos: Maximum number of videos to load
67
+
68
+ Returns:
69
+ torch.Tensor: All frames concatenated with shape (N, C, H, W)
70
+ """
71
+ video_dir = Path(video_dir)
72
+ video_paths = sorted(list(video_dir.glob("**/*.mp4")))
73
+
74
+ if max_videos:
75
+ video_paths = video_paths[:max_videos]
76
+
77
+ all_frames = []
78
+
79
+ print(f"Loading videos from {video_dir}")
80
+ print(f"Found {len(video_paths)} videos")
81
+
82
+ for video_path in tqdm(video_paths, desc="Loading videos"):
83
+ try:
84
+ frames = load_video_frames(video_path, max_frames=max_frames_per_video)
85
+ all_frames.append(frames)
86
+ except Exception as e:
87
+ print(f"\nWarning: Failed to load {video_path.name}: {e}")
88
+ continue
89
+
90
+ if len(all_frames) == 0:
91
+ raise ValueError(f"No videos loaded from {video_dir}")
92
+
93
+ # Concatenate all frames: (N_videos, T, C, H, W) -> (N_total_frames, C, H, W)
94
+ all_frames = torch.cat(all_frames, dim=0)
95
+
96
+ print(f"Loaded {all_frames.shape[0]} frames total")
97
+ print(f"Frame shape: {all_frames.shape[1:]}")
98
+
99
+ return all_frames
100
+
101
+
102
+ def calculate_fid(pred_dir, gt_dir, batch_size=32, device='cuda',
103
+ max_frames_per_video=None, max_videos=None):
104
+ """
105
+ Calculate FID between predicted and ground truth videos.
106
+
107
+ Args:
108
+ pred_dir: Directory containing predicted videos
109
+ gt_dir: Directory containing ground truth videos
110
+ batch_size: Batch size for FID calculation
111
+ device: Device to use ('cuda' or 'cpu')
112
+ max_frames_per_video: Maximum frames to load per video
113
+ max_videos: Maximum number of videos to load from each directory
114
+
115
+ Returns:
116
+ float: FID score
117
+ """
118
+ print("="*60)
119
+ print("FID Calculation")
120
+ print("="*60)
121
+ print(f"Pred directory: {pred_dir}")
122
+ print(f"GT directory: {gt_dir}")
123
+ print(f"Device: {device}")
124
+ print(f"Batch size: {batch_size}")
125
+ print("="*60 + "\n")
126
+
127
+ # Check if directories exist
128
+ pred_dir = Path(pred_dir)
129
+ gt_dir = Path(gt_dir)
130
+
131
+ if not pred_dir.exists():
132
+ raise ValueError(f"Pred directory does not exist: {pred_dir}")
133
+ if not gt_dir.exists():
134
+ raise ValueError(f"GT directory does not exist: {gt_dir}")
135
+
136
+ # Load videos
137
+ print("\n[1/3] Loading predicted videos...")
138
+ pred_frames = load_videos_from_directory(
139
+ pred_dir,
140
+ max_frames_per_video=max_frames_per_video,
141
+ max_videos=max_videos
142
+ )
143
+
144
+ print("\n[2/3] Loading ground truth videos...")
145
+ gt_frames = load_videos_from_directory(
146
+ gt_dir,
147
+ max_frames_per_video=max_frames_per_video,
148
+ max_videos=max_videos
149
+ )
150
+
151
+ # Initialize FID model
152
+ print("\n[3/3] Calculating FID...")
153
+ fid_model = FrechetInceptionDistance(normalize=True).to(device)
154
+
155
+ # Process pred frames in batches
156
+ print("Processing predicted frames...")
157
+ num_pred_frames = pred_frames.shape[0]
158
+ for i in tqdm(range(0, num_pred_frames, batch_size)):
159
+ batch = pred_frames[i:i+batch_size]
160
+ batch = batch.to(device)
161
+ fid_model.update(batch, real=False)
162
+
163
+ # Process gt frames in batches
164
+ print("Processing ground truth frames...")
165
+ num_gt_frames = gt_frames.shape[0]
166
+ for i in tqdm(range(0, num_gt_frames, batch_size)):
167
+ batch = gt_frames[i:i+batch_size]
168
+ batch = batch.to(device)
169
+ fid_model.update(batch, real=True)
170
+
171
+ # Compute FID
172
+ fid_score = fid_model.compute().item()
173
+
174
+ return fid_score
175
+
176
+
177
+ def main():
178
+ parser = argparse.ArgumentParser(
179
+ description="Calculate FID between predicted and ground truth videos"
180
+ )
181
+ parser.add_argument(
182
+ "--videos_dir",
183
+ type=str,
184
+ default="/mnt/worldmem_valid/outputs/2025-12-01/08-09-46/videos/test_vis",
185
+ help="Base directory containing 'pred' and 'gt' subdirectories"
186
+ )
187
+ parser.add_argument(
188
+ "--pred_dir",
189
+ type=str,
190
+ default=None,
191
+ help="Override pred directory (default: {videos_dir}/pred)"
192
+ )
193
+ parser.add_argument(
194
+ "--gt_dir",
195
+ type=str,
196
+ default=None,
197
+ help="Override gt directory (default: {videos_dir}/gt)"
198
+ )
199
+ parser.add_argument(
200
+ "--batch_size",
201
+ type=int,
202
+ default=32,
203
+ help="Batch size for FID calculation (default: 32)"
204
+ )
205
+ parser.add_argument(
206
+ "--device",
207
+ type=str,
208
+ default="cuda" if torch.cuda.is_available() else "cpu",
209
+ help="Device to use (default: cuda if available)"
210
+ )
211
+ parser.add_argument(
212
+ "--max_frames_per_video",
213
+ type=int,
214
+ default=None,
215
+ help="Maximum frames to load per video (default: None, load all)"
216
+ )
217
+ parser.add_argument(
218
+ "--max_videos",
219
+ type=int,
220
+ default=50,
221
+ help="Maximum number of videos to load (default: None, load all)"
222
+ )
223
+
224
+ args = parser.parse_args()
225
+
226
+ # Determine pred and gt directories
227
+ videos_dir = Path(args.videos_dir)
228
+
229
+ if args.pred_dir:
230
+ pred_dir = Path(args.pred_dir)
231
+ else:
232
+ pred_dir = videos_dir / "pred"
233
+
234
+ if args.gt_dir:
235
+ gt_dir = Path(args.gt_dir)
236
+ else:
237
+ gt_dir = videos_dir / "gt"
238
+
239
+ # Calculate FID
240
+ try:
241
+ fid_score = calculate_fid(
242
+ pred_dir=pred_dir,
243
+ gt_dir=gt_dir,
244
+ batch_size=args.batch_size,
245
+ device=args.device,
246
+ max_frames_per_video=args.max_frames_per_video,
247
+ max_videos=args.max_videos
248
+ )
249
+
250
+ # Print results
251
+ print("\n" + "="*60)
252
+ print("RESULTS")
253
+ print("="*60)
254
+ print(f"FID Score: {fid_score:.4f}")
255
+ print("="*60)
256
+
257
+ # Save results to file
258
+ output_file = videos_dir / "fid_results.txt"
259
+ with open(output_file, 'w') as f:
260
+ f.write(f"FID Score: {fid_score:.4f}\n")
261
+ f.write(f"Pred directory: {pred_dir}\n")
262
+ f.write(f"GT directory: {gt_dir}\n")
263
+
264
+ print(f"\nResults saved to: {output_file}")
265
+
266
+ except Exception as e:
267
+ print(f"\n✗ Error: {e}")
268
+ import traceback
269
+ traceback.print_exc()
270
+ return 1
271
+
272
+ return 0
273
+
274
+
275
+ if __name__ == "__main__":
276
+ exit(main())
277
+
configurations/algorithm/base_algo.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
2
+
3
+ debug: ${debug} # inherited from configurations/config.yaml
configurations/algorithm/base_pytorch_algo.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - base_algo # inherits from configurations/algorithm/base_algo.yaml
3
+
4
+ lr: ${experiment.training.lr}
configurations/algorithm/df_base.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_pytorch_algo
3
+
4
+ # dataset-dependent configurations
5
+ x_shape: ${dataset.observation_shape}
6
+ frame_stack: 1
7
+ frame_skip: 1
8
+ data_mean: ${dataset.data_mean}
9
+ data_std: ${dataset.data_std}
10
+ external_cond_dim: 0 #${dataset.action_dim}
11
+ context_frames: ${dataset.context_length}
12
+ # training hyperparameters
13
+ weight_decay: 1e-4
14
+ warmup_steps: 10000
15
+ optimizer_beta: [0.9, 0.999]
16
+ # diffusion-related
17
+ uncertainty_scale: 1
18
+ guidance_scale: 0.0
19
+ chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
20
+ scheduling_matrix: autoregressive
21
+ noise_level: random_all
22
+ causal: True
23
+
24
+ diffusion:
25
+ # training
26
+ objective: pred_x0
27
+ beta_schedule: cosine
28
+ schedule_fn_kwargs: {}
29
+ clip_noise: 20.0
30
+ use_snr: False
31
+ use_cum_snr: False
32
+ use_fused_snr: False
33
+ snr_clip: 5.0
34
+ cum_snr_decay: 0.98
35
+ timesteps: 1000
36
+ # sampling
37
+ sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
38
+ ddim_sampling_eta: 1.0
39
+ stabilization_level: 10
40
+ # architecture
41
+ architecture:
42
+ network_size: 64
configurations/algorithm/df_video_worldmemminecraft.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - df_base
3
+
4
+ n_frames: ${dataset.n_frames}
5
+ frame_skip: ${dataset.frame_skip}
6
+ metadata: ${dataset.metadata}
7
+
8
+ # training hyperparameters
9
+ weight_decay: 2e-3
10
+ warmup_steps: 1000
11
+ optimizer_beta: [0.9, 0.99]
12
+ action_cond_dim: 25
13
+ use_plucker: true
14
+
15
+ diffusion:
16
+ # training
17
+ beta_schedule: sigmoid
18
+ objective: pred_v
19
+ use_fused_snr: True
20
+ cum_snr_decay: 0.96
21
+ clip_noise: 20.
22
+ # sampling
23
+ sampling_timesteps: 20
24
+ ddim_sampling_eta: 0.0
25
+ stabilization_level: 15
26
+ # architecture
27
+ architecture:
28
+ network_size: 64
29
+ attn_heads: 4
30
+ attn_dim_head: 64
31
+ dim_mults: [1, 2, 4, 8]
32
+ resolution: ${dataset.resolution}
33
+ attn_resolutions: [16, 32, 64, 128]
34
+ use_init_temporal_attn: True
35
+ use_linear_attn: True
36
+ time_emb_type: rotary
37
+
38
+ _name: df_video_worldmemminecraft
configurations/dataset/base_dataset.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
2
+
3
+ debug: ${debug} # inherited from configurations/config.yaml
configurations/dataset/base_video.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_dataset
3
+
4
+ metadata: "data/${dataset.name}/metadata.json"
5
+ data_mean: "data/${dataset.name}/data_mean.npy"
6
+ data_std: "data/${dataset.name}/data_std.npy"
7
+ save_dir: ???
8
+ n_frames: 32
9
+ context_length: 4
10
+ resolution: 128
11
+ observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
12
+ external_cond_dim: 0
13
+ validation_multiplier: 1
14
+ frame_skip: 1
configurations/dataset/video_minecraft.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_video
3
+
4
+ save_dir: data/minecraft_simple_backforward
5
+ n_frames: 16 # TODO: increase later
6
+ resolution: 128
7
+ data_mean: 0.5
8
+ data_std: 0.5
9
+ action_cond_dim: 25
10
+ context_length: 1
11
+ frame_skip: 1
12
+ validation_multiplier: 1
13
+
14
+ _name: video_minecraft_oasis
configurations/experiment/base_experiment.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ debug: ${debug} # inherited from configurations/config.yaml
2
+ tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
configurations/experiment/base_pytorch.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inherites from base_experiment.yaml
2
+ # most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
3
+
4
+ defaults:
5
+ - base_experiment
6
+
7
+ tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
8
+ num_nodes: 1 # number of gpu servers used in large scale distributed training
9
+
10
+ training:
11
+ precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
12
+ compile: False # whether to compile the model with torch.compile
13
+ lr: 0.001 # learning rate
14
+ batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
15
+ max_epochs: 1000 # set to -1 to train forever
16
+ max_steps: -1 # set to -1 to train forever, will override max_epochs
17
+ max_time: null # set to something like "00:12:00:00" to enable
18
+ data:
19
+ num_workers: 4 # number of CPU threads for data preprocessing.
20
+ shuffle: True # whether training data will be shuffled
21
+ optim:
22
+ accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
23
+ gradient_clip_val: 0 # clip gradients with norm above this value, set to 0 to disable
24
+ checkpointing:
25
+ # these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
26
+ every_n_train_steps: 5000 # save a checkpoint every n train steps
27
+ every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
28
+ train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
29
+ enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
30
+
31
+ validation:
32
+ precision: 16-mixed
33
+ compile: False # whether to compile the model with torch.compile
34
+ batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
35
+ val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
36
+ val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
37
+ limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
38
+ inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
39
+ data:
40
+ num_workers: 4 # number of CPU threads for data preprocessing, for validation.
41
+ shuffle: False # whether validation data will be shuffled
42
+
43
+ test:
44
+ precision: 16-mixed
45
+ compile: False # whether to compile the model with torch.compile
46
+ batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
47
+ limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
48
+ data:
49
+ num_workers: 4 # number of CPU threads for data preprocessing, for test.
50
+ shuffle: False # whether test data will be shuffled
configurations/experiment/exp_video.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_pytorch
3
+
4
+ tasks: [training]
5
+
6
+ training:
7
+ lr: 2e-5
8
+ precision: 16-mixed
9
+ batch_size: 4
10
+ max_epochs: -1
11
+ max_steps: 2000005
12
+ checkpointing:
13
+ every_n_train_steps: 2500
14
+ optim:
15
+ gradient_clip_val: 1.0
16
+
17
+ validation:
18
+ val_every_n_step: 2500
19
+ val_every_n_epoch: null
20
+ batch_size: 4
21
+ limit_batch: 1
22
+
23
+ test:
24
+ limit_batch: 1
25
+ batch_size: 1
26
+
27
+ logging:
28
+ metrics:
29
+ # - fvd
30
+ # - fid
31
+ # - lpips
configurations/huggingface.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ n_tokens: 3
2
+ pose_cond_dim: 5
3
+ use_plucker: true
4
+ focal_length: 0.35
5
+ customized_validation: true
6
+ memory_condition_length: 8
7
+ log_video: true
8
+ relative_embedding: true
9
+ state_embed_only_on_qk: true
10
+ use_domain_adapter: false
11
+ use_memory_attention: true
12
+ add_timestamp_embedding: true
13
+ use_pose_prediction: true
14
+ require_pose_prediction: true
15
+ is_interactive: true
16
+ diffusion:
17
+ sampling_timesteps: 20
18
+ beta_schedule: sigmoid
19
+ objective: pred_v
20
+ use_fused_snr: True
21
+ cum_snr_decay: 0.96
22
+ clip_noise: 20.
23
+ ddim_sampling_eta: 0.0
24
+ stabilization_level: 15
25
+ schedule_fn_kwargs: {}
26
+ use_snr: False
27
+ use_cum_snr: False
28
+ snr_clip: 5.0
29
+ timesteps: 1000
30
+ # architecture
31
+ architecture:
32
+ network_size: 64
33
+ attn_heads: 4
34
+ attn_dim_head: 64
35
+ dim_mults: [1, 2, 4, 8]
36
+ resolution: ${dataset.resolution}
37
+ attn_resolutions: [16, 32, 64, 128]
38
+ use_init_temporal_attn: True
39
+ use_linear_attn: True
40
+ time_emb_type: rotary
41
+
42
+ weight_decay: 2e-3
43
+ warmup_steps: 10000
44
+ optimizer_beta: [0.9, 0.99]
45
+ action_cond_dim: 25
46
+ n_frames: 8
47
+ frame_skip: 1
48
+ frame_stack: 1
49
+ uncertainty_scale: 1
50
+ guidance_scale: 0.0
51
+ chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
52
+ scheduling_matrix: full_sequence
53
+ noise_level: random_all
54
+ causal: True
55
+ x_shape: [3, 360, 640]
56
+ context_frames: 1
57
+ diffusion_path: zeqixiao/worldmem_checkpoints/diffusion_only.ckpt
58
+ vae_path: zeqixiao/worldmem_checkpoints/vae_only.ckpt
59
+ pose_predictor_path: zeqixiao/worldmem_checkpoints/pose_prediction_model_only.ckpt
60
+ next_frame_length: 1
configurations/training.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration parsing starts here
2
+ defaults:
3
+ - experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
4
+ - dataset: video_minecraft # dataset yaml file name in configurations/dataset folder [fixme]
5
+ - algorithm: df_video_worldmemminecraft # algorithm yaml file name in configurations/algorithm folder [fixme]
6
+ - cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
7
+
8
+ debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
9
+
10
+ wandb:
11
+ entity: xizaoqu # wandb account name / organization name [fixme]
12
+ project: worldmem # wandb project name; if not provided, defaults to root folder name [fixme]
13
+ mode: online # set wandb logging to online, offline or dryrun
14
+
15
+ resume: null # wandb run id to resume logging and loading checkpoint from
16
+ load: null # wandb run id containing checkpoint or a path to a checkpoint file