chenyangqi commited on
Commit
3060b7e
1 Parent(s): 8094e3b

add FateZero code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. FateZero +0 -1
  2. FateZero/.gitignore +176 -0
  3. FateZero/LICENSE.md +21 -0
  4. FateZero/README.md +393 -0
  5. FateZero/colab_fatezero.ipynb +0 -0
  6. FateZero/config/.gitignore +1 -0
  7. FateZero/config/attribute/bear_tiger_lion_leopard.yaml +108 -0
  8. FateZero/config/attribute/bus_gpu.yaml +100 -0
  9. FateZero/config/attribute/cat_tiger_leopard_grass.yaml +112 -0
  10. FateZero/config/attribute/dog_robotic_corgi.yaml +103 -0
  11. FateZero/config/attribute/duck_rubber.yaml +99 -0
  12. FateZero/config/attribute/fox_wolf_snow.yaml +107 -0
  13. FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml +114 -0
  14. FateZero/config/attribute/squ_carrot_robot_eggplant.yaml +123 -0
  15. FateZero/config/attribute/swan_swa.yaml +102 -0
  16. FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml +83 -0
  17. FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml +84 -0
  18. FateZero/config/style/jeep_watercolor.yaml +94 -0
  19. FateZero/config/style/lily_monet.yaml +93 -0
  20. FateZero/config/style/rabit_pokemon.yaml +92 -0
  21. FateZero/config/style/sun_flower_van_gogh.yaml +86 -0
  22. FateZero/config/style/surf_ukiyo.yaml +90 -0
  23. FateZero/config/style/swan_cartoon.yaml +101 -0
  24. FateZero/config/style/train_shinkai.yaml +97 -0
  25. FateZero/config/teaser/jeep_posche.yaml +93 -0
  26. FateZero/config/teaser/jeep_watercolor.yaml +94 -0
  27. FateZero/data/.gitignore +4 -0
  28. FateZero/data/teaser_car-turn/00000.png +0 -0
  29. FateZero/data/teaser_car-turn/00001.png +0 -0
  30. FateZero/data/teaser_car-turn/00002.png +0 -0
  31. FateZero/data/teaser_car-turn/00003.png +0 -0
  32. FateZero/data/teaser_car-turn/00004.png +0 -0
  33. FateZero/data/teaser_car-turn/00005.png +0 -0
  34. FateZero/data/teaser_car-turn/00006.png +0 -0
  35. FateZero/data/teaser_car-turn/00007.png +0 -0
  36. FateZero/docs/EditingGuidance.md +65 -0
  37. FateZero/docs/OpenSans-Regular.ttf +0 -0
  38. FateZero/requirements.txt +17 -0
  39. FateZero/test_fatezero.py +290 -0
  40. FateZero/test_fatezero_dataset.py +52 -0
  41. FateZero/test_install.py +23 -0
  42. FateZero/train_tune_a_video.py +426 -0
  43. FateZero/video_diffusion/common/image_util.py +203 -0
  44. FateZero/video_diffusion/common/instantiate_from_config.py +33 -0
  45. FateZero/video_diffusion/common/logger.py +17 -0
  46. FateZero/video_diffusion/common/set_seed.py +28 -0
  47. FateZero/video_diffusion/common/util.py +73 -0
  48. FateZero/video_diffusion/data/dataset.py +158 -0
  49. FateZero/video_diffusion/data/transform.py +48 -0
  50. FateZero/video_diffusion/models/attention.py +482 -0
FateZero DELETED
@@ -1 +0,0 @@
1
- Subproject commit 6992d238770f464c03a0a74cbcec4f99da4635ec
 
 
FateZero/.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ start_hold
2
+ chenyangqi
3
+ trash/**
4
+ runs*/**
5
+ result/**
6
+ ckpt/**
7
+ ckpt
8
+ **.whl
9
+ stable-diffusion-v1-4
10
+ trash
11
+ # data/**
12
+
13
+ # Initially taken from Github's Python gitignore file
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # tests and logs
24
+ tests/fixtures/cached_*_text.txt
25
+ logs/
26
+ lightning_logs/
27
+ lang_code_data/
28
+
29
+ # Distribution / packaging
30
+ .Python
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ *.egg-info/
44
+ .installed.cfg
45
+ *.egg
46
+ MANIFEST
47
+
48
+ # PyInstaller
49
+ # Usually these files are written by a python script from a template
50
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
51
+ *.manifest
52
+ *.spec
53
+
54
+ # Installer logs
55
+ pip-log.txt
56
+ pip-delete-this-directory.txt
57
+
58
+ # Unit test / coverage reports
59
+ htmlcov/
60
+ .tox/
61
+ .nox/
62
+ .coverage
63
+ .coverage.*
64
+ .cache
65
+ nosetests.xml
66
+ coverage.xml
67
+ *.cover
68
+ .hypothesis/
69
+ .pytest_cache/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+
80
+ # Flask stuff:
81
+ instance/
82
+ .webassets-cache
83
+
84
+ # Scrapy stuff:
85
+ .scrapy
86
+
87
+ # Sphinx documentation
88
+ docs/_build/
89
+
90
+ # PyBuilder
91
+ target/
92
+
93
+ # Jupyter Notebook
94
+ .ipynb_checkpoints
95
+
96
+ # IPython
97
+ profile_default/
98
+ ipython_config.py
99
+
100
+ # pyenv
101
+ .python-version
102
+
103
+ # celery beat schedule file
104
+ celerybeat-schedule
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/
135
+
136
+ # vscode
137
+ .vs
138
+ .vscode
139
+
140
+ # Pycharm
141
+ .idea
142
+
143
+ # TF code
144
+ tensorflow_code
145
+
146
+ # Models
147
+ proc_data
148
+
149
+ # examples
150
+ runs
151
+ /runs_old
152
+ /wandb
153
+ /examples/runs
154
+ /examples/**/*.args
155
+ /examples/rag/sweep
156
+
157
+ # emacs
158
+ *.*~
159
+ debug.env
160
+
161
+ # vim
162
+ .*.swp
163
+
164
+ #ctags
165
+ tags
166
+
167
+ # pre-commit
168
+ .pre-commit*
169
+
170
+ # .lock
171
+ *.lock
172
+
173
+ # DS_Store (MacOS)
174
+ .DS_Store
175
+ # RL pipelines may produce mp4 outputs
176
+ *.mp4
FateZero/LICENSE.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Chenyang QI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
FateZero/README.md ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FateZero: Fusing Attentions for Zero-shot Text-based Video Editing
2
+
3
+ [Chenyang Qi](https://chenyangqiqi.github.io/), [Xiaodong Cun](http://vinthony.github.io/), [Yong Zhang](https://yzhang2016.github.io), [Chenyang Lei](https://chenyanglei.github.io/), [Xintao Wang](https://xinntao.github.io/), [Ying Shan](https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ), and [Qifeng Chen](https://cqf.io)
4
+
5
+ <a href='https://arxiv.org/abs/2303.09535'><img src='https://img.shields.io/badge/ArXiv-2303.09535-red'></a>
6
+ <a href='https://fate-zero-edit.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb)
7
+ [![GitHub](https://img.shields.io/github/stars/ChenyangQiQi/FateZero?style=social)](https://github.com/ChenyangQiQi/FateZero)
8
+
9
+
10
+ <!-- ![fatezero_demo](./docs/teaser.png) -->
11
+
12
+ <table class="center">
13
+ <td><img src="docs/gif_results/17_car_posche_01_concat_result.gif"></td>
14
+ <td><img src="docs/gif_results/3_sunflower_vangogh_conat_result.gif"></td>
15
+ <tr>
16
+ <td width=25% style="text-align:center;">"silver jeep ➜ posche car"</td>
17
+ <td width=25% style="text-align:center;">"+ Van Gogh style"</td>
18
+ <!-- <td width=25% style="text-align:center;">"Wonder Woman, wearing a cowboy hat, is skiing"</td>
19
+ <td width=25% style="text-align:center;">"A man, wearing pink clothes, is skiing at sunset"</td> -->
20
+ </tr>
21
+ </table >
22
+
23
+ ## Abstract
24
+ <b>TL;DR: Using FateZero, Edits your video via pretrained Diffusion models without training.</b>
25
+
26
+ <details><summary>CLICK for full abstract</summary>
27
+
28
+
29
+ > The diffusion-based generative models have achieved
30
+ remarkable success in text-based image generation. However,
31
+ since it contains enormous randomness in generation
32
+ progress, it is still challenging to apply such models for
33
+ real-world visual content editing, especially in videos. In
34
+ this paper, we propose FateZero, a zero-shot text-based editing method on real-world videos without per-prompt
35
+ training or use-specific mask. To edit videos consistently,
36
+ we propose several techniques based on the pre-trained
37
+ models. Firstly, in contrast to the straightforward DDIM
38
+ inversion technique, our approach captures intermediate
39
+ attention maps during inversion, which effectively retain
40
+ both structural and motion information. These maps are
41
+ directly fused in the editing process rather than generated
42
+ during denoising. To further minimize semantic leakage of
43
+ the source video, we then fuse self-attentions with a blending
44
+ mask obtained by cross-attention features from the source
45
+ prompt. Furthermore, we have implemented a reform of the
46
+ self-attention mechanism in denoising UNet by introducing
47
+ spatial-temporal attention to ensure frame consistency. Yet
48
+ succinct, our method is the first one to show the ability of
49
+ zero-shot text-driven video style and local attribute editing
50
+ from the trained text-to-image model. We also have a better
51
+ zero-shot shape-aware editing ability based on the text-tovideo
52
+ model. Extensive experiments demonstrate our
53
+ superior temporal consistency and editing capability than
54
+ previous works.
55
+ </details>
56
+
57
+ ## Changelog
58
+ - 2023.03.27 Release [`attribute editing config`](config/attribute) and
59
+ <!-- [`data`](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/Ee7J2IzZuaVGkefh-ZRp1GwB7RCUYU7MVJCKqeNWmOIpfg?e=dcOwb7) -->
60
+ [`data`](https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/attribute.zip) used in the paper.
61
+ - 2023.03.22 Upload a `colab notebook` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb). Enjoy the fun of zero-shot video-editing freely!
62
+ - 2023.03.22 Release [`style editing config`](config/style) and
63
+ <!--[`data`](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/EaTqRAuW0eJLj0z_JJrURkcBZCC3Zvgsdo6zsXHhpyHhHQ?e=FzuiNG) -->
64
+ [`data`](https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/style.zip)
65
+ used in the paper.
66
+ - 2023.03.21 [Editing guidance](docs/EditingGuidance.md) is provided to help users to edit in-the-wild video. Welcome to play and give feedback!
67
+ - 2023.03.21 Update the `codebase and configuration`. Now, it can run with lower resources (16G GPU and less than 16G CPU RAM) with [new configuration](config/low_resource_teaser) in `config/low_resource_teaser`.
68
+ <!-- A new option store all the attentions in hard disk, which require less ram. -->
69
+ - 2023.03.17 Release Code and Paper!
70
+
71
+ ## Todo
72
+
73
+ - [x] Release the edit config for teaser
74
+ - [x] Memory and runtime profiling
75
+ - [x] Hands-on guidance of hyperparameters tuning
76
+ - [x] Colab
77
+ - [x] Release configs for other result and in-the-wild dataset
78
+ <!-- - [x] Style editing: done
79
+ - [-] Attribute editing: in progress -->
80
+ - [-] hugging-face: inprogress
81
+ - [ ] Tune-a-video optimization and shape editing configs
82
+ - [ ] Release more application
83
+
84
+ ## Setup Environment
85
+ Our method is tested using cuda11, fp16 of accelerator and xformers on a single A100 or 3090.
86
+
87
+ ```bash
88
+ conda create -n fatezero38 python=3.8
89
+ conda activate fatezero38
90
+
91
+ pip install -r requirements.txt
92
+ ```
93
+
94
+ `xformers` is recommended for A100 GPU to save memory and running time.
95
+
96
+ <details><summary>Click for xformers installation </summary>
97
+
98
+ We find its installation not stable. You may try the following wheel:
99
+ ```bash
100
+ wget https://github.com/ShivamShrirao/xformers-wheels/releases/download/4c06c79/xformers-0.0.15.dev0+4c06c79.d20221201-cp38-cp38-linux_x86_64.whl
101
+ pip install xformers-0.0.15.dev0+4c06c79.d20221201-cp38-cp38-linux_x86_64.whl
102
+ ```
103
+
104
+ </details>
105
+
106
+ Validate the installation by
107
+ ```
108
+ python test_install.py
109
+ ```
110
+
111
+ Our environment is similar to Tune-A-video ([official](https://github.com/showlab/Tune-A-Video), [unofficial](https://github.com/bryandlee/Tune-A-Video)) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). You may check them for more details.
112
+
113
+
114
+ ## FateZero Editing
115
+
116
+ #### Style and Attribute Editing in Teaser
117
+
118
+ Download the [stable diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) (or other interesting image diffusion model) and put it to `./ckpt/stable-diffusion-v1-4`.
119
+
120
+ <details><summary>Click for bash command: </summary>
121
+
122
+ ```
123
+ mkdir ./ckpt
124
+ # download from huggingface face, takes 20G space
125
+ git lfs install
126
+ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
127
+ cd ./ckpt
128
+ ln -s ../stable-diffusion-v1-4 .
129
+ ```
130
+ </details>
131
+
132
+ Then, you could reproduce style and shape editing result in our teaser by running:
133
+
134
+ ```bash
135
+ accelerate launch test_fatezero.py --config config/teaser/jeep_watercolor.yaml
136
+ # or CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml
137
+ ```
138
+
139
+ <details><summary>The result is saved at `./result` . (Click for directory structure) </summary>
140
+
141
+ ```
142
+ result
143
+ ├── teaser
144
+ │ ├── jeep_posche
145
+ │ ├── jeep_watercolor
146
+ │ ├── cross-attention # visualization of cross-attention during inversion
147
+ │ ├── sample # result
148
+ │ ├── train_samples # the input video
149
+
150
+ ```
151
+
152
+ </details>
153
+
154
+ Editing 8 frames on an Nvidia 3090, use `100G CPU memory, 12G GPU memory` for editing. We also provide some [`low cost setting`](config/low_resource_teaser) of style editing by different hyper-parameters on a 16GB GPU.
155
+ You may try these low cost setting on colab.
156
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb)
157
+
158
+ More the speed and hardware benchmark [here](docs/EditingGuidance.md#ddim-hyperparameters).
159
+
160
+ #### Shape and large motion editing with Tune-A-Video
161
+
162
+ Besides style and attribution editing above, we also provide a `Tune-A-Video` [checkpoint](https://hkustconnect-my.sharepoint.com/:f:/g/personal/cqiaa_connect_ust_hk/EviSTWoAOs1EmHtqZruq50kBZu1E8gxDknCPigSvsS96uQ?e=492khj). You may download the it and move it to `./ckpt/jeep_tuned_200/`.
163
+ <!-- We provide the [Tune-a-Video](https://drive.google.com/file/d/166eNbabM6TeJVy7hxol2gL1kUGKHi3Do/view?usp=share_link), you could download the data, unzip and put it to `data`. : -->
164
+
165
+ <details><summary>The directory structure should like this: (Click for directory structure) </summary>
166
+
167
+ ```
168
+ ckpt
169
+ ├── stable-diffusion-v1-4
170
+ ├── jeep_tuned_200
171
+ ...
172
+ data
173
+ ├── car-turn
174
+ │ ├── 00000000.png
175
+ │ ├── 00000001.png
176
+ │ ├── ...
177
+ video_diffusion
178
+ ```
179
+ </details>
180
+
181
+ You could reproduce the shape editing result in our teaser by running:
182
+
183
+ ```bash
184
+ accelerate launch test_fatezero.py --config config/teaser/jeep_posche.yaml
185
+ ```
186
+
187
+
188
+ ### Reproduce other results in the paper (in progress)
189
+ <!-- Download the data of [style editing](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/EaTqRAuW0eJLj0z_JJrURkcBZCC3Zvgsdo6zsXHhpyHhHQ?e=FzuiNG) and [attribute editing](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/Ee7J2IzZuaVGkefh-ZRp1GwB7RCUYU7MVJCKqeNWmOIpfg?e=dcOwb7)
190
+ -->
191
+ Download the data of style editing and attribute editing
192
+ from [onedrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/cqiaa_connect_ust_hk/EkIeHj3CQiBNhm6iEEhJQZwBEBJNCGt3FsANmyqeAYbuXQ?e=FxYtJk) or from Github [Release](https://github.com/ChenyangQiQi/FateZero/releases/tag/v0.0.1).
193
+ <details><summary>Click for wget bash command: </summary>
194
+
195
+ ```
196
+ wget https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/attribute.zip
197
+ wget https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/style.zip
198
+ ```
199
+ </details>
200
+
201
+ Unzip and Place it in ['./data'](data). Then use the command in ['config/style'](config/style) and ['config/attribute'](config/attribute) to get the results.
202
+
203
+ The config of our tune-a-video ckpts will be updated latter.
204
+
205
+ ## Tuning guidance to edit YOUR video
206
+ We provided a tuning guidance to edit in-the-wild video at [here](./docs/EditingGuidance.md). The work is still in progress. Welcome to give your feedback in issues.
207
+
208
+ ## Style Editing Results with Stable Diffusion
209
+ We show the difference of source prompt and target prompt in the box below each video.
210
+
211
+ Note mp4 and gif files in this github page are compressed.
212
+ Please check our [Project Page](https://fate-zero-edit.github.io/) for mp4 files of original video editing results.
213
+ <table class="center">
214
+
215
+ <tr>
216
+ <td><img src="docs/gif_results/style/1_surf_ukiyo_01_concat_result.gif"></td>
217
+ <td><img src="docs/gif_results/style/2_car_watercolor_01_concat_result.gif"></td>
218
+ <td><img src="docs/gif_results/style/6_lily_monet_01_concat_result.gif"></td>
219
+ <!-- <td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/wonder-woman.gif"></td>
220
+ <td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/pink-sunset.gif"></td> -->
221
+ </tr>
222
+ <tr>
223
+ <td width=25% style="text-align:center;">"+ Ukiyo-e style"</td>
224
+ <td width=25% style="text-align:center;">"+ watercolor painting"</td>
225
+ <td width=25% style="text-align:center;">"+ Monet style"</td>
226
+ </tr>
227
+
228
+ <tr>
229
+ <td><img src="docs/gif_results/style/4_rabit_pokemon_01_concat_result.gif"></td>
230
+ <td><img src="docs/gif_results/style/5_train_shikai_01_concat_result.gif"></td>
231
+ <td><img src="docs/gif_results/style/7_swan_carton_01_concat_result.gif"></td>
232
+
233
+ </tr>
234
+ <tr>
235
+
236
+ </tr>
237
+ <tr>
238
+ <td width=25% style="text-align:center;">"+ Pokémon cartoon style"</td>
239
+ <td width=25% style="text-align:center;">"+ Makoto Shinkai style"</td>
240
+ <td width=25% style="text-align:center;">"+ cartoon style"</td>
241
+ </tr>
242
+ </table>
243
+
244
+ ## Attribute Editing Results with Stable Diffusion
245
+ <table class="center">
246
+
247
+ <tr>
248
+
249
+ <td><img src="docs/gif_results/attri/15_rabbit_eat_01_concat_result.gif"></td>
250
+ <td><img src="docs/gif_results/attri/15_rabbit_eat_02_concat_result.gif"></td>
251
+ <td><img src="docs/gif_results/attri/15_rabbit_eat_04_concat_result.gif"></td>
252
+
253
+ </tr>
254
+ <tr>
255
+ <td width=25% style="text-align:center;">"rabbit, strawberry ➜ white rabbit, flower"</td>
256
+ <td width=25% style="text-align:center;">"rabbit, strawberry ➜ squirrel, carrot"</td>
257
+ <td width=25% style="text-align:center;">"rabbit, strawberry ➜ white rabbit, leaves"</td>
258
+
259
+ </tr>
260
+ <tr>
261
+
262
+ <td><img src="docs/gif_results/attri/16_sq_eat_04_concat_result.gif"></td>
263
+ <td><img src="docs/gif_results/attri/16_sq_eat_02_concat_result.gif"></td>
264
+ <td><img src="docs/gif_results/attri/16_sq_eat_03_concat_result.gif"></td>
265
+
266
+ </tr>
267
+ <tr>
268
+ <td width=25% style="text-align:center;">"squirrel ➜ robot squirrel"</td>
269
+ <td width=25% style="text-align:center;">"squirrel, Carrot ➜ rabbit, eggplant"</td>
270
+ <td width=25% style="text-align:center;">"squirrel, Carrot ➜ robot mouse, screwdriver"</td>
271
+
272
+ </tr>
273
+
274
+ <tr>
275
+
276
+ <td><img src="docs/gif_results/attri/13_bear_tiger_leopard_lion_01_concat_result.gif"></td>
277
+ <td><img src="docs/gif_results/attri/13_bear_tiger_leopard_lion_02_concat_result.gif"></td>
278
+ <td><img src="docs/gif_results/attri/13_bear_tiger_leopard_lion_03_concat_result.gif"></td>
279
+
280
+ </tr>
281
+ <tr>
282
+ <td width=25% style="text-align:center;">"bear ➜ a red tiger"</td>
283
+ <td width=25% style="text-align:center;">"bear ➜ a yellow leopard"</td>
284
+ <td width=25% style="text-align:center;">"bear ➜ a brown lion"</td>
285
+
286
+ </tr>
287
+ <tr>
288
+
289
+ <td><img src="docs/gif_results/attri/14_cat_grass_tiger_corgin_02_concat_result.gif"></td>
290
+ <td><img src="docs/gif_results/attri/14_cat_grass_tiger_corgin_03_concat_result.gif"></td>
291
+ <td><img src="docs/gif_results/attri/14_cat_grass_tiger_corgin_04_concat_result.gif"></td>
292
+
293
+ </tr>
294
+ <tr>
295
+ <td width=25% style="text-align:center;">"cat ➜ black cat, grass..."</td>
296
+ <td width=25% style="text-align:center;">"cat ➜ red tiger"</td>
297
+ <td width=25% style="text-align:center;">"cat ➜ Shiba-Inu"</td>
298
+
299
+ </tr>
300
+
301
+ <tr>
302
+
303
+ <td><img src="docs/gif_results/attri/10_bus_gpu_01_concat_result.gif"></td>
304
+ <td><img src="docs/gif_results/attri/11_dog_robotic_corgin_01_concat_result.gif"></td>
305
+ <td><img src="docs/gif_results/attri/11_dog_robotic_corgin_02_concat_result.gif"></td>
306
+
307
+ </tr>
308
+ <tr>
309
+ <td width=25% style="text-align:center;">"bus ➜ GPU"</td>
310
+ <td width=25% style="text-align:center;">"gray dog ➜ yellow corgi"</td>
311
+ <td width=25% style="text-align:center;">"gray dog ➜ robotic dog"</td>
312
+
313
+ </tr>
314
+ <tr>
315
+
316
+ <td><img src="docs/gif_results/attri/9_duck_rubber_01_concat_result.gif"></td>
317
+ <td><img src="docs/gif_results/attri/12_fox_snow_wolf_01_concat_result.gif"></td>
318
+ <td><img src="docs/gif_results/attri/12_fox_snow_wolf_02_concat_result.gif"></td>
319
+
320
+ </tr>
321
+ <tr>
322
+ <td width=25% style="text-align:center;">"white duck ➜ yellow rubber duck"</td>
323
+ <td width=25% style="text-align:center;">"grass ➜ snow"</td>
324
+ <td width=25% style="text-align:center;">"white fox ➜ grey wolf"</td>
325
+
326
+ </tr>
327
+
328
+
329
+ </table>
330
+
331
+ ## Shape and large motion editing with Tune-A-Video
332
+ <table class="center">
333
+
334
+ <tr>
335
+ <td><img src="docs/gif_results/shape/17_car_posche_01_concat_result.gif"></td>
336
+ <td><img src="docs/gif_results/shape/18_swan_01_concat_result.gif"></td>
337
+ <td><img src="docs/gif_results/shape/18_swan_02_concat_result.gif"></td>
338
+ <!-- <td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/wonder-woman.gif"></td>
339
+ <td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/pink-sunset.gif"></td> -->
340
+ </tr>
341
+ <tr>
342
+ <td width=25% style="text-align:center;">"silver jeep ➜ posche car"</td>
343
+ <td width=25% style="text-align:center;">"Swan ➜ White Duck"</td>
344
+ <td width=25% style="text-align:center;">"Swan ➜ Pink flamingo"</td>
345
+ </tr>
346
+
347
+ <tr>
348
+ <td><img src="docs/gif_results/shape/19_man_wonder_01_concat_result.gif"></td>
349
+ <td><img src="docs/gif_results/shape/19_man_wonder_02_concat_result.gif"></td>
350
+ <td><img src="docs/gif_results/shape/19_man_wonder_03_concat_result.gif"></td>
351
+
352
+ </tr>
353
+ <tr>
354
+
355
+ </tr>
356
+ <tr>
357
+ <td width=25% style="text-align:center;">"A man ➜ A Batman"</td>
358
+ <td width=25% style="text-align:center;">"A man ➜ A Wonder Woman, With cowboy hat"</td>
359
+ <td width=25% style="text-align:center;">"A man ➜ A Spider-Man"</td>
360
+ </tr>
361
+ </table>
362
+
363
+
364
+ ## Demo Video
365
+
366
+ https://user-images.githubusercontent.com/45789244/225698509-79c14793-3153-4bba-9d6e-ede7d811d7f8.mp4
367
+
368
+ The video here is compressed due to the size limit of github.
369
+ The original full resolution video is [here](https://hkustconnect-my.sharepoint.com/:v:/g/personal/cqiaa_connect_ust_hk/EXKDI_nahEhKtiYPvvyU9SkBDTG2W4G1AZ_vkC7ekh3ENw?e=Xhgtmk).
370
+
371
+
372
+ ## Citation
373
+
374
+ ```
375
+ @misc{qi2023fatezero,
376
+ title={FateZero: Fusing Attentions for Zero-shot Text-based Video Editing},
377
+ author={Chenyang Qi and Xiaodong Cun and Yong Zhang and Chenyang Lei and Xintao Wang and Ying Shan and Qifeng Chen},
378
+ year={2023},
379
+ eprint={2303.09535},
380
+ archivePrefix={arXiv},
381
+ primaryClass={cs.CV}
382
+ }
383
+ ```
384
+
385
+
386
+ ## Acknowledgements
387
+
388
+ This repository borrows heavily from [Tune-A-Video](https://github.com/showlab/Tune-A-Video) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). thanks the authors for sharing their code and models.
389
+
390
+ ## Maintenance
391
+
392
+ This is the codebase for our research work. We are still working hard to update this repo and more details are coming in days. If you have any questions or ideas to discuss, feel free to contact [Chenyang Qi](cqiaa@connect.ust.hk) or [Xiaodong Cun](vinthony@gmail.com).
393
+
FateZero/colab_fatezero.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
FateZero/config/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ # debug/**
FateZero/config/attribute/bear_tiger_lion_leopard.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/bear_tiger_lion_leopard.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+
6
+ train_dataset:
7
+ path: "data/attribute/bear_tiger_lion_leopard"
8
+ prompt: "a brown bear walking on the rock against a wall"
9
+ n_sample_frame: 8
10
+ # n_sample_frame: 22
11
+ sampling_rate: 1
12
+ stride: 80
13
+ offset:
14
+ left: 0
15
+ right: 0
16
+ top: 0
17
+ bottom: 0
18
+
19
+ validation_sample_logger_config:
20
+ use_train_latents: True
21
+ use_inversion_attention: True
22
+ guidance_scale: 7.5
23
+ prompts: [
24
+ # source prompt
25
+ a brown bear walking on the rock against a wall,
26
+
27
+ # foreground texture style
28
+ a red tiger walking on the rock against a wall,
29
+ a yellow leopard walking on the rock against a wall,
30
+ a brown lion walking on the rock against a wall,
31
+ ]
32
+ p2p_config:
33
+ 0:
34
+ # Whether to directly copy the cross attention from source
35
+ # True: directly copy, better for object replacement
36
+ # False: keep source attention, better for style
37
+ is_replace_controller: False
38
+
39
+ # Semantic preserving and replacement Debug me
40
+ cross_replace_steps:
41
+ default_: 0.8
42
+
43
+ # Source background structure preserving, in [0, 1].
44
+ # e.g., =0.6 Replace the first 60% steps self-attention
45
+ self_replace_steps: 0.6
46
+
47
+
48
+ # Amplify the target-words cross attention, larger value, more close to target
49
+ eq_params:
50
+ words: ["silver", "sculpture"]
51
+ values: [2,2]
52
+
53
+ # Target structure-divergence hyperparames
54
+ # If you change the shape of object better to use all three line, otherwise, no need.
55
+ # Without following three lines, all self-attention will be replaced
56
+ blend_words: [['cat',], ["cat",]]
57
+ masked_self_attention: True
58
+ # masked_latents: False # performance not so good in our case, need debug
59
+ bend_th: [2, 2]
60
+ # preserve source structure of blend_words , [0, 1]
61
+ # default is bend_th: [2, 2] # preserve all source self-attention
62
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
63
+
64
+
65
+ 1:
66
+ is_replace_controller: true
67
+ cross_replace_steps:
68
+ default_: 0.7
69
+ self_replace_steps: 0.7
70
+ 2:
71
+ is_replace_controller: true
72
+ cross_replace_steps:
73
+ default_: 0.7
74
+ self_replace_steps: 0.7
75
+ 3:
76
+ is_replace_controller: true
77
+ cross_replace_steps:
78
+ default_: 0.7
79
+ self_replace_steps: 0.7
80
+
81
+
82
+
83
+
84
+ clip_length: "${..train_dataset.n_sample_frame}"
85
+ sample_seeds: [0]
86
+ val_all_frames: False
87
+
88
+ num_inference_steps: 50
89
+ prompt2prompt_edit: True
90
+
91
+
92
+ model_config:
93
+ lora: 160
94
+ # temporal_downsample_time: 4
95
+ SparseCausalAttention_index: ['mid']
96
+ least_sc_channel: 640
97
+ # least_sc_channel: 100000
98
+
99
+ test_pipeline_config:
100
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
101
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
102
+
103
+ epsilon: 1e-5
104
+ train_steps: 10
105
+ seed: 0
106
+ learning_rate: 1e-5
107
+ train_temporal_conv: False
108
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/bus_gpu.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/bus_gpu.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+
6
+ train_dataset:
7
+ path: "data/attribute/bus_gpu"
8
+ prompt: "a white and blue bus on the road"
9
+ n_sample_frame: 8
10
+ # n_sample_frame: 22
11
+ sampling_rate: 1
12
+ stride: 80
13
+ offset:
14
+ left: 0
15
+ right: 0
16
+ top: 0
17
+ bottom: 0
18
+
19
+ validation_sample_logger_config:
20
+ use_train_latents: True
21
+ use_inversion_attention: True
22
+ guidance_scale: 7.5
23
+ prompts: [
24
+ # source prompt
25
+ a white and blue bus on the road,
26
+
27
+ # foreground texture style
28
+ a black and green GPU on the road
29
+ ]
30
+ p2p_config:
31
+ 0:
32
+ # Whether to directly copy the cross attention from source
33
+ # True: directly copy, better for object replacement
34
+ # False: keep source attention, better for style
35
+ is_replace_controller: False
36
+
37
+ # Semantic preserving and replacement Debug me
38
+ cross_replace_steps:
39
+ default_: 0.8
40
+
41
+ # Source background structure preserving, in [0, 1].
42
+ # e.g., =0.6 Replace the first 60% steps self-attention
43
+ self_replace_steps: 0.6
44
+
45
+
46
+ # Amplify the target-words cross attention, larger value, more close to target
47
+ eq_params:
48
+ words: ["silver", "sculpture"]
49
+ values: [2,2]
50
+
51
+ # Target structure-divergence hyperparames
52
+ # If you change the shape of object better to use all three line, otherwise, no need.
53
+ # Without following three lines, all self-attention will be replaced
54
+ blend_words: [['cat',], ["cat",]]
55
+ masked_self_attention: True
56
+ # masked_latents: False # performance not so good in our case, need debug
57
+ bend_th: [2, 2]
58
+ # preserve source structure of blend_words , [0, 1]
59
+ # default is bend_th: [2, 2] # preserve all source self-attention
60
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
61
+
62
+
63
+ 1:
64
+ is_replace_controller: true
65
+ cross_replace_steps:
66
+ default_: 0.1
67
+ self_replace_steps: 0.1
68
+
69
+ eq_params:
70
+ words: ["Nvidia", "GPU"]
71
+ values: [10, 10] # amplify attention to the word "tiger" by *2
72
+
73
+
74
+
75
+
76
+ clip_length: "${..train_dataset.n_sample_frame}"
77
+ sample_seeds: [0]
78
+ val_all_frames: False
79
+
80
+ num_inference_steps: 50
81
+ prompt2prompt_edit: True
82
+
83
+
84
+ model_config:
85
+ lora: 160
86
+ # temporal_downsample_time: 4
87
+ SparseCausalAttention_index: ['mid']
88
+ least_sc_channel: 640
89
+ # least_sc_channel: 100000
90
+
91
+ test_pipeline_config:
92
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
93
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
94
+
95
+ epsilon: 1e-5
96
+ train_steps: 10
97
+ seed: 0
98
+ learning_rate: 1e-5
99
+ train_temporal_conv: False
100
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/cat_tiger_leopard_grass.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/cat_tiger_leopard_grass.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+
6
+ train_dataset:
7
+ path: "data/attribute/cat_tiger_leopard_grass"
8
+ prompt: "A black cat walking on the floor next to a wall"
9
+ n_sample_frame: 8
10
+ # n_sample_frame: 22
11
+ sampling_rate: 1
12
+ stride: 80
13
+ offset:
14
+ left: 0
15
+ right: 0
16
+ top: 0
17
+ bottom: 0
18
+
19
+ validation_sample_logger_config:
20
+ use_train_latents: True
21
+ use_inversion_attention: True
22
+ guidance_scale: 7.5
23
+ prompts: [
24
+ # source prompt
25
+ A black cat walking on the floor next to a wall,
26
+ A black cat walking on the grass next to a wall,
27
+ A red tiger walking on the floor next to a wall,
28
+ a yellow cute Shiba-Inu walking on the floor next to a wall,
29
+ a yellow cute leopard walking on the floor next to a wall,
30
+ ]
31
+ p2p_config:
32
+ 0:
33
+ # Whether to directly copy the cross attention from source
34
+ # True: directly copy, better for object replacement
35
+ # False: keep source attention, better for style
36
+ is_replace_controller: False
37
+
38
+ # Semantic preserving and replacement Debug me
39
+ cross_replace_steps:
40
+ default_: 0.8
41
+
42
+ # Source background structure preserving, in [0, 1].
43
+ # e.g., =0.6 Replace the first 60% steps self-attention
44
+ self_replace_steps: 0.6
45
+
46
+
47
+ # Amplify the target-words cross attention, larger value, more close to target
48
+ eq_params:
49
+ words: ["silver", "sculpture"]
50
+ values: [2,2]
51
+
52
+ # Target structure-divergence hyperparames
53
+ # If you change the shape of object better to use all three line, otherwise, no need.
54
+ # Without following three lines, all self-attention will be replaced
55
+ blend_words: [['cat',], ["cat",]]
56
+ masked_self_attention: True
57
+ # masked_latents: False # performance not so good in our case, need debug
58
+ bend_th: [2, 2]
59
+ # preserve source structure of blend_words , [0, 1]
60
+ # default is bend_th: [2, 2] # preserve all source self-attention
61
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
62
+
63
+
64
+ 1:
65
+ is_replace_controller: false
66
+ cross_replace_steps:
67
+ default_: 0.5
68
+ self_replace_steps: 0.5
69
+ 2:
70
+ is_replace_controller: false
71
+ cross_replace_steps:
72
+ default_: 0.5
73
+ self_replace_steps: 0.5
74
+ 3:
75
+ is_replace_controller: false
76
+ cross_replace_steps:
77
+ default_: 0.5
78
+ self_replace_steps: 0.5
79
+ 4:
80
+ is_replace_controller: false
81
+ cross_replace_steps:
82
+ default_: 0.7
83
+ self_replace_steps: 0.7
84
+
85
+
86
+
87
+
88
+ clip_length: "${..train_dataset.n_sample_frame}"
89
+ sample_seeds: [0]
90
+ val_all_frames: False
91
+
92
+ num_inference_steps: 50
93
+ prompt2prompt_edit: True
94
+
95
+
96
+ model_config:
97
+ lora: 160
98
+ # temporal_downsample_time: 4
99
+ SparseCausalAttention_index: ['mid']
100
+ least_sc_channel: 640
101
+ # least_sc_channel: 100000
102
+
103
+ test_pipeline_config:
104
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
105
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
106
+
107
+ epsilon: 1e-5
108
+ train_steps: 10
109
+ seed: 0
110
+ learning_rate: 1e-5
111
+ train_temporal_conv: False
112
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/dog_robotic_corgi.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/dog_robotic_corgi.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "data/attribute/gray_dog"
7
+ prompt: "A gray dog sitting on the mat"
8
+ n_sample_frame: 8
9
+ # n_sample_frame: 22
10
+ sampling_rate: 1
11
+ stride: 80
12
+ offset:
13
+ left: 0
14
+ right: 0
15
+ top: 0
16
+ bottom: 0
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: True
20
+ use_inversion_attention: True
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ # source prompt
24
+ A gray dog sitting on the mat,
25
+
26
+ # foreground texture style
27
+ A robotic dog sitting on the mat,
28
+ A yellow corgi sitting on the mat
29
+ ]
30
+ p2p_config:
31
+ 0:
32
+ # Whether to directly copy the cross attention from source
33
+ # True: directly copy, better for object replacement
34
+ # False: keep source attention, better for style
35
+ is_replace_controller: False
36
+
37
+ # Semantic preserving and replacement Debug me
38
+ cross_replace_steps:
39
+ default_: 0.8
40
+
41
+ # Source background structure preserving, in [0, 1].
42
+ # e.g., =0.6 Replace the first 60% steps self-attention
43
+ self_replace_steps: 0.6
44
+
45
+
46
+ # Amplify the target-words cross attention, larger value, more close to target
47
+ eq_params:
48
+ words: ["silver", "sculpture"]
49
+ values: [2,2]
50
+
51
+ # Target structure-divergence hyperparames
52
+ # If you change the shape of object better to use all three line, otherwise, no need.
53
+ # Without following three lines, all self-attention will be replaced
54
+ blend_words: [['cat',], ["cat",]]
55
+ masked_self_attention: True
56
+ # masked_latents: False # performance not so good in our case, need debug
57
+ bend_th: [2, 2]
58
+ # preserve source structure of blend_words , [0, 1]
59
+ # default is bend_th: [2, 2] # preserve all source self-attention
60
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
61
+
62
+
63
+ 1:
64
+ is_replace_controller: false
65
+ cross_replace_steps:
66
+ default_: 0.5
67
+ self_replace_steps: 0.5
68
+
69
+ eq_params:
70
+ words: ["robotic"]
71
+ values: [10] # amplify attention to the word "tiger" by *2
72
+
73
+ 2:
74
+ is_replace_controller: false
75
+ cross_replace_steps:
76
+ default_: 0.5
77
+ self_replace_steps: 0.5
78
+
79
+ clip_length: "${..train_dataset.n_sample_frame}"
80
+ sample_seeds: [0]
81
+ val_all_frames: False
82
+
83
+ num_inference_steps: 50
84
+ prompt2prompt_edit: True
85
+
86
+
87
+ model_config:
88
+ lora: 160
89
+ # temporal_downsample_time: 4
90
+ SparseCausalAttention_index: ['mid']
91
+ least_sc_channel: 640
92
+ # least_sc_channel: 100000
93
+
94
+ test_pipeline_config:
95
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
96
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
97
+
98
+ epsilon: 1e-5
99
+ train_steps: 10
100
+ seed: 0
101
+ learning_rate: 1e-5
102
+ train_temporal_conv: False
103
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/duck_rubber.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/duck_rubber.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "data/attribute/duck_rubber"
7
+ prompt: "a sleepy white duck"
8
+ n_sample_frame: 8
9
+ # n_sample_frame: 22
10
+ sampling_rate: 1
11
+ stride: 80
12
+ offset:
13
+ left: 0
14
+ right: 0
15
+ top: 0
16
+ bottom: 0
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: True
20
+ use_inversion_attention: True
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ # source prompt
24
+ a sleepy white duck,
25
+
26
+ # foreground texture style
27
+ a sleepy yellow rubber duck
28
+ ]
29
+ p2p_config:
30
+ 0:
31
+ # Whether to directly copy the cross attention from source
32
+ # True: directly copy, better for object replacement
33
+ # False: keep source attention, better for style
34
+ is_replace_controller: False
35
+
36
+ # Semantic preserving and replacement Debug me
37
+ cross_replace_steps:
38
+ default_: 0.8
39
+
40
+ # Source background structure preserving, in [0, 1].
41
+ # e.g., =0.6 Replace the first 60% steps self-attention
42
+ self_replace_steps: 0.6
43
+
44
+
45
+ # Amplify the target-words cross attention, larger value, more close to target
46
+ eq_params:
47
+ words: ["silver", "sculpture"]
48
+ values: [2,2]
49
+
50
+ # Target structure-divergence hyperparames
51
+ # If you change the shape of object better to use all three line, otherwise, no need.
52
+ # Without following three lines, all self-attention will be replaced
53
+ blend_words: [['cat',], ["cat",]]
54
+ masked_self_attention: True
55
+ # masked_latents: False # performance not so good in our case, need debug
56
+ bend_th: [2, 2]
57
+ # preserve source structure of blend_words , [0, 1]
58
+ # default is bend_th: [2, 2] # preserve all source self-attention
59
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
60
+
61
+
62
+ 1:
63
+ is_replace_controller: False
64
+ cross_replace_steps:
65
+ default_: 0.7
66
+ self_replace_steps: 0.7
67
+
68
+ # eq_params:
69
+ # words: ["yellow", "rubber"]
70
+ # values: [10, 10] # amplify attention to the word "tiger" by *2
71
+
72
+
73
+
74
+
75
+ clip_length: "${..train_dataset.n_sample_frame}"
76
+ sample_seeds: [0]
77
+ val_all_frames: False
78
+
79
+ num_inference_steps: 50
80
+ prompt2prompt_edit: True
81
+
82
+
83
+ model_config:
84
+ lora: 160
85
+ # temporal_downsample_time: 4
86
+ SparseCausalAttention_index: ['mid']
87
+ least_sc_channel: 640
88
+ # least_sc_channel: 100000
89
+
90
+ test_pipeline_config:
91
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
92
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
93
+
94
+ epsilon: 1e-5
95
+ train_steps: 10
96
+ seed: 0
97
+ learning_rate: 1e-5
98
+ train_temporal_conv: False
99
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/fox_wolf_snow.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/fox_wolf_snow.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "data/attribute/fox_wolf_snow"
7
+ prompt: "a white fox sitting in the grass"
8
+ n_sample_frame: 8
9
+ # n_sample_frame: 22
10
+ sampling_rate: 1
11
+ stride: 80
12
+ offset:
13
+ left: 0
14
+ right: 0
15
+ top: 0
16
+ bottom: 0
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: True
20
+ use_inversion_attention: True
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ # source prompt
24
+ a white fox sitting in the grass,
25
+
26
+ # foreground texture style
27
+ a grey wolf sitting in the grass,
28
+ a white fox sitting in the snow
29
+ ]
30
+ p2p_config:
31
+ 0:
32
+ # Whether to directly copy the cross attention from source
33
+ # True: directly copy, better for object replacement
34
+ # False: keep source attention, better for style
35
+ is_replace_controller: False
36
+
37
+ # Semantic preserving and replacement Debug me
38
+ cross_replace_steps:
39
+ default_: 0.8
40
+
41
+ # Source background structure preserving, in [0, 1].
42
+ # e.g., =0.6 Replace the first 60% steps self-attention
43
+ self_replace_steps: 0.6
44
+
45
+
46
+ # Amplify the target-words cross attention, larger value, more close to target
47
+ eq_params:
48
+ words: ["silver", "sculpture"]
49
+ values: [2,2]
50
+
51
+ # Target structure-divergence hyperparames
52
+ # If you change the shape of object better to use all three line, otherwise, no need.
53
+ # Without following three lines, all self-attention will be replaced
54
+ blend_words: [['cat',], ["cat",]]
55
+ masked_self_attention: True
56
+ # masked_latents: False # performance not so good in our case, need debug
57
+ bend_th: [2, 2]
58
+ # preserve source structure of blend_words , [0, 1]
59
+ # default is bend_th: [2, 2] # preserve all source self-attention
60
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
61
+
62
+
63
+ 1:
64
+ is_replace_controller: false
65
+ cross_replace_steps:
66
+ default_: 0.5
67
+ self_replace_steps: 0.5
68
+
69
+ eq_params:
70
+ words: ["robotic"]
71
+ values: [10] # amplify attention to the word "tiger" by *2
72
+
73
+ 2:
74
+ is_replace_controller: false
75
+ cross_replace_steps:
76
+ default_: 0.5
77
+ self_replace_steps: 0.5
78
+ eq_params:
79
+ words: ["snow"]
80
+ values: [10] # amplify attention to the word "tiger" by *2
81
+
82
+
83
+ clip_length: "${..train_dataset.n_sample_frame}"
84
+ sample_seeds: [0]
85
+ val_all_frames: False
86
+
87
+ num_inference_steps: 50
88
+ prompt2prompt_edit: True
89
+
90
+
91
+ model_config:
92
+ lora: 160
93
+ # temporal_downsample_time: 4
94
+ SparseCausalAttention_index: ['mid']
95
+ least_sc_channel: 640
96
+ # least_sc_channel: 100000
97
+
98
+ test_pipeline_config:
99
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
100
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
101
+
102
+ epsilon: 1e-5
103
+ train_steps: 10
104
+ seed: 0
105
+ learning_rate: 1e-5
106
+ train_temporal_conv: False
107
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=1 python test_fatezero.py --config config/attribute/rabbit_straberry_leaves_flowers.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+
6
+ train_dataset:
7
+ path: "data/attribute/rabbit_strawberry"
8
+ prompt: "A rabbit is eating strawberries"
9
+ n_sample_frame: 8
10
+ # n_sample_frame: 22
11
+ sampling_rate: 1
12
+ stride: 80
13
+ offset:
14
+ left: 0
15
+ right: 0
16
+ top: 0
17
+ bottom: 0
18
+
19
+ validation_sample_logger_config:
20
+ use_train_latents: True
21
+ use_inversion_attention: True
22
+ guidance_scale: 7.5
23
+ prompts: [
24
+ # source prompt
25
+ A rabbit is eating strawberries,
26
+
27
+ # foreground texture style
28
+ A white rabbit is eating leaves,
29
+ A white rabbit is eating flower,
30
+ A white rabbit is eating orange,
31
+
32
+ # a brown lion walking on the rock against a wall,
33
+ ]
34
+ p2p_config:
35
+ 0:
36
+ # Whether to directly copy the cross attention from source
37
+ # True: directly copy, better for object replacement
38
+ # False: keep source attention, better for style
39
+ is_replace_controller: False
40
+
41
+ # Semantic preserving and replacement Debug me
42
+ cross_replace_steps:
43
+ default_: 0.8
44
+
45
+ # Source background structure preserving, in [0, 1].
46
+ # e.g., =0.6 Replace the first 60% steps self-attention
47
+ self_replace_steps: 0.6
48
+
49
+
50
+ # Amplify the target-words cross attention, larger value, more close to target
51
+ eq_params:
52
+ words: ["silver", "sculpture"]
53
+ values: [2,2]
54
+
55
+ # Target structure-divergence hyperparames
56
+ # If you change the shape of object better to use all three line, otherwise, no need.
57
+ # Without following three lines, all self-attention will be replaced
58
+ blend_words: [['cat',], ["cat",]]
59
+ masked_self_attention: True
60
+ # masked_latents: False # performance not so good in our case, need debug
61
+ bend_th: [2, 2]
62
+ # preserve source structure of blend_words , [0, 1]
63
+ # default is bend_th: [2, 2] # preserve all source self-attention
64
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
65
+ 1:
66
+ is_replace_controller: false
67
+ cross_replace_steps:
68
+ default_: 0.5
69
+ self_replace_steps: 0.5
70
+ eq_params:
71
+ words: ["leaves"]
72
+ values: [10]
73
+ 2:
74
+ is_replace_controller: false
75
+ cross_replace_steps:
76
+ default_: 0.5
77
+ self_replace_steps: 0.5
78
+ eq_params:
79
+ words: ["flower"]
80
+ values: [10]
81
+ 3:
82
+ is_replace_controller: false
83
+ cross_replace_steps:
84
+ default_: 0.5
85
+ self_replace_steps: 0.5
86
+ eq_params:
87
+ words: ["orange"]
88
+ values: [10]
89
+
90
+ clip_length: "${..train_dataset.n_sample_frame}"
91
+ sample_seeds: [0]
92
+ val_all_frames: False
93
+
94
+ num_inference_steps: 50
95
+ prompt2prompt_edit: True
96
+
97
+
98
+ model_config:
99
+ lora: 160
100
+ # temporal_downsample_time: 4
101
+ SparseCausalAttention_index: ['mid']
102
+ least_sc_channel: 640
103
+ # least_sc_channel: 100000
104
+
105
+ test_pipeline_config:
106
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
107
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
108
+
109
+ epsilon: 1e-5
110
+ train_steps: 10
111
+ seed: 0
112
+ learning_rate: 1e-5
113
+ train_temporal_conv: False
114
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/squ_carrot_robot_eggplant.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/squ_carrot_robot_eggplant.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+
6
+ train_dataset:
7
+ path: "data/attribute/squirrel_carrot"
8
+ prompt: "A squirrel is eating a carrot"
9
+ n_sample_frame: 8
10
+ # n_sample_frame: 22
11
+ sampling_rate: 1
12
+ stride: 80
13
+ offset:
14
+ left: 0
15
+ right: 0
16
+ top: 0
17
+ bottom: 0
18
+
19
+ validation_sample_logger_config:
20
+ use_train_latents: True
21
+ use_inversion_attention: True
22
+ guidance_scale: 7.5
23
+ prompts: [
24
+ # source prompt
25
+ A squirrel is eating a carrot,
26
+ A robot squirrel is eating a carrot,
27
+ A rabbit is eating a eggplant,
28
+ A robot mouse is eating a screwdriver,
29
+ A white mouse is eating a peanut,
30
+ ]
31
+ p2p_config:
32
+ 0:
33
+ # Whether to directly copy the cross attention from source
34
+ # True: directly copy, better for object replacement
35
+ # False: keep source attention, better for style
36
+ is_replace_controller: False
37
+
38
+ # Semantic preserving and replacement Debug me
39
+ cross_replace_steps:
40
+ default_: 0.8
41
+
42
+ # Source background structure preserving, in [0, 1].
43
+ # e.g., =0.6 Replace the first 60% steps self-attention
44
+ self_replace_steps: 0.6
45
+
46
+
47
+ # Amplify the target-words cross attention, larger value, more close to target
48
+ eq_params:
49
+ words: ["silver", "sculpture"]
50
+ values: [2,2]
51
+
52
+ # Target structure-divergence hyperparames
53
+ # If you change the shape of object better to use all three line, otherwise, no need.
54
+ # Without following three lines, all self-attention will be replaced
55
+ blend_words: [['cat',], ["cat",]]
56
+ masked_self_attention: True
57
+ # masked_latents: False # performance not so good in our case, need debug
58
+ bend_th: [2, 2]
59
+ # preserve source structure of blend_words , [0, 1]
60
+ # default is bend_th: [2, 2] # preserve all source self-attention
61
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
62
+
63
+
64
+ 1:
65
+ is_replace_controller: false
66
+ cross_replace_steps:
67
+ default_: 0.5
68
+ self_replace_steps: 0.4
69
+ eq_params:
70
+ words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
71
+ values: [10, 10, 20, 10, 10, 10]
72
+ 2:
73
+ is_replace_controller: false
74
+ cross_replace_steps:
75
+ default_: 0.5
76
+ self_replace_steps: 0.5
77
+ eq_params:
78
+ words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
79
+ values: [10, 10, 20, 10, 10, 10]
80
+ 3:
81
+ is_replace_controller: false
82
+ cross_replace_steps:
83
+ default_: 0.5
84
+ self_replace_steps: 0.5
85
+ eq_params:
86
+ words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
87
+ values: [10, 10, 20, 10, 10, 10]
88
+ 4:
89
+ is_replace_controller: false
90
+ cross_replace_steps:
91
+ default_: 0.5
92
+ self_replace_steps: 0.5
93
+ eq_params:
94
+ words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
95
+ values: [10, 10, 20, 10, 10, 10]
96
+
97
+
98
+
99
+ clip_length: "${..train_dataset.n_sample_frame}"
100
+ sample_seeds: [0]
101
+ val_all_frames: False
102
+
103
+ num_inference_steps: 50
104
+ prompt2prompt_edit: True
105
+
106
+
107
+ model_config:
108
+ lora: 160
109
+ # temporal_downsample_time: 4
110
+ SparseCausalAttention_index: ['mid']
111
+ least_sc_channel: 640
112
+ # least_sc_channel: 100000
113
+
114
+ test_pipeline_config:
115
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
116
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
117
+
118
+ epsilon: 1e-5
119
+ train_steps: 10
120
+ seed: 0
121
+ learning_rate: 1e-5
122
+ train_temporal_conv: False
123
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/attribute/swan_swa.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/swan_swa.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+
6
+ train_dataset:
7
+ path: "data/attribute/swan_swarov"
8
+ prompt: "a black swan with a red beak swimming in a river near a wall and bushes,"
9
+ n_sample_frame: 8
10
+ # n_sample_frame: 22
11
+ sampling_rate: 1
12
+ stride: 80
13
+ offset:
14
+ left: 0
15
+ right: 0
16
+ top: 0
17
+ bottom: 0
18
+
19
+ use_train_latents: True
20
+
21
+ validation_sample_logger_config:
22
+ use_train_latents: True
23
+ use_inversion_attention: True
24
+ guidance_scale: 7.5
25
+ prompts: [
26
+ # source prompt
27
+ a black swan with a red beak swimming in a river near a wall and bushes,
28
+
29
+ # foreground texture style
30
+ a Swarovski crystal swan with a red beak swimming in a river near a wall and bushes,
31
+ ]
32
+ p2p_config:
33
+ 0:
34
+ # Whether to directly copy the cross attention from source
35
+ # True: directly copy, better for object replacement
36
+ # False: keep source attention, better for style
37
+ is_replace_controller: False
38
+
39
+ # Semantic preserving and replacement Debug me
40
+ cross_replace_steps:
41
+ default_: 0.8
42
+
43
+ # Source background structure preserving, in [0, 1].
44
+ # e.g., =0.6 Replace the first 60% steps self-attention
45
+ self_replace_steps: 0.6
46
+
47
+
48
+ # Amplify the target-words cross attention, larger value, more close to target
49
+ eq_params:
50
+ words: ["silver", "sculpture"]
51
+ values: [2,2]
52
+
53
+ # Target structure-divergence hyperparames
54
+ # If you change the shape of object better to use all three line, otherwise, no need.
55
+ # Without following three lines, all self-attention will be replaced
56
+ blend_words: [['cat',], ["cat",]]
57
+ masked_self_attention: True
58
+ # masked_latents: False # performance not so good in our case, need debug
59
+ bend_th: [2, 2]
60
+ # preserve source structure of blend_words , [0, 1]
61
+ # default is bend_th: [2, 2] # preserve all source self-attention
62
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
63
+
64
+
65
+ 1:
66
+ is_replace_controller: False
67
+ cross_replace_steps:
68
+ default_: 0.8
69
+ self_replace_steps: 0.6
70
+
71
+ eq_params:
72
+ words: ["Swarovski", "crystal"]
73
+ values: [5, 5] # amplify attention to the word "tiger" by *2
74
+ use_inversion_attention: True
75
+
76
+
77
+
78
+ clip_length: "${..train_dataset.n_sample_frame}"
79
+ sample_seeds: [0]
80
+ val_all_frames: False
81
+
82
+ num_inference_steps: 50
83
+ prompt2prompt_edit: True
84
+
85
+
86
+ model_config:
87
+ lora: 160
88
+ # temporal_downsample_time: 4
89
+ SparseCausalAttention_index: ['mid']
90
+ least_sc_channel: 1280
91
+ # least_sc_channel: 100000
92
+
93
+ test_pipeline_config:
94
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
95
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
96
+
97
+ epsilon: 1e-5
98
+ train_steps: 10
99
+ seed: 0
100
+ learning_rate: 1e-5
101
+ train_temporal_conv: False
102
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/low_resource_teaser/jeep_watercolor.yaml
2
+
3
+ pretrained_model_path: "FateZero/ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "FateZero/data/teaser_car-turn"
7
+ prompt: "a silver jeep driving down a curvy road in the countryside"
8
+ n_sample_frame: 8
9
+ sampling_rate: 1
10
+ stride: 80
11
+ offset:
12
+ left: 0
13
+ right: 0
14
+ top: 0
15
+ bottom: 0
16
+
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: true
20
+ use_inversion_attention: true
21
+ guidance_scale: 7.5
22
+ source_prompt: "${train_dataset.prompt}"
23
+ prompts: [
24
+ # a silver jeep driving down a curvy road in the countryside,
25
+ watercolor painting of a silver jeep driving down a curvy road in the countryside,
26
+ ]
27
+ p2p_config:
28
+ 0:
29
+ # Whether to directly copy the cross attention from source
30
+ # True: directly copy, better for object replacement
31
+ # False: keep source attention, better for style
32
+
33
+ is_replace_controller: False
34
+
35
+ # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
36
+ cross_replace_steps:
37
+ default_: 0.8
38
+
39
+ # Source background structure preserving, in [0, 1].
40
+ # e.g., =0.6 Replace the first 60% steps self-attention
41
+ self_replace_steps: 0.8
42
+
43
+
44
+ # Amplify the target-words cross attention, larger value, more close to target
45
+ eq_params:
46
+ words: ["watercolor"]
47
+ values: [10,10]
48
+
49
+ # Target structure-divergence hyperparames
50
+ # If you change the shape of object better to use all three line, otherwise, no need.
51
+ # Without following three lines, all self-attention will be replaced
52
+ # blend_words: [['jeep',], ["car",]]
53
+ # masked_self_attention: True
54
+ # masked_latents: False # performance not so good in our case, need debug
55
+ # bend_th: [2, 2]
56
+ # preserve source structure of blend_words , [0, 1]
57
+ # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
58
+ # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
59
+
60
+
61
+ clip_length: "${..train_dataset.n_sample_frame}"
62
+ sample_seeds: [0]
63
+
64
+ num_inference_steps: 10
65
+ prompt2prompt_edit: True
66
+
67
+ model_config:
68
+ lora: 160
69
+ # temporal_downsample_time: 4
70
+ SparseCausalAttention_index: ['mid']
71
+ least_sc_channel: 640
72
+ # least_sc_channel: 100000
73
+
74
+ test_pipeline_config:
75
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
76
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
77
+
78
+ epsilon: 1e-5
79
+ train_steps: 10
80
+ seed: 0
81
+ learning_rate: 1e-5
82
+ train_temporal_conv: False
83
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "data/teaser_car-turn"
7
+ prompt: "a silver jeep driving down a curvy road in the countryside"
8
+ n_sample_frame: 8
9
+ sampling_rate: 1
10
+ stride: 80
11
+ offset:
12
+ left: 0
13
+ right: 0
14
+ top: 0
15
+ bottom: 0
16
+
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: true
20
+ use_inversion_attention: true
21
+ guidance_scale: 7.5
22
+ source_prompt: "${train_dataset.prompt}"
23
+ prompts: [
24
+ # a silver jeep driving down a curvy road in the countryside,
25
+ watercolor painting of a silver jeep driving down a curvy road in the countryside,
26
+ ]
27
+ p2p_config:
28
+ 0:
29
+ # Whether to directly copy the cross attention from source
30
+ # True: directly copy, better for object replacement
31
+ # False: keep source attention, better for style
32
+
33
+ is_replace_controller: False
34
+
35
+ # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
36
+ cross_replace_steps:
37
+ default_: 0.8
38
+
39
+ # Source background structure preserving, in [0, 1].
40
+ # e.g., =0.6 Replace the first 60% steps self-attention
41
+ self_replace_steps: 0.8
42
+
43
+
44
+ # Amplify the target-words cross attention, larger value, more close to target
45
+ eq_params:
46
+ words: ["watercolor"]
47
+ values: [10,10]
48
+
49
+ # Target structure-divergence hyperparames
50
+ # If you change the shape of object better to use all three line, otherwise, no need.
51
+ # Without following three lines, all self-attention will be replaced
52
+ # blend_words: [['jeep',], ["car",]]
53
+ # masked_self_attention: True
54
+ # masked_latents: False # performance not so good in our case, need debug
55
+ # bend_th: [2, 2]
56
+ # preserve source structure of blend_words , [0, 1]
57
+ # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
58
+ # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
59
+
60
+
61
+ clip_length: "${..train_dataset.n_sample_frame}"
62
+ sample_seeds: [0]
63
+
64
+ num_inference_steps: 10
65
+ prompt2prompt_edit: True
66
+
67
+ disk_store: True
68
+ model_config:
69
+ lora: 160
70
+ # temporal_downsample_time: 4
71
+ SparseCausalAttention_index: ['mid']
72
+ least_sc_channel: 640
73
+ # least_sc_channel: 100000
74
+
75
+ test_pipeline_config:
76
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
77
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
78
+
79
+ epsilon: 1e-5
80
+ train_steps: 10
81
+ seed: 0
82
+ learning_rate: 1e-5
83
+ train_temporal_conv: False
84
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/jeep_watercolor.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml
2
+
3
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "data/teaser_car-turn"
7
+ prompt: "a silver jeep driving down a curvy road in the countryside"
8
+ n_sample_frame: 8
9
+ sampling_rate: 1
10
+ stride: 80
11
+ offset:
12
+ left: 0
13
+ right: 0
14
+ top: 0
15
+ bottom: 0
16
+
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: true
20
+ use_inversion_attention: true
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ a silver jeep driving down a curvy road in the countryside,
24
+ watercolor painting of a silver jeep driving down a curvy road in the countryside,
25
+ ]
26
+ p2p_config:
27
+ 0:
28
+ # Whether to directly copy the cross attention from source
29
+ # True: directly copy, better for object replacement
30
+ # False: keep source attention, better for style
31
+ is_replace_controller: False
32
+
33
+ # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
34
+ cross_replace_steps:
35
+ default_: 0.8
36
+
37
+ # Source background structure preserving, in [0, 1].
38
+ # e.g., =0.6 Replace the first 60% steps self-attention
39
+ self_replace_steps: 0.9
40
+
41
+
42
+ # Amplify the target-words cross attention, larger value, more close to target
43
+ # eq_params:
44
+ # words: ["", ""]
45
+ # values: [10,10]
46
+
47
+ # Target structure-divergence hyperparames
48
+ # If you change the shape of object better to use all three line, otherwise, no need.
49
+ # Without following three lines, all self-attention will be replaced
50
+ # blend_words: [['jeep',], ["car",]]
51
+ masked_self_attention: True
52
+ # masked_latents: False # Directly copy the latents, performance not so good in our case
53
+ bend_th: [2, 2]
54
+ # preserve source structure of blend_words , [0, 1]
55
+ # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
56
+ # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
57
+
58
+
59
+ 1:
60
+ cross_replace_steps:
61
+ default_: 0.8
62
+ self_replace_steps: 0.8
63
+
64
+ eq_params:
65
+ words: ["watercolor"]
66
+ values: [10] # amplify attention to the word "tiger" by *2
67
+ use_inversion_attention: True
68
+ is_replace_controller: False
69
+
70
+
71
+ clip_length: "${..train_dataset.n_sample_frame}"
72
+ sample_seeds: [0]
73
+
74
+ num_inference_steps: 50
75
+ prompt2prompt_edit: True
76
+
77
+
78
+ model_config:
79
+ lora: 160
80
+ # temporal_downsample_time: 4
81
+ SparseCausalAttention_index: ['mid']
82
+ least_sc_channel: 640
83
+ # least_sc_channel: 100000
84
+
85
+ test_pipeline_config:
86
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
87
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
88
+
89
+ epsilon: 1e-5
90
+ train_steps: 10
91
+ seed: 0
92
+ learning_rate: 1e-5
93
+ train_temporal_conv: False
94
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/lily_monet.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
2
+
3
+
4
+ train_dataset:
5
+ path: "data/style/red_water_lily_opening"
6
+ prompt: "a pink water lily"
7
+ start_sample_frame: 1
8
+ n_sample_frame: 8
9
+ # n_sample_frame: 22
10
+ sampling_rate: 20
11
+ stride: 8000
12
+ # offset:
13
+ # left: 300
14
+ # right: 0
15
+ # top: 0
16
+ # bottom: 0
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: True
20
+ use_inversion_attention: True
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ a pink water lily,
24
+ Claude Monet painting of a pink water lily,
25
+ ]
26
+ p2p_config:
27
+ 0:
28
+ # Whether to directly copy the cross attention from source
29
+ # True: directly copy, better for object replacement
30
+ # False: keep source attention, better for style
31
+ is_replace_controller: False
32
+
33
+ # Semantic preserving and replacement Debug me
34
+ cross_replace_steps:
35
+ default_: 0.7
36
+
37
+ # Source background structure preserving, in [0, 1].
38
+ # e.g., =0.6 Replace the first 60% steps self-attention
39
+ self_replace_steps: 0.7
40
+
41
+
42
+ # Amplify the target-words cross attention, larger value, more close to target
43
+ eq_params:
44
+ words: ["silver", "sculpture"]
45
+ values: [2,2]
46
+
47
+ # Target structure-divergence hyperparames
48
+ # If you change the shape of object better to use all three line, otherwise, no need.
49
+ # Without following three lines, all self-attention will be replaced
50
+ blend_words: [['cat',], ["cat",]]
51
+ masked_self_attention: True
52
+ # masked_latents: False # performance not so good in our case, need debug
53
+ bend_th: [2, 2]
54
+ # preserve source structure of blend_words , [0, 1]
55
+ # default is bend_th: [2, 2] # preserve all source self-attention
56
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
57
+
58
+
59
+ 1:
60
+ is_replace_controller: False
61
+ cross_replace_steps:
62
+ default_: 0.5
63
+ self_replace_steps: 0.5
64
+
65
+ eq_params:
66
+ words: ["Monet"]
67
+ values: [10]
68
+
69
+ clip_length: "${..train_dataset.n_sample_frame}"
70
+ sample_seeds: [0]
71
+ val_all_frames: False
72
+
73
+ num_inference_steps: 50
74
+ prompt2prompt_edit: True
75
+
76
+
77
+ model_config:
78
+ lora: 160
79
+ # temporal_downsample_time: 4
80
+ SparseCausalAttention_index: ['mid']
81
+ least_sc_channel: 1280
82
+ # least_sc_channel: 100000
83
+
84
+ test_pipeline_config:
85
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
86
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
87
+
88
+ epsilon: 1e-5
89
+ train_steps: 10
90
+ seed: 0
91
+ learning_rate: 1e-5
92
+ train_temporal_conv: False
93
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/rabit_pokemon.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
2
+
3
+
4
+ train_dataset:
5
+ path: "data/style/rabit"
6
+ prompt: "A rabbit is eating a watermelon"
7
+ n_sample_frame: 8
8
+ # n_sample_frame: 22
9
+ sampling_rate: 3
10
+ stride: 80
11
+
12
+
13
+ validation_sample_logger_config:
14
+ use_train_latents: True
15
+ use_inversion_attention: True
16
+ guidance_scale: 7.5
17
+ prompts: [
18
+ # source prompt
19
+ A rabbit is eating a watermelon,
20
+ # overall style
21
+ pokemon cartoon of A rabbit is eating a watermelon,
22
+ ]
23
+ p2p_config:
24
+ 0:
25
+ # Whether to directly copy the cross attention from source
26
+ # True: directly copy, better for object replacement
27
+ # False: keep source attention, better for style
28
+ is_replace_controller: False
29
+
30
+ # Semantic preserving and replacement Debug me
31
+ cross_replace_steps:
32
+ default_: 0.8
33
+
34
+ # Source background structure preserving, in [0, 1].
35
+ # e.g., =0.6 Replace the first 60% steps self-attention
36
+ self_replace_steps: 0.6
37
+
38
+
39
+ # Amplify the target-words cross attention, larger value, more close to target
40
+ eq_params:
41
+ words: ["silver", "sculpture"]
42
+ values: [2,2]
43
+
44
+ # Target structure-divergence hyperparames
45
+ # If you change the shape of object better to use all three line, otherwise, no need.
46
+ # Without following three lines, all self-attention will be replaced
47
+ blend_words: [['cat',], ["cat",]]
48
+ masked_self_attention: True
49
+ # masked_latents: False # performance not so good in our case, need debug
50
+ bend_th: [2, 2]
51
+ # preserve source structure of blend_words , [0, 1]
52
+ # default is bend_th: [2, 2] # preserve all source self-attention
53
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
54
+
55
+
56
+ 1:
57
+ is_replace_controller: False
58
+ cross_replace_steps:
59
+ default_: 0.7
60
+ self_replace_steps: 0.7
61
+
62
+ eq_params:
63
+ words: ["pokemon", "cartoon"]
64
+ values: [3, 3] # amplify attention to the word "tiger" by *2
65
+
66
+
67
+
68
+ clip_length: "${..train_dataset.n_sample_frame}"
69
+ sample_seeds: [0]
70
+ val_all_frames: False
71
+
72
+ num_inference_steps: 50
73
+ prompt2prompt_edit: True
74
+
75
+
76
+ model_config:
77
+ # lora: 160
78
+ # temporal_downsample_time: 4
79
+ # SparseCausalAttention_index: ['mid']
80
+ # least_sc_channel: 640
81
+ # least_sc_channel: 100000
82
+
83
+ test_pipeline_config:
84
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
85
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
86
+
87
+ epsilon: 1e-5
88
+ train_steps: 50
89
+ seed: 0
90
+ learning_rate: 1e-5
91
+ train_temporal_conv: False
92
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/sun_flower_van_gogh.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
2
+
3
+ train_dataset:
4
+ path: "data/style/sunflower"
5
+ prompt: "a yellow sunflower"
6
+ start_sample_frame: 0
7
+ n_sample_frame: 8
8
+ sampling_rate: 1
9
+
10
+
11
+ validation_sample_logger_config:
12
+ use_train_latents: True
13
+ use_inversion_attention: True
14
+ guidance_scale: 7.5
15
+ prompts: [
16
+ a yellow sunflower,
17
+ van gogh style painting of a yellow sunflower,
18
+ ]
19
+ p2p_config:
20
+ 0:
21
+ # Whether to directly copy the cross attention from source
22
+ # True: directly copy, better for object replacement
23
+ # False: keep source attention, better for style
24
+ is_replace_controller: False
25
+
26
+ # Semantic preserving and replacement Debug me
27
+ cross_replace_steps:
28
+ default_: 0.7
29
+
30
+ # Source background structure preserving, in [0, 1].
31
+ # e.g., =0.6 Replace the first 60% steps self-attention
32
+ self_replace_steps: 0.7
33
+
34
+
35
+ # Amplify the target-words cross attention, larger value, more close to target
36
+ eq_params:
37
+ words: ["silver", "sculpture"]
38
+ values: [2,2]
39
+
40
+ # Target structure-divergence hyperparames
41
+ # If you change the shape of object better to use all three line, otherwise, no need.
42
+ # Without following three lines, all self-attention will be replaced
43
+ blend_words: [['cat',], ["cat",]]
44
+ masked_self_attention: True
45
+ # masked_latents: False # performance not so good in our case, need debug
46
+ bend_th: [2, 2]
47
+ # preserve source structure of blend_words , [0, 1]
48
+ # default is bend_th: [2, 2] # preserve all source self-attention
49
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
50
+
51
+
52
+ 1:
53
+ is_replace_controller: False
54
+ cross_replace_steps:
55
+ default_: 0.5
56
+ self_replace_steps: 0.5
57
+
58
+ eq_params:
59
+ words: ["van", "gogh"]
60
+ values: [10, 10] # amplify attention to the word "tiger" by *2
61
+
62
+ clip_length: "${..train_dataset.n_sample_frame}"
63
+ sample_seeds: [0]
64
+ val_all_frames: False
65
+
66
+ num_inference_steps: 50
67
+ prompt2prompt_edit: True
68
+
69
+
70
+ model_config:
71
+ lora: 160
72
+ # temporal_downsample_time: 4
73
+ SparseCausalAttention_index: ['mid']
74
+ least_sc_channel: 640
75
+ # least_sc_channel: 100000
76
+
77
+ test_pipeline_config:
78
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
79
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
80
+
81
+ epsilon: 1e-5
82
+ train_steps: 10
83
+ seed: 0
84
+ learning_rate: 1e-5
85
+ train_temporal_conv: False
86
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/surf_ukiyo.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
2
+
3
+ train_dataset:
4
+ path: "data/style/surf"
5
+ prompt: "a man with round helmet surfing on a white wave in blue ocean with a rope"
6
+ n_sample_frame: 1
7
+
8
+ sampling_rate: 8
9
+
10
+
11
+ # use_train_latents: True
12
+
13
+ validation_sample_logger_config:
14
+ use_train_latents: true
15
+ use_inversion_attention: true
16
+ guidance_scale: 7.5
17
+ prompts: [
18
+ a man with round helmet surfing on a white wave in blue ocean with a rope,
19
+ The Ukiyo-e style painting of a man with round helmet surfing on a white wave in blue ocean with a rope
20
+ ]
21
+ p2p_config:
22
+ 0:
23
+ # Whether to directly copy the cross attention from source
24
+ # True: directly copy, better for object replacement
25
+ # False: keep source attention, better for style
26
+ is_replace_controller: False
27
+
28
+ # Semantic preserving and replacement Debug me
29
+ cross_replace_steps:
30
+ default_: 0.8
31
+
32
+ # Source background structure preserving, in [0, 1].
33
+ # e.g., =0.6 Replace the first 60% steps self-attention
34
+ self_replace_steps: 0.8
35
+
36
+
37
+ # Amplify the target-words cross attention, larger value, more close to target
38
+ eq_params:
39
+ words: ["silver", "sculpture"]
40
+ values: [2,2]
41
+
42
+ # Target structure-divergence hyperparames
43
+ # If you change the shape of object better to use all three line, otherwise, no need.
44
+ # Without following three lines, all self-attention will be replaced
45
+ blend_words: [['cat',], ["cat",]]
46
+ masked_self_attention: True
47
+ # masked_latents: False # performance not so good in our case, need debug
48
+ bend_th: [2, 2]
49
+ # preserve source structure of blend_words , [0, 1]
50
+ # default is bend_th: [2, 2] # preserve all source self-attention
51
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
52
+
53
+ 1:
54
+ is_replace_controller: False
55
+ cross_replace_steps:
56
+ default_: 0.9
57
+ self_replace_steps: 0.9
58
+
59
+ eq_params:
60
+ words: ["Ukiyo-e"]
61
+ values: [10, 10] # amplify attention to the word "tiger" by *2
62
+
63
+
64
+
65
+
66
+ clip_length: "${..train_dataset.n_sample_frame}"
67
+ sample_seeds: [0]
68
+ val_all_frames: False
69
+
70
+ num_inference_steps: 50
71
+ prompt2prompt_edit: True
72
+
73
+
74
+ model_config:
75
+ # lora: 160
76
+ # temporal_downsample_time: 4
77
+ SparseCausalAttention_index: ['mid']
78
+ least_sc_channel: 640
79
+ # least_sc_channel: 100000
80
+
81
+ test_pipeline_config:
82
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
83
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
84
+
85
+ epsilon: 1e-5
86
+ train_steps: 50
87
+ seed: 0
88
+ learning_rate: 1e-5
89
+ train_temporal_conv: False
90
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/swan_cartoon.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
2
+
3
+
4
+ train_dataset:
5
+ path: "data/style/blackswan"
6
+ prompt: "a black swan with a red beak swimming in a river near a wall and bushes,"
7
+ n_sample_frame: 8
8
+ # n_sample_frame: 22
9
+ sampling_rate: 6
10
+ stride: 80
11
+ offset:
12
+ left: 0
13
+ right: 0
14
+ top: 0
15
+ bottom: 0
16
+
17
+ # use_train_latents: True
18
+
19
+ validation_sample_logger_config:
20
+ use_train_latents: true
21
+ use_inversion_attention: true
22
+ guidance_scale: 7.5
23
+ prompts: [
24
+ # source prompt
25
+ a black swan with a red beak swimming in a river near a wall and bushes,
26
+ cartoon photo of a black swan with a red beak swimming in a river near a wall and bushes,
27
+ ]
28
+ p2p_config:
29
+ 0:
30
+ # Whether to directly copy the cross attention from source
31
+ # True: directly copy, better for object replacement
32
+ # False: keep source attention, better for style
33
+ is_replace_controller: False
34
+
35
+ # Semantic preserving and replacement Debug me
36
+ cross_replace_steps:
37
+ default_: 0.8
38
+
39
+ # Source background structure preserving, in [0, 1].
40
+ # e.g., =0.6 Replace the first 60% steps self-attention
41
+ self_replace_steps: 0.6
42
+
43
+
44
+ # Amplify the target-words cross attention, larger value, more close to target
45
+ eq_params:
46
+ words: ["silver", "sculpture"]
47
+ values: [2,2]
48
+
49
+ # Target structure-divergence hyperparames
50
+ # If you change the shape of object better to use all three line, otherwise, no need.
51
+ # Without following three lines, all self-attention will be replaced
52
+ blend_words: [['cat',], ["cat",]]
53
+ masked_self_attention: True
54
+ # masked_latents: False # performance not so good in our case, need debug
55
+ bend_th: [2, 2]
56
+ # preserve source structure of blend_words , [0, 1]
57
+ # default is bend_th: [2, 2] # preserve all source self-attention
58
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
59
+
60
+ # Fixed hyperparams
61
+ use_inversion_attention: True
62
+
63
+ 1:
64
+ is_replace_controller: False
65
+ cross_replace_steps:
66
+ default_: 0.8
67
+ self_replace_steps: 0.7
68
+
69
+ eq_params:
70
+ words: ["cartoon"]
71
+ values: [10] # amplify attention to the word "tiger" by *2
72
+ use_inversion_attention: True
73
+
74
+
75
+
76
+ clip_length: "${..train_dataset.n_sample_frame}"
77
+ sample_seeds: [0]
78
+ val_all_frames: False
79
+
80
+ num_inference_steps: 50
81
+ # guidance_scale: 7.5
82
+ prompt2prompt_edit: True
83
+
84
+
85
+ model_config:
86
+ lora: 160
87
+ # temporal_downsample_time: 4
88
+ SparseCausalAttention_index: ['mid']
89
+ least_sc_channel: 640
90
+ # least_sc_channel: 100000
91
+
92
+ test_pipeline_config:
93
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
94
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
95
+
96
+ epsilon: 1e-5
97
+ train_steps: 10
98
+ seed: 0
99
+ learning_rate: 1e-5
100
+ train_temporal_conv: False
101
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/style/train_shinkai.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
2
+
3
+ train_dataset:
4
+ path: "data/style/train"
5
+ prompt: "a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track"
6
+ n_sample_frame: 32
7
+ # n_sample_frame: 22
8
+ sampling_rate: 7
9
+ stride: 80
10
+ # offset:
11
+ # left: 300
12
+ # right: 0
13
+ # top: 0
14
+ # bottom: 0
15
+
16
+ use_train_latents: True
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: True
20
+ use_inversion_attention: True
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track,
24
+ a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track Makoto Shinkai style
25
+
26
+ ]
27
+ p2p_config:
28
+ 0:
29
+ # Whether to directly copy the cross attention from source
30
+ # True: directly copy, better for object replacement
31
+ # False: keep source attention, better for style
32
+ is_replace_controller: False
33
+
34
+ # Semantic preserving and replacement Debug me
35
+ cross_replace_steps:
36
+ default_: 1.0
37
+
38
+ # Source background structure preserving, in [0, 1].
39
+ # e.g., =0.6 Replace the first 60% steps self-attention
40
+ self_replace_steps: 1.0
41
+
42
+
43
+ # Amplify the target-words cross attention, larger value, more close to target
44
+ # eq_params:
45
+ # words: ["silver", "sculpture"]
46
+ # values: [2,2]
47
+
48
+ # Target structure-divergence hyperparames
49
+ # If you change the shape of object better to use all three line, otherwise, no need.
50
+ # Without following three lines, all self-attention will be replaced
51
+ # blend_words: [['cat',], ["cat",]]
52
+ # masked_self_attention: True
53
+ # # masked_latents: False # performance not so good in our case, need debug
54
+ # bend_th: [2, 2]
55
+ # preserve source structure of blend_words , [0, 1]
56
+ # default is bend_th: [2, 2] # preserve all source self-attention
57
+ # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
58
+
59
+
60
+ 1:
61
+ is_replace_controller: False
62
+ cross_replace_steps:
63
+ default_: 1.0
64
+ self_replace_steps: 0.9
65
+
66
+ eq_params:
67
+ words: ["Makoto", "Shinkai"]
68
+ values: [10, 10] # amplify attention to the word "tiger" by *2
69
+
70
+
71
+
72
+
73
+ clip_length: "${..train_dataset.n_sample_frame}"
74
+ sample_seeds: [0]
75
+ val_all_frames: False
76
+
77
+ num_inference_steps: 50
78
+ prompt2prompt_edit: True
79
+
80
+
81
+ model_config:
82
+ lora: 160
83
+ # temporal_downsample_time: 4
84
+ SparseCausalAttention_index: ['mid']
85
+ least_sc_channel: 1280
86
+ # least_sc_channel: 100000
87
+
88
+ test_pipeline_config:
89
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
90
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
91
+
92
+ epsilon: 1e-5
93
+ train_steps: 10
94
+ seed: 0
95
+ learning_rate: 1e-5
96
+ train_temporal_conv: False
97
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/teaser/jeep_posche.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_posche.yaml
2
+
3
+ pretrained_model_path: "./ckpt/jeep_tuned_200"
4
+
5
+ train_dataset:
6
+ path: "data/teaser_car-turn"
7
+ prompt: "a silver jeep driving down a curvy road in the countryside,"
8
+ n_sample_frame: 8
9
+ sampling_rate: 1
10
+ stride: 80
11
+ offset:
12
+ left: 0
13
+ right: 0
14
+ top: 0
15
+ bottom: 0
16
+
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: true
20
+ use_inversion_attention: true
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ a silver jeep driving down a curvy road in the countryside,
24
+ a Porsche car driving down a curvy road in the countryside,
25
+ ]
26
+ p2p_config:
27
+ 0:
28
+ # Whether to directly copy the cross attention from source
29
+ # True: directly copy, better for object replacement
30
+ # False: keep source attention, better for style
31
+ is_replace_controller: False
32
+
33
+ # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
34
+ cross_replace_steps:
35
+ default_: 0.8
36
+
37
+ # Source background structure preserving, in [0, 1].
38
+ # e.g., =0.6 Replace the first 60% steps self-attention
39
+ self_replace_steps: 0.9
40
+
41
+
42
+ # Amplify the target-words cross attention, larger value, more close to target
43
+ # Usefull in style editing
44
+ eq_params:
45
+ words: ["watercolor", "painting"]
46
+ values: [10,10]
47
+
48
+ # Target structure-divergence hyperparames
49
+ # If you change the shape of object better to use all three line, otherwise, no need.
50
+ # Without following three lines, all self-attention will be replaced
51
+ # Usefull in shape editing
52
+ blend_words: [['jeep',], ["car",]]
53
+ masked_self_attention: True
54
+ # masked_latents: False # Directly copy the latents, performance not so good in our case
55
+
56
+ # preserve source structure of blend_words , [0, 1]
57
+ # bend_th-> [1.0, 1.0], mask -> 0, use inversion-time attention, the structure is similar to the input
58
+ # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
59
+ bend_th: [0.3, 0.3]
60
+
61
+ 1:
62
+ cross_replace_steps:
63
+ default_: 0.5
64
+ self_replace_steps: 0.5
65
+
66
+ use_inversion_attention: True
67
+ is_replace_controller: True
68
+
69
+ blend_words: [['silver', 'jeep'], ["Porsche", 'car']] # for local edit. If it is not local yet - use only the source object: blend_word = ((('cat',), ("cat",))).
70
+ masked_self_attention: True
71
+ bend_th: [0.3, 0.3]
72
+
73
+ clip_length: "${..train_dataset.n_sample_frame}"
74
+ sample_seeds: [0]
75
+
76
+ num_inference_steps: 50
77
+ prompt2prompt_edit: True
78
+
79
+
80
+ model_config:
81
+ lora: 160
82
+
83
+
84
+ test_pipeline_config:
85
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
86
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
87
+
88
+ epsilon: 1e-5
89
+ train_steps: 10
90
+ seed: 0
91
+ learning_rate: 1e-5
92
+ train_temporal_conv: False
93
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/config/teaser/jeep_watercolor.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml
2
+
3
+ pretrained_model_path: "FateZero/ckpt/stable-diffusion-v1-4"
4
+
5
+ train_dataset:
6
+ path: "FateZero/data/teaser_car-turn"
7
+ prompt: "a silver jeep driving down a curvy road in the countryside"
8
+ n_sample_frame: 8
9
+ sampling_rate: 1
10
+ stride: 80
11
+ offset:
12
+ left: 0
13
+ right: 0
14
+ top: 0
15
+ bottom: 0
16
+
17
+
18
+ validation_sample_logger_config:
19
+ use_train_latents: true
20
+ use_inversion_attention: true
21
+ guidance_scale: 7.5
22
+ prompts: [
23
+ a silver jeep driving down a curvy road in the countryside,
24
+ watercolor painting of a silver jeep driving down a curvy road in the countryside,
25
+ ]
26
+ p2p_config:
27
+ 0:
28
+ # Whether to directly copy the cross attention from source
29
+ # True: directly copy, better for object replacement
30
+ # False: keep source attention, better for style
31
+ is_replace_controller: False
32
+
33
+ # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
34
+ cross_replace_steps:
35
+ default_: 0.8
36
+
37
+ # Source background structure preserving, in [0, 1].
38
+ # e.g., =0.6 Replace the first 60% steps self-attention
39
+ self_replace_steps: 0.9
40
+
41
+
42
+ # Amplify the target-words cross attention, larger value, more close to target
43
+ # eq_params:
44
+ # words: ["", ""]
45
+ # values: [10,10]
46
+
47
+ # Target structure-divergence hyperparames
48
+ # If you change the shape of object better to use all three line, otherwise, no need.
49
+ # Without following three lines, all self-attention will be replaced
50
+ # blend_words: [['jeep',], ["car",]]
51
+ masked_self_attention: True
52
+ # masked_latents: False # Directly copy the latents, performance not so good in our case
53
+ bend_th: [2, 2]
54
+ # preserve source structure of blend_words , [0, 1]
55
+ # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
56
+ # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
57
+
58
+
59
+ 1:
60
+ cross_replace_steps:
61
+ default_: 0.8
62
+ self_replace_steps: 0.8
63
+
64
+ eq_params:
65
+ words: ["watercolor"]
66
+ values: [10] # amplify attention to the word "tiger" by *2
67
+ use_inversion_attention: True
68
+ is_replace_controller: False
69
+
70
+
71
+ clip_length: "${..train_dataset.n_sample_frame}"
72
+ sample_seeds: [0]
73
+
74
+ num_inference_steps: 50
75
+ prompt2prompt_edit: True
76
+
77
+
78
+ model_config:
79
+ lora: 160
80
+ # temporal_downsample_time: 4
81
+ SparseCausalAttention_index: ['mid']
82
+ least_sc_channel: 640
83
+ # least_sc_channel: 100000
84
+
85
+ test_pipeline_config:
86
+ target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
87
+ num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
88
+
89
+ epsilon: 1e-5
90
+ train_steps: 10
91
+ seed: 0
92
+ learning_rate: 1e-5
93
+ train_temporal_conv: False
94
+ guidance_scale: "${validation_sample_logger_config.guidance_scale}"
FateZero/data/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *
2
+ !teaser_car-turn
3
+ !teaser_car-turn/*
4
+ !.gitignore
FateZero/data/teaser_car-turn/00000.png ADDED
FateZero/data/teaser_car-turn/00001.png ADDED
FateZero/data/teaser_car-turn/00002.png ADDED
FateZero/data/teaser_car-turn/00003.png ADDED
FateZero/data/teaser_car-turn/00004.png ADDED
FateZero/data/teaser_car-turn/00005.png ADDED
FateZero/data/teaser_car-turn/00006.png ADDED
FateZero/data/teaser_car-turn/00007.png ADDED
FateZero/docs/EditingGuidance.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EditingGuidance
2
+
3
+ ## Prompt Engineering
4
+ For the results in the paper and webpage, we get the source prompt using the BLIP model embedded in the [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/).
5
+
6
+ Click the "interrogate CLIP", and we will get a source prompt automatically. Then, we remove the last few useless words.
7
+
8
+ <img src="../docs/blip.png" height="220px"/>
9
+
10
+ During stylization, you may use a very simple source prompt "A photo" as a baseline if your input video is too complicated to describe by one sentence.
11
+
12
+ ### Validate the prompt
13
+
14
+ - Put the source prompt into the stable diffusion. If the generated image is close to our input video, it can be a good source prompt.
15
+ - A good prompt describes each frame and most objects in video. Especially, it has the object or attribute that we want to edit or preserve.
16
+ - Put the target prompt into the stable diffusion. We can check the upper bound of our editing effect. A reasonable composition of video may achieve better results(e.g., "sunflower" video with "Van Gogh" prompt is better than "sunflower" with "Monet")
17
+
18
+
19
+
20
+
21
+
22
+
23
+ ## FateZero hyperparameters
24
+ We give a simple analysis of the involved hyperparaters as follows:
25
+ ``` yaml
26
+ # Whether to directly copy the cross attention from source
27
+ # True: directly copy, better for object replacement
28
+ # False: keep source attention, better for style
29
+ is_replace_controller: False
30
+
31
+ # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
32
+ cross_replace_steps:
33
+ default_: 0.8
34
+
35
+ # Source background structure preserving, in [0, 1].
36
+ # e.g., =0.6 Replace the first 60% steps self-attention
37
+ self_replace_steps: 0.8
38
+
39
+
40
+ # Amplify the target-words cross attention, larger value, more close to target
41
+ # eq_params:
42
+ # words: ["", ""]
43
+ # values: [10,10]
44
+
45
+ # Target structure-divergence hyperparames
46
+ # If you change the shape of object, it is better to use all three line; otherwise, no need.
47
+ # Without following three lines, all self-attention will be replaced
48
+ blend_words: [['jeep',], ["car",]]
49
+ masked_self_attention: True
50
+ # masked_latents: False # Directly copy the latents, performance not so good in our case
51
+ bend_th: [2, 2]
52
+ # preserve source structure of blend_words in [0, 1]
53
+ # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
54
+ # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
55
+ ```
56
+
57
+ ## DDIM hyperparameters
58
+
59
+ We profile the cost of editing 8 frames on an Nvidia 3090, fp16 of accelerator, xformers.
60
+
61
+ | Configs | Attention location | DDIM Inver. Step | CPU memory | GPU memory | Inversion time | Editing time time | Quality
62
+ |------------------|------------------ |------------------|------------------|------------------|------------------|----| ---- |
63
+ | [basic](../config/teaser/jeep_watercolor.yaml) | RAM | 50 | 100G | 12G | 60s | 40s | Full support
64
+ | [low cost](../config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml) | RAM | 10 | 15G | 12G | 10s | 10s | OK for Style, not work for shape
65
+ | [lower cost](../config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml) | DISK | 10 | 6G | 12G | 33 s | 100s | OK for Style, not work for shape
FateZero/docs/OpenSans-Regular.ttf ADDED
Binary file (148 kB). View file
 
FateZero/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.1+cu113 # --index-url https://download.pytorch.org/whl/cu113
3
+ torchvision==0.13.1+cu113 # --index-url https://download.pytorch.org/whl/cu113
4
+ diffusers[torch]==0.11.1
5
+ accelerate==0.15.0
6
+ transformers==4.25.1
7
+ bitsandbytes==0.35.4
8
+ einops
9
+ omegaconf
10
+ ftfy
11
+ tensorboard
12
+ modelcards
13
+ imageio
14
+ triton
15
+ click
16
+ opencv-python
17
+ imageio[ffmpeg]
FateZero/test_fatezero.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import copy
4
+ from typing import Optional,Dict
5
+ from tqdm.auto import tqdm
6
+ from omegaconf import OmegaConf
7
+ import click
8
+
9
+ import torch
10
+ import torch.utils.data
11
+ import torch.utils.checkpoint
12
+
13
+ from accelerate import Accelerator
14
+ from accelerate.logging import get_logger
15
+ from accelerate.utils import set_seed
16
+ from diffusers import (
17
+ AutoencoderKL,
18
+ DDIMScheduler,
19
+ )
20
+ from diffusers.utils.import_utils import is_xformers_available
21
+ from transformers import AutoTokenizer, CLIPTextModel
22
+ from einops import rearrange
23
+
24
+ import sys
25
+ sys.path.append('FateZero')
26
+ from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
27
+ from video_diffusion.data.dataset import ImageSequenceDataset
28
+ from video_diffusion.common.util import get_time_string, get_function_args
29
+ from video_diffusion.common.image_util import log_train_samples
30
+ from video_diffusion.common.instantiate_from_config import instantiate_from_config
31
+ from video_diffusion.pipelines.p2pvalidation_loop import p2pSampleLogger
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ def collate_fn(examples):
37
+ """Concat a batch of sampled image in dataloader
38
+ """
39
+ batch = {
40
+ "prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0),
41
+ "images": torch.stack([example["images"] for example in examples]),
42
+ }
43
+ return batch
44
+
45
+
46
+
47
+ def test(
48
+ config: str,
49
+ pretrained_model_path: str,
50
+ train_dataset: Dict,
51
+ logdir: str = None,
52
+ validation_sample_logger_config: Optional[Dict] = None,
53
+ test_pipeline_config: Optional[Dict] = None,
54
+ gradient_accumulation_steps: int = 1,
55
+ seed: Optional[int] = None,
56
+ mixed_precision: Optional[str] = "fp16",
57
+ train_batch_size: int = 1,
58
+ model_config: dict={},
59
+ verbose: bool=True,
60
+ **kwargs
61
+
62
+ ):
63
+ args = get_function_args()
64
+
65
+ time_string = get_time_string()
66
+ if logdir is None:
67
+ logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')
68
+ logdir += f"_{time_string}"
69
+
70
+ accelerator = Accelerator(
71
+ gradient_accumulation_steps=gradient_accumulation_steps,
72
+ mixed_precision=mixed_precision,
73
+ )
74
+ if accelerator.is_main_process:
75
+ os.makedirs(logdir, exist_ok=True)
76
+ OmegaConf.save(args, os.path.join(logdir, "config.yml"))
77
+
78
+ if seed is not None:
79
+ set_seed(seed)
80
+
81
+ # Load the tokenizer
82
+ tokenizer = AutoTokenizer.from_pretrained(
83
+ pretrained_model_path,
84
+ subfolder="tokenizer",
85
+ use_fast=False,
86
+ )
87
+
88
+ # Load models and create wrapper for stable diffusion
89
+ text_encoder = CLIPTextModel.from_pretrained(
90
+ pretrained_model_path,
91
+ subfolder="text_encoder",
92
+ )
93
+
94
+ vae = AutoencoderKL.from_pretrained(
95
+ pretrained_model_path,
96
+ subfolder="vae",
97
+ )
98
+
99
+ unet = UNetPseudo3DConditionModel.from_2d_model(
100
+ os.path.join(pretrained_model_path, "unet"), model_config=model_config
101
+ )
102
+
103
+ if 'target' not in test_pipeline_config:
104
+ test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
105
+
106
+ pipeline = instantiate_from_config(
107
+ test_pipeline_config,
108
+ vae=vae,
109
+ text_encoder=text_encoder,
110
+ tokenizer=tokenizer,
111
+ unet=unet,
112
+ scheduler=DDIMScheduler.from_pretrained(
113
+ pretrained_model_path,
114
+ subfolder="scheduler",
115
+ ),
116
+ disk_store=kwargs.get('disk_store', False)
117
+ )
118
+ pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps'])
119
+ pipeline.set_progress_bar_config(disable=True)
120
+
121
+
122
+ if is_xformers_available():
123
+ try:
124
+ pipeline.enable_xformers_memory_efficient_attention()
125
+ except Exception as e:
126
+ logger.warning(
127
+ "Could not enable memory efficient attention. Make sure xformers is installed"
128
+ f" correctly and a GPU is available: {e}"
129
+ )
130
+
131
+ vae.requires_grad_(False)
132
+ unet.requires_grad_(False)
133
+ text_encoder.requires_grad_(False)
134
+ prompt_ids = tokenizer(
135
+ train_dataset["prompt"],
136
+ truncation=True,
137
+ padding="max_length",
138
+ max_length=tokenizer.model_max_length,
139
+ return_tensors="pt",
140
+ ).input_ids
141
+ train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids)
142
+
143
+ train_dataloader = torch.utils.data.DataLoader(
144
+ train_dataset,
145
+ batch_size=train_batch_size,
146
+ shuffle=True,
147
+ num_workers=4,
148
+ collate_fn=collate_fn,
149
+ )
150
+ train_sample_save_path = os.path.join(logdir, "train_samples.gif")
151
+ log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader)
152
+
153
+ unet, train_dataloader = accelerator.prepare(
154
+ unet, train_dataloader
155
+ )
156
+
157
+ weight_dtype = torch.float32
158
+ if accelerator.mixed_precision == "fp16":
159
+ weight_dtype = torch.float16
160
+ print('use fp16')
161
+ elif accelerator.mixed_precision == "bf16":
162
+ weight_dtype = torch.bfloat16
163
+
164
+ # Move text_encode and vae to gpu.
165
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
166
+ # These models are only used for inference, keeping weights in full precision is not required.
167
+ vae.to(accelerator.device, dtype=weight_dtype)
168
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
169
+
170
+
171
+ # We need to initialize the trackers we use, and also store our configuration.
172
+ # The trackers initializes automatically on the main process.
173
+ if accelerator.is_main_process:
174
+ accelerator.init_trackers("video") # , config=vars(args))
175
+ logger.info("***** wait to fix the logger path *****")
176
+
177
+ if validation_sample_logger_config is not None and accelerator.is_main_process:
178
+ validation_sample_logger = p2pSampleLogger(**validation_sample_logger_config, logdir=logdir)
179
+ # validation_sample_logger.log_sample_images(
180
+ # pipeline=pipeline,
181
+ # device=accelerator.device,
182
+ # step=0,
183
+ # )
184
+ def make_data_yielder(dataloader):
185
+ while True:
186
+ for batch in dataloader:
187
+ yield batch
188
+ accelerator.wait_for_everyone()
189
+
190
+ train_data_yielder = make_data_yielder(train_dataloader)
191
+
192
+
193
+ batch = next(train_data_yielder)
194
+ if validation_sample_logger_config.get('use_train_latents', False):
195
+ # Precompute the latents for this video to align the initial latents in training and test
196
+ assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video"
197
+ # we only inference for latents, no training
198
+ vae.eval()
199
+ text_encoder.eval()
200
+ unet.eval()
201
+
202
+ text_embeddings = pipeline._encode_prompt(
203
+ train_dataset.prompt,
204
+ device = accelerator.device,
205
+ num_images_per_prompt = 1,
206
+ do_classifier_free_guidance = True,
207
+ negative_prompt=None
208
+ )
209
+
210
+ use_inversion_attention = validation_sample_logger_config.get('use_inversion_attention', False)
211
+ batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted(
212
+ rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"),
213
+ batch_size = 1,
214
+ num_images_per_prompt = 1, # not sure how to use it
215
+ text_embeddings = text_embeddings,
216
+ prompt = train_dataset.prompt,
217
+ store_attention=use_inversion_attention,
218
+ LOW_RESOURCE = True, # not classifier-free guidance
219
+ save_path = logdir if verbose else None
220
+ )
221
+
222
+ batch['ddim_init_latents'] = batch['latents_all_step'][-1]
223
+
224
+ else:
225
+ batch['ddim_init_latents'] = None
226
+
227
+ vae.eval()
228
+ text_encoder.eval()
229
+ unet.eval()
230
+
231
+ # with accelerator.accumulate(unet):
232
+ # Convert images to latent space
233
+ images = batch["images"].to(dtype=weight_dtype)
234
+ images = rearrange(images, "b c f h w -> (b f) c h w")
235
+
236
+
237
+ if accelerator.is_main_process:
238
+
239
+ if validation_sample_logger is not None:
240
+ unet.eval()
241
+ samples_all, save_path = validation_sample_logger.log_sample_images(
242
+ image=images, # torch.Size([8, 3, 512, 512])
243
+ pipeline=pipeline,
244
+ device=accelerator.device,
245
+ step=0,
246
+ latents = batch['ddim_init_latents'],
247
+ save_dir = logdir if verbose else None
248
+ )
249
+ # accelerator.log(logs, step=step)
250
+ print('accelerator.end_training()')
251
+ accelerator.end_training()
252
+ return save_path
253
+
254
+
255
+ # @click.command()
256
+ # @click.option("--config", type=str, default="FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml")
257
+ def run(config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'):
258
+ print(f'in run function {config}')
259
+ Omegadict = OmegaConf.load(config)
260
+ if 'unet' in os.listdir(Omegadict['pretrained_model_path']):
261
+ test(config=config, **Omegadict)
262
+ print('test finished')
263
+ return '/home/cqiaa/diffusion/hugging_face/Tune-A-Video-inference/FateZero/result/low_resource_teaser/jeep_watercolor_ddim_10_steps_230327-200651/sample/step_0_0_0.mp4'
264
+ else:
265
+ # Go through all ckpt if possible
266
+ checkpoint_list = sorted(glob(os.path.join(Omegadict['pretrained_model_path'], 'checkpoint_*')))
267
+ print('checkpoint to evaluate:')
268
+ for checkpoint in checkpoint_list:
269
+ epoch = checkpoint.split('_')[-1]
270
+
271
+ for checkpoint in tqdm(checkpoint_list):
272
+ epoch = checkpoint.split('_')[-1]
273
+ if 'pretrained_epoch_list' not in Omegadict or int(epoch) in Omegadict['pretrained_epoch_list']:
274
+ print(f'Evaluate {checkpoint}')
275
+ # Update saving dir and ckpt
276
+ Omegadict_checkpoint = copy.deepcopy(Omegadict)
277
+ Omegadict_checkpoint['pretrained_model_path'] = checkpoint
278
+
279
+ if 'logdir' not in Omegadict_checkpoint:
280
+ logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')
281
+ logdir += f"/{os.path.basename(checkpoint)}"
282
+
283
+ Omegadict_checkpoint['logdir'] = logdir
284
+ print(f'Saving at {logdir}')
285
+
286
+ test(config=config, **Omegadict_checkpoint)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ run('FateZero/config/teaser/jeep_watercolor.yaml')
FateZero/test_fatezero_dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from test_fatezero import *
4
+ from glob import glob
5
+ import copy
6
+
7
+ @click.command()
8
+ @click.option("--edit_config", type=str, default="config/supp/style/0313_style_edit_warp_640.yaml")
9
+ @click.option("--dataset_config", type=str, default="data/supp_edit_dataset/dataset_prompt.yaml")
10
+ def run(edit_config, dataset_config):
11
+ Omegadict_edit_config = OmegaConf.load(edit_config)
12
+ Omegadict_dataset_config = OmegaConf.load(dataset_config)
13
+
14
+ # Go trough all data sample
15
+ data_sample_list = sorted(Omegadict_dataset_config.keys())
16
+ print(f'Datasample to evaluate: {data_sample_list}')
17
+ dataset_time_string = get_time_string()
18
+ for data_sample in data_sample_list:
19
+ print(f'Evaluate {data_sample}')
20
+
21
+ for p2p_config_index, p2p_config in Omegadict_edit_config['validation_sample_logger_config']['p2p_config'].items():
22
+ edit_config_now = copy.deepcopy(Omegadict_edit_config)
23
+ edit_config_now['train_dataset'] = copy.deepcopy(Omegadict_dataset_config[data_sample])
24
+ edit_config_now['train_dataset'].pop('target')
25
+ if 'eq_params' in edit_config_now['train_dataset']:
26
+ edit_config_now['train_dataset'].pop('eq_params')
27
+ # edit_config_now['train_dataset']['prompt'] = Omegadict_dataset_config[data_sample]['source']
28
+
29
+ edit_config_now['validation_sample_logger_config']['prompts'] \
30
+ = copy.deepcopy( [Omegadict_dataset_config[data_sample]['prompt'],]+ OmegaConf.to_object(Omegadict_dataset_config[data_sample]['target']))
31
+ p2p_config_now = dict()
32
+ for i in range(len(edit_config_now['validation_sample_logger_config']['prompts'])):
33
+ p2p_config_now[i] = p2p_config
34
+ if 'eq_params' in Omegadict_dataset_config[data_sample]:
35
+ p2p_config_now[i]['eq_params'] = Omegadict_dataset_config[data_sample]['eq_params']
36
+
37
+ edit_config_now['validation_sample_logger_config']['p2p_config'] = copy.deepcopy(p2p_config_now)
38
+ edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['prompt']
39
+ # edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['eq_params']
40
+
41
+
42
+ # if 'logdir' not in edit_config_now:
43
+ logdir = edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_config_{p2p_config_index}'+f'_{os.path.basename(dataset_config)[:-5]}'+f'_{dataset_time_string}'
44
+ logdir += f"/{data_sample}"
45
+ edit_config_now['logdir'] = logdir
46
+ print(f'Saving at {logdir}')
47
+
48
+ test(config=edit_config, **edit_config_now)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ run()
FateZero/test_install.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ import sys
5
+ print(f"python version {sys.version}")
6
+ print(f"torch version {torch.__version__}")
7
+ print(f"validate gpu status:")
8
+ print( torch.tensor(1.0).cuda()*2)
9
+ os.system("nvcc --version")
10
+
11
+ import diffusers
12
+ print(diffusers.__version__)
13
+ print(diffusers.__file__)
14
+
15
+ try:
16
+ import bitsandbytes
17
+ print(bitsandbytes.__file__)
18
+ except:
19
+ print("fail to import bitsandbytes")
20
+
21
+ os.system("accelerate env")
22
+
23
+ os.system("python -m xformers.info")
FateZero/train_tune_a_video.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,copy
2
+ import inspect
3
+ from typing import Optional, List, Dict, Union
4
+ import PIL
5
+ import click
6
+ from omegaconf import OmegaConf
7
+
8
+ import torch
9
+ import torch.utils.data
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+
13
+ from accelerate import Accelerator
14
+ from accelerate.utils import set_seed
15
+ from diffusers import (
16
+ AutoencoderKL,
17
+ DDPMScheduler,
18
+ DDIMScheduler,
19
+ UNet2DConditionModel,
20
+ )
21
+ from diffusers.optimization import get_scheduler
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from diffusers.pipeline_utils import DiffusionPipeline
24
+
25
+ from tqdm.auto import tqdm
26
+ from transformers import AutoTokenizer, CLIPTextModel
27
+ from einops import rearrange
28
+
29
+ from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
30
+ from video_diffusion.data.dataset import ImageSequenceDataset
31
+ from video_diffusion.common.util import get_time_string, get_function_args
32
+ from video_diffusion.common.logger import get_logger_config_path
33
+ from video_diffusion.common.image_util import log_train_samples, log_train_reg_samples
34
+ from video_diffusion.common.instantiate_from_config import instantiate_from_config, get_obj_from_str
35
+ from video_diffusion.pipelines.validation_loop import SampleLogger
36
+
37
+
38
+ def collate_fn(examples):
39
+ batch = {
40
+ "prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0),
41
+ "images": torch.stack([example["images"] for example in examples]),
42
+
43
+ }
44
+ if "class_images" in examples[0]:
45
+ batch["class_prompt_ids"] = torch.cat([example["class_prompt_ids"] for example in examples], dim=0)
46
+ batch["class_images"] = torch.stack([example["class_images"] for example in examples])
47
+ return batch
48
+
49
+
50
+
51
+ def train(
52
+ config: str,
53
+ pretrained_model_path: str,
54
+ train_dataset: Dict,
55
+ logdir: str = None,
56
+ train_steps: int = 300,
57
+ validation_steps: int = 1000,
58
+ validation_sample_logger_config: Optional[Dict] = None,
59
+ test_pipeline_config: Optional[Dict] = dict(),
60
+ trainer_pipeline_config: Optional[Dict] = dict(),
61
+ gradient_accumulation_steps: int = 1,
62
+ seed: Optional[int] = None,
63
+ mixed_precision: Optional[str] = "fp16",
64
+ enable_xformers: bool = True,
65
+ train_batch_size: int = 1,
66
+ learning_rate: float = 3e-5,
67
+ scale_lr: bool = False,
68
+ lr_scheduler: str = "constant", # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
69
+ lr_warmup_steps: int = 0,
70
+ use_8bit_adam: bool = True,
71
+ adam_beta1: float = 0.9,
72
+ adam_beta2: float = 0.999,
73
+ adam_weight_decay: float = 1e-2,
74
+ adam_epsilon: float = 1e-08,
75
+ max_grad_norm: float = 1.0,
76
+ gradient_checkpointing: bool = False,
77
+ train_temporal_conv: bool = False,
78
+ checkpointing_steps: int = 1000,
79
+ model_config: dict={},
80
+ # use_train_latents: bool=False,
81
+ # kwr
82
+ # **kwargs
83
+ ):
84
+ args = get_function_args()
85
+ # args.update(kwargs)
86
+ train_dataset_config = copy.deepcopy(train_dataset)
87
+ time_string = get_time_string()
88
+ if logdir is None:
89
+ logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')
90
+ logdir += f"_{time_string}"
91
+
92
+ accelerator = Accelerator(
93
+ gradient_accumulation_steps=gradient_accumulation_steps,
94
+ mixed_precision=mixed_precision,
95
+ )
96
+ if accelerator.is_main_process:
97
+ os.makedirs(logdir, exist_ok=True)
98
+ OmegaConf.save(args, os.path.join(logdir, "config.yml"))
99
+ logger = get_logger_config_path(logdir)
100
+ if seed is not None:
101
+ set_seed(seed)
102
+
103
+ # Load the tokenizer
104
+ tokenizer = AutoTokenizer.from_pretrained(
105
+ pretrained_model_path,
106
+ subfolder="tokenizer",
107
+ use_fast=False,
108
+ )
109
+
110
+ # Load models and create wrapper for stable diffusion
111
+ text_encoder = CLIPTextModel.from_pretrained(
112
+ pretrained_model_path,
113
+ subfolder="text_encoder",
114
+ )
115
+
116
+ vae = AutoencoderKL.from_pretrained(
117
+ pretrained_model_path,
118
+ subfolder="vae",
119
+ )
120
+
121
+ unet = UNetPseudo3DConditionModel.from_2d_model(
122
+ os.path.join(pretrained_model_path, "unet"), model_config=model_config
123
+ )
124
+
125
+
126
+ if 'target' not in test_pipeline_config:
127
+ test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
128
+
129
+ pipeline = instantiate_from_config(
130
+ test_pipeline_config,
131
+ vae=vae,
132
+ text_encoder=text_encoder,
133
+ tokenizer=tokenizer,
134
+ unet=unet,
135
+ scheduler=DDIMScheduler.from_pretrained(
136
+ pretrained_model_path,
137
+ subfolder="scheduler",
138
+ ),
139
+ )
140
+ pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps'])
141
+ pipeline.set_progress_bar_config(disable=True)
142
+
143
+
144
+ if is_xformers_available() and enable_xformers:
145
+ # if False: # Disable xformers for null inversion
146
+ try:
147
+ pipeline.enable_xformers_memory_efficient_attention()
148
+ print('enable xformers in the training and testing')
149
+ except Exception as e:
150
+ logger.warning(
151
+ "Could not enable memory efficient attention. Make sure xformers is installed"
152
+ f" correctly and a GPU is available: {e}"
153
+ )
154
+
155
+ vae.requires_grad_(False)
156
+ unet.requires_grad_(False)
157
+ text_encoder.requires_grad_(False)
158
+
159
+ # Start of config trainable parameters in Unet and optimizer
160
+ trainable_modules = ("attn_temporal", ".to_q")
161
+ if train_temporal_conv:
162
+ trainable_modules += ("conv_temporal",)
163
+ for name, module in unet.named_modules():
164
+ if name.endswith(trainable_modules):
165
+ for params in module.parameters():
166
+ params.requires_grad = True
167
+
168
+
169
+ if gradient_checkpointing:
170
+ print('enable gradient checkpointing in the training and testing')
171
+ unet.enable_gradient_checkpointing()
172
+
173
+ if scale_lr:
174
+ learning_rate = (
175
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
176
+ )
177
+
178
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
179
+ if use_8bit_adam:
180
+ try:
181
+ import bitsandbytes as bnb
182
+ except ImportError:
183
+ raise ImportError(
184
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
185
+ )
186
+
187
+ optimizer_class = bnb.optim.AdamW8bit
188
+ else:
189
+ optimizer_class = torch.optim.AdamW
190
+
191
+ params_to_optimize = unet.parameters()
192
+ num_trainable_modules = 0
193
+ num_trainable_params = 0
194
+ num_unet_params = 0
195
+ for params in params_to_optimize:
196
+ num_unet_params += params.numel()
197
+ if params.requires_grad == True:
198
+ num_trainable_modules +=1
199
+ num_trainable_params += params.numel()
200
+
201
+ logger.info(f"Num of trainable modules: {num_trainable_modules}")
202
+ logger.info(f"Num of trainable params: {num_trainable_params/(1024*1024):.2f} M")
203
+ logger.info(f"Num of unet params: {num_unet_params/(1024*1024):.2f} M ")
204
+
205
+
206
+ params_to_optimize = unet.parameters()
207
+ optimizer = optimizer_class(
208
+ params_to_optimize,
209
+ lr=learning_rate,
210
+ betas=(adam_beta1, adam_beta2),
211
+ weight_decay=adam_weight_decay,
212
+ eps=adam_epsilon,
213
+ )
214
+ # End of config trainable parameters in Unet and optimizer
215
+
216
+
217
+ prompt_ids = tokenizer(
218
+ train_dataset["prompt"],
219
+ truncation=True,
220
+ padding="max_length",
221
+ max_length=tokenizer.model_max_length,
222
+ return_tensors="pt",
223
+ ).input_ids
224
+
225
+ if 'class_data_root' in train_dataset_config:
226
+ if 'class_data_prompt' not in train_dataset_config:
227
+ train_dataset_config['class_data_prompt'] = train_dataset_config['prompt']
228
+ class_prompt_ids = tokenizer(
229
+ train_dataset_config["class_data_prompt"],
230
+ truncation=True,
231
+ padding="max_length",
232
+ max_length=tokenizer.model_max_length,
233
+ return_tensors="pt",
234
+ ).input_ids
235
+ else:
236
+ class_prompt_ids = None
237
+ train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids, class_prompt_ids=class_prompt_ids)
238
+
239
+ train_dataloader = torch.utils.data.DataLoader(
240
+ train_dataset,
241
+ batch_size=train_batch_size,
242
+ shuffle=True,
243
+ num_workers=16,
244
+ collate_fn=collate_fn,
245
+ )
246
+
247
+ train_sample_save_path = os.path.join(logdir, "train_samples.gif")
248
+ log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader)
249
+ if 'class_data_root' in train_dataset_config:
250
+ log_train_reg_samples(save_path=train_sample_save_path.replace('train_samples', 'class_data_samples'), train_dataloader=train_dataloader)
251
+
252
+ # Prepare learning rate scheduler in accelerate config
253
+ lr_scheduler = get_scheduler(
254
+ lr_scheduler,
255
+ optimizer=optimizer,
256
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
257
+ num_training_steps=train_steps * gradient_accumulation_steps,
258
+ )
259
+
260
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
261
+ unet, optimizer, train_dataloader, lr_scheduler
262
+ )
263
+ accelerator.register_for_checkpointing(lr_scheduler)
264
+
265
+ weight_dtype = torch.float32
266
+ if accelerator.mixed_precision == "fp16":
267
+ weight_dtype = torch.float16
268
+ print('enable float16 in the training and testing')
269
+ elif accelerator.mixed_precision == "bf16":
270
+ weight_dtype = torch.bfloat16
271
+
272
+ # Move text_encode and vae to gpu.
273
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
274
+ # as these models are only used for inference, keeping weights in full precision is not required.
275
+ vae.to(accelerator.device, dtype=weight_dtype)
276
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
277
+
278
+
279
+ # We need to initialize the trackers we use, and also store our configuration.
280
+ # The trackers initializes automatically on the main process.
281
+ if accelerator.is_main_process:
282
+ accelerator.init_trackers("video") # , config=vars(args))
283
+
284
+ # Start of config trainer
285
+ trainer = instantiate_from_config(
286
+ trainer_pipeline_config,
287
+ vae=vae,
288
+ text_encoder=text_encoder,
289
+ tokenizer=tokenizer,
290
+ unet=unet,
291
+ scheduler= DDPMScheduler.from_pretrained(
292
+ pretrained_model_path,
293
+ subfolder="scheduler",
294
+ ),
295
+ # training hyperparams
296
+ weight_dtype=weight_dtype,
297
+ accelerator=accelerator,
298
+ optimizer=optimizer,
299
+ max_grad_norm=max_grad_norm,
300
+ lr_scheduler=lr_scheduler,
301
+ prior_preservation=None
302
+ )
303
+ trainer.print_pipeline(logger)
304
+ # Train!
305
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
306
+ logger.info("***** Running training *****")
307
+ logger.info(f" Num examples = {len(train_dataset)}")
308
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
309
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
310
+ logger.info(
311
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
312
+ )
313
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
314
+ logger.info(f" Total optimization steps = {train_steps}")
315
+ step = 0
316
+ # End of config trainer
317
+
318
+ if validation_sample_logger_config is not None and accelerator.is_main_process:
319
+ validation_sample_logger = SampleLogger(**validation_sample_logger_config, logdir=logdir)
320
+
321
+
322
+ # Only show the progress bar once on each machine.
323
+ progress_bar = tqdm(
324
+ range(step, train_steps),
325
+ disable=not accelerator.is_local_main_process,
326
+ )
327
+ progress_bar.set_description("Steps")
328
+
329
+ def make_data_yielder(dataloader):
330
+ while True:
331
+ for batch in dataloader:
332
+ yield batch
333
+ accelerator.wait_for_everyone()
334
+
335
+ train_data_yielder = make_data_yielder(train_dataloader)
336
+
337
+
338
+ assert(train_dataset.overfit_length == 1), "Only support overfiting on a single video"
339
+ # batch = next(train_data_yielder)
340
+
341
+
342
+ while step < train_steps:
343
+ batch = next(train_data_yielder)
344
+ """************************* start of an iteration*******************************"""
345
+ loss = trainer.step(batch)
346
+ # torch.cuda.empty_cache()
347
+
348
+ """************************* end of an iteration*******************************"""
349
+ # Checks if the accelerator has performed an optimization step behind the scenes
350
+ if accelerator.sync_gradients:
351
+ progress_bar.update(1)
352
+ step += 1
353
+
354
+ if accelerator.is_main_process:
355
+
356
+ if validation_sample_logger is not None and (step % validation_steps == 0):
357
+ unet.eval()
358
+
359
+ val_image = rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w")
360
+
361
+ # Unet is changing in different iteration; we should invert online
362
+ if validation_sample_logger_config.get('use_train_latents', False):
363
+ # Precompute the latents for this video to align the initial latents in training and test
364
+ assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video"
365
+ # we only inference for latents, no training
366
+ vae.eval()
367
+ text_encoder.eval()
368
+ unet.eval()
369
+
370
+ text_embeddings = pipeline._encode_prompt(
371
+ train_dataset.prompt,
372
+ device = accelerator.device,
373
+ num_images_per_prompt = 1,
374
+ do_classifier_free_guidance = True,
375
+ negative_prompt=None
376
+ )
377
+ batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted(
378
+ rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"),
379
+ batch_size = 1 ,
380
+ num_images_per_prompt = 1, # not sure how to use it
381
+ text_embeddings = text_embeddings
382
+ )
383
+ batch['ddim_init_latents'] = batch['latents_all_step'][-1]
384
+ else:
385
+ batch['ddim_init_latents'] = None
386
+
387
+
388
+
389
+ validation_sample_logger.log_sample_images(
390
+ # image=rearrange(train_dataset.get_all()["images"].to(accelerator.device, dtype=weight_dtype), "c f h w -> f c h w"), # torch.Size([8, 3, 512, 512])
391
+ image= val_image, # torch.Size([8, 3, 512, 512])
392
+ pipeline=pipeline,
393
+ device=accelerator.device,
394
+ step=step,
395
+ latents = batch['ddim_init_latents'],
396
+ )
397
+ torch.cuda.empty_cache()
398
+ unet.train()
399
+
400
+ if step % checkpointing_steps == 0:
401
+ accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
402
+ inspect.signature(accelerator.unwrap_model).parameters.keys()
403
+ )
404
+ extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
405
+ pipeline_save = get_obj_from_str(test_pipeline_config["target"]).from_pretrained(
406
+ pretrained_model_path,
407
+ unet=accelerator.unwrap_model(unet, **extra_args),
408
+ )
409
+ checkpoint_save_path = os.path.join(logdir, f"checkpoint_{step}")
410
+ pipeline_save.save_pretrained(checkpoint_save_path)
411
+
412
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
413
+ progress_bar.set_postfix(**logs)
414
+ accelerator.log(logs, step=step)
415
+
416
+ accelerator.end_training()
417
+
418
+
419
+ @click.command()
420
+ @click.option("--config", type=str, default="config/sample.yml")
421
+ def run(config):
422
+ train(config=config, **OmegaConf.load(config))
423
+
424
+
425
+ if __name__ == "__main__":
426
+ run()
FateZero/video_diffusion/common/image_util.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import textwrap
4
+
5
+ import imageio
6
+ import numpy as np
7
+ from typing import Sequence
8
+ import requests
9
+ import cv2
10
+ from PIL import Image, ImageDraw, ImageFont
11
+
12
+ import torch
13
+ from torchvision import transforms
14
+ from einops import rearrange
15
+
16
+
17
+
18
+
19
+
20
+
21
+ IMAGE_EXTENSION = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
22
+
23
+ FONT_URL = "https://raw.github.com/googlefonts/opensans/main/fonts/ttf/OpenSans-Regular.ttf"
24
+ FONT_PATH = "./docs/OpenSans-Regular.ttf"
25
+
26
+
27
+ def pad(image: Image.Image, top=0, right=0, bottom=0, left=0, color=(255, 255, 255)) -> Image.Image:
28
+ new_image = Image.new(image.mode, (image.width + right + left, image.height + top + bottom), color)
29
+ new_image.paste(image, (left, top))
30
+ return new_image
31
+
32
+
33
+ def download_font_opensans(path=FONT_PATH):
34
+ font_url = FONT_URL
35
+ response = requests.get(font_url)
36
+ os.makedirs(os.path.dirname(path), exist_ok=True)
37
+ with open(path, "wb") as f:
38
+ f.write(response.content)
39
+
40
+
41
+ def annotate_image_with_font(image: Image.Image, text: str, font: ImageFont.FreeTypeFont) -> Image.Image:
42
+ image_w = image.width
43
+ _, _, text_w, text_h = font.getbbox(text)
44
+ line_size = math.floor(len(text) * image_w / text_w)
45
+
46
+ lines = textwrap.wrap(text, width=line_size)
47
+ padding = text_h * len(lines)
48
+ image = pad(image, top=padding + 3)
49
+
50
+ ImageDraw.Draw(image).text((0, 0), "\n".join(lines), fill=(0, 0, 0), font=font)
51
+ return image
52
+
53
+
54
+ def annotate_image(image: Image.Image, text: str, font_size: int = 15):
55
+ if not os.path.isfile(FONT_PATH):
56
+ download_font_opensans()
57
+ font = ImageFont.truetype(FONT_PATH, size=font_size)
58
+ return annotate_image_with_font(image=image, text=text, font=font)
59
+
60
+
61
+ def make_grid(images: Sequence[Image.Image], rows=None, cols=None) -> Image.Image:
62
+ if isinstance(images[0], np.ndarray):
63
+ images = [Image.fromarray(i) for i in images]
64
+
65
+ if rows is None:
66
+ assert cols is not None
67
+ rows = math.ceil(len(images) / cols)
68
+ else:
69
+ cols = math.ceil(len(images) / rows)
70
+
71
+ w, h = images[0].size
72
+ grid = Image.new("RGB", size=(cols * w, rows * h))
73
+ for i, image in enumerate(images):
74
+ if image.size != (w, h):
75
+ image = image.resize((w, h))
76
+ grid.paste(image, box=(i % cols * w, i // cols * h))
77
+ return grid
78
+
79
+
80
+ def save_images_as_gif(
81
+ images: Sequence[Image.Image],
82
+ save_path: str,
83
+ loop=0,
84
+ duration=100,
85
+ optimize=False,
86
+ ) -> None:
87
+
88
+ images[0].save(
89
+ save_path,
90
+ save_all=True,
91
+ append_images=images[1:],
92
+ optimize=optimize,
93
+ loop=loop,
94
+ duration=duration,
95
+ )
96
+
97
+ def save_images_as_mp4(
98
+ images: Sequence[Image.Image],
99
+ save_path: str,
100
+ ) -> None:
101
+ # images[0].save(
102
+ # save_path,
103
+ # save_all=True,
104
+ # append_images=images[1:],
105
+ # optimize=optimize,
106
+ # loop=loop,
107
+ # duration=duration,
108
+ # )
109
+ writer_edit = imageio.get_writer(
110
+ save_path,
111
+ fps=10)
112
+ for i in images:
113
+ init_image = i.convert("RGB")
114
+ writer_edit.append_data(np.array(init_image))
115
+ writer_edit.close()
116
+
117
+
118
+
119
+ def save_images_as_folder(
120
+ images: Sequence[Image.Image],
121
+ save_path: str,
122
+ ) -> None:
123
+ os.makedirs(save_path, exist_ok=True)
124
+ for index, image in enumerate(images):
125
+ init_image = image
126
+ if len(np.array(init_image).shape) == 3:
127
+ cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image)[:, :, ::-1])
128
+ else:
129
+ cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image))
130
+
131
+ def log_train_samples(
132
+ train_dataloader,
133
+ save_path,
134
+ num_batch: int = 4,
135
+ ):
136
+ train_samples = []
137
+ for idx, batch in enumerate(train_dataloader):
138
+ if idx >= num_batch:
139
+ break
140
+ train_samples.append(batch["images"])
141
+
142
+ train_samples = torch.cat(train_samples).numpy()
143
+ train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
144
+ train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
145
+ train_samples = numpy_batch_seq_to_pil(train_samples)
146
+ train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
147
+ # save_images_as_gif(train_samples, save_path)
148
+ save_gif_mp4_folder_type(train_samples, save_path)
149
+
150
+ def log_train_reg_samples(
151
+ train_dataloader,
152
+ save_path,
153
+ num_batch: int = 4,
154
+ ):
155
+ train_samples = []
156
+ for idx, batch in enumerate(train_dataloader):
157
+ if idx >= num_batch:
158
+ break
159
+ train_samples.append(batch["class_images"])
160
+
161
+ train_samples = torch.cat(train_samples).numpy()
162
+ train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
163
+ train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
164
+ train_samples = numpy_batch_seq_to_pil(train_samples)
165
+ train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
166
+ # save_images_as_gif(train_samples, save_path)
167
+ save_gif_mp4_folder_type(train_samples, save_path)
168
+
169
+
170
+ def save_gif_mp4_folder_type(images, save_path, save_gif=False):
171
+
172
+ if isinstance(images[0], np.ndarray):
173
+ images = [Image.fromarray(i) for i in images]
174
+ elif isinstance(images[0], torch.Tensor):
175
+ images = [transforms.ToPILImage()(i.cpu().clone()[0]) for i in images]
176
+ save_path_mp4 = save_path.replace('gif', 'mp4')
177
+ save_path_folder = save_path.replace('.gif', '')
178
+ if save_gif: save_images_as_gif(images, save_path)
179
+ save_images_as_mp4(images, save_path_mp4)
180
+ save_images_as_folder(images, save_path_folder)
181
+
182
+ # copy from video_diffusion/pipelines/stable_diffusion.py
183
+ def numpy_seq_to_pil(images):
184
+ """
185
+ Convert a numpy image or a batch of images to a PIL image.
186
+ """
187
+ if images.ndim == 3:
188
+ images = images[None, ...]
189
+ images = (images * 255).round().astype("uint8")
190
+ if images.shape[-1] == 1:
191
+ # special case for grayscale (single channel) images
192
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
193
+ else:
194
+ pil_images = [Image.fromarray(image) for image in images]
195
+
196
+ return pil_images
197
+
198
+ # copy from diffusers-0.11.1/src/diffusers/pipeline_utils.py
199
+ def numpy_batch_seq_to_pil(images):
200
+ pil_images = []
201
+ for sequence in images:
202
+ pil_images.append(numpy_seq_to_pil(sequence))
203
+ return pil_images
FateZero/video_diffusion/common/instantiate_from_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copy from stable diffusion
3
+ """
4
+ import importlib
5
+
6
+
7
+ def instantiate_from_config(config:dict, **args_from_code):
8
+ """Util funciton to decompose differenct modules using config
9
+
10
+ Args:
11
+ config (dict): with key of "target" and "params", better from yaml
12
+ static
13
+ args_from_code: additional con
14
+
15
+
16
+ Returns:
17
+ a validation/training pipeline, a module
18
+ """
19
+ if not "target" in config:
20
+ if config == '__is_first_stage__':
21
+ return None
22
+ elif config == "__is_unconditional__":
23
+ return None
24
+ raise KeyError("Expected key `target` to instantiate.")
25
+ return get_obj_from_str(config["target"])(**config.get("params", dict()), **args_from_code)
26
+
27
+
28
+ def get_obj_from_str(string, reload=False):
29
+ module, cls = string.rsplit(".", 1)
30
+ if reload:
31
+ module_imp = importlib.import_module(module)
32
+ importlib.reload(module_imp)
33
+ return getattr(importlib.import_module(module, package=None), cls)
FateZero/video_diffusion/common/logger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging, logging.handlers
3
+ from accelerate.logging import get_logger
4
+
5
+ def get_logger_config_path(logdir):
6
+ # accelerate handles the logger in multiprocessing
7
+ logger = get_logger(__name__)
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s:%(levelname)s : %(message)s',
11
+ datefmt='%a, %d %b %Y %H:%M:%S',
12
+ filename=os.path.join(logdir, 'log.log'),
13
+ filemode='w')
14
+ chlr = logging.StreamHandler()
15
+ chlr.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s : %(message)s'))
16
+ logger.logger.addHandler(chlr)
17
+ return logger
FateZero/video_diffusion/common/set_seed.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
3
+
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+
8
+ from accelerate.utils import set_seed
9
+
10
+
11
+ def video_set_seed(seed: int):
12
+ """
13
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
14
+
15
+ Args:
16
+ seed (`int`): The seed to set.
17
+ device_specific (`bool`, *optional*, defaults to `False`):
18
+ Whether to differ the seed on each device slightly with `self.process_index`.
19
+ """
20
+ set_seed(seed)
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ torch.backends.cudnn.benchmark = False
26
+ # torch.use_deterministic_algorithms(True, warn_only=True)
27
+ # [W Context.cpp:82] Warning: efficient_attention_forward_cutlass does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (function alertNotDeterministic)
28
+
FateZero/video_diffusion/common/util.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import inspect
5
+ import datetime
6
+ from typing import List, Tuple, Optional, Dict
7
+
8
+
9
+ def glob_files(
10
+ root_path: str,
11
+ extensions: Tuple[str],
12
+ recursive: bool = True,
13
+ skip_hidden_directories: bool = True,
14
+ max_directories: Optional[int] = None,
15
+ max_files: Optional[int] = None,
16
+ relative_path: bool = False,
17
+ ) -> Tuple[List[str], bool, bool]:
18
+ """glob files with specified extensions
19
+
20
+ Args:
21
+ root_path (str): _description_
22
+ extensions (Tuple[str]): _description_
23
+ recursive (bool, optional): _description_. Defaults to True.
24
+ skip_hidden_directories (bool, optional): _description_. Defaults to True.
25
+ max_directories (Optional[int], optional): max number of directories to search. Defaults to None.
26
+ max_files (Optional[int], optional): max file number limit. Defaults to None.
27
+ relative_path (bool, optional): _description_. Defaults to False.
28
+
29
+ Returns:
30
+ Tuple[List[str], bool, bool]: _description_
31
+ """
32
+ paths = []
33
+ hit_max_directories = False
34
+ hit_max_files = False
35
+ for directory_idx, (directory, _, fnames) in enumerate(os.walk(root_path, followlinks=True)):
36
+ if skip_hidden_directories and os.path.basename(directory).startswith("."):
37
+ continue
38
+
39
+ if max_directories is not None and directory_idx >= max_directories:
40
+ hit_max_directories = True
41
+ break
42
+
43
+ paths += [
44
+ os.path.join(directory, fname)
45
+ for fname in sorted(fnames)
46
+ if fname.lower().endswith(extensions)
47
+ ]
48
+
49
+ if not recursive:
50
+ break
51
+
52
+ if max_files is not None and len(paths) > max_files:
53
+ hit_max_files = True
54
+ paths = paths[:max_files]
55
+ break
56
+
57
+ if relative_path:
58
+ paths = [os.path.relpath(p, root_path) for p in paths]
59
+
60
+ return paths, hit_max_directories, hit_max_files
61
+
62
+
63
+ def get_time_string() -> str:
64
+ x = datetime.datetime.now()
65
+ return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}"
66
+
67
+
68
+ def get_function_args() -> Dict:
69
+ frame = sys._getframe(1)
70
+ args, _, _, values = inspect.getargvalues(frame)
71
+ args_dict = copy.deepcopy({arg: values[arg] for arg in args})
72
+
73
+ return args_dict
FateZero/video_diffusion/data/dataset.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from einops import rearrange
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+ from .transform import short_size_scale, random_crop, center_crop, offset_crop
12
+ from ..common.image_util import IMAGE_EXTENSION
13
+
14
+ import sys
15
+ sys.path.append('FateZero')
16
+
17
+ class ImageSequenceDataset(Dataset):
18
+ def __init__(
19
+ self,
20
+ path: str,
21
+ prompt_ids: torch.Tensor,
22
+ prompt: str,
23
+ start_sample_frame: int=0,
24
+ n_sample_frame: int = 8,
25
+ sampling_rate: int = 1,
26
+ stride: int = 1,
27
+ image_mode: str = "RGB",
28
+ image_size: int = 512,
29
+ crop: str = "center",
30
+
31
+ class_data_root: str = None,
32
+ class_prompt_ids: torch.Tensor = None,
33
+
34
+ offset: dict = {
35
+ "left": 0,
36
+ "right": 0,
37
+ "top": 0,
38
+ "bottom": 0
39
+ }
40
+ ):
41
+ self.path = path
42
+ self.images = self.get_image_list(path)
43
+ self.n_images = len(self.images)
44
+ self.offset = offset
45
+
46
+ if n_sample_frame < 0:
47
+ n_sample_frame = len(self.images)
48
+ self.start_sample_frame = start_sample_frame
49
+
50
+ self.n_sample_frame = n_sample_frame
51
+ self.sampling_rate = sampling_rate
52
+
53
+ self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
54
+ if self.n_images < self.sequence_length:
55
+ raise ValueError("self.n_images < self.sequence_length")
56
+ self.stride = stride
57
+
58
+ self.image_mode = image_mode
59
+ self.image_size = image_size
60
+ crop_methods = {
61
+ "center": center_crop,
62
+ "random": random_crop,
63
+ }
64
+ if crop not in crop_methods:
65
+ raise ValueError
66
+ self.crop = crop_methods[crop]
67
+
68
+ self.prompt = prompt
69
+ self.prompt_ids = prompt_ids
70
+ self.overfit_length = (self.n_images - self.sequence_length) // self.stride + 1
71
+ # Negative prompt for regularization
72
+ if class_data_root is not None:
73
+ self.class_data_root = Path(class_data_root)
74
+ self.class_images_path = sorted(list(self.class_data_root.iterdir()))
75
+ self.num_class_images = len(self.class_images_path)
76
+ self.class_prompt_ids = class_prompt_ids
77
+
78
+ self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
79
+
80
+ def __len__(self):
81
+ max_len = (self.n_images - self.sequence_length) // self.stride + 1
82
+
83
+ if hasattr(self, 'num_class_images'):
84
+ max_len = max(max_len, self.num_class_images)
85
+ # return (self.n_images - self.sequence_length) // self.stride + 1
86
+ return max_len
87
+
88
+ def __getitem__(self, index):
89
+ return_batch = {}
90
+ frame_indices = self.get_frame_indices(index%self.video_len)
91
+ frames = [self.load_frame(i) for i in frame_indices]
92
+ frames = self.transform(frames)
93
+
94
+ return_batch.update(
95
+ {
96
+ "images": frames,
97
+ "prompt_ids": self.prompt_ids,
98
+ }
99
+ )
100
+
101
+ if hasattr(self, 'class_data_root'):
102
+ class_index = index % (self.num_class_images - self.n_sample_frame)
103
+ class_indices = self.get_class_indices(class_index)
104
+ frames = [self.load_class_frame(i) for i in class_indices]
105
+ return_batch["class_images"] = self.tensorize_frames(frames)
106
+ return_batch["class_prompt_ids"] = self.class_prompt_ids
107
+ return return_batch
108
+
109
+ def get_all(self, val_length=None):
110
+ if val_length is None:
111
+ val_length = len(self.images)
112
+ frame_indices = (i for i in range(val_length))
113
+ frames = [self.load_frame(i) for i in frame_indices]
114
+ frames = self.transform(frames)
115
+
116
+ return {
117
+ "images": frames,
118
+ "prompt_ids": self.prompt_ids,
119
+ }
120
+
121
+ def transform(self, frames):
122
+ frames = self.tensorize_frames(frames)
123
+ frames = offset_crop(frames, **self.offset)
124
+ frames = short_size_scale(frames, size=self.image_size)
125
+ frames = self.crop(frames, height=self.image_size, width=self.image_size)
126
+ return frames
127
+
128
+ @staticmethod
129
+ def tensorize_frames(frames):
130
+ frames = rearrange(np.stack(frames), "f h w c -> c f h w")
131
+ return torch.from_numpy(frames).div(255) * 2 - 1
132
+
133
+ def load_frame(self, index):
134
+ image_path = os.path.join(self.path, self.images[index])
135
+ return Image.open(image_path).convert(self.image_mode)
136
+
137
+ def load_class_frame(self, index):
138
+ image_path = self.class_images_path[index]
139
+ return Image.open(image_path).convert(self.image_mode)
140
+
141
+ def get_frame_indices(self, index):
142
+ if self.start_sample_frame is not None:
143
+ frame_start = self.start_sample_frame + self.stride * index
144
+ else:
145
+ frame_start = self.stride * index
146
+ return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame))
147
+
148
+ def get_class_indices(self, index):
149
+ frame_start = index
150
+ return (frame_start + i for i in range(self.n_sample_frame))
151
+
152
+ @staticmethod
153
+ def get_image_list(path):
154
+ images = []
155
+ for file in sorted(os.listdir(path)):
156
+ if file.endswith(IMAGE_EXTENSION):
157
+ images.append(file)
158
+ return images
FateZero/video_diffusion/data/transform.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+
6
+ def short_size_scale(images, size):
7
+ h, w = images.shape[-2:]
8
+ short, long = (h, w) if h < w else (w, h)
9
+
10
+ scale = size / short
11
+ long_target = int(scale * long)
12
+
13
+ target_size = (size, long_target) if h < w else (long_target, size)
14
+
15
+ return torch.nn.functional.interpolate(
16
+ input=images, size=target_size, mode="bilinear", antialias=True
17
+ )
18
+
19
+
20
+ def random_short_side_scale(images, size_min, size_max):
21
+ size = random.randint(size_min, size_max)
22
+ return short_size_scale(images, size)
23
+
24
+
25
+ def random_crop(images, height, width):
26
+ image_h, image_w = images.shape[-2:]
27
+ h_start = random.randint(0, image_h - height)
28
+ w_start = random.randint(0, image_w - width)
29
+ return images[:, :, h_start : h_start + height, w_start : w_start + width]
30
+
31
+
32
+ def center_crop(images, height, width):
33
+ # offset_crop(images, 0,0, 200, 0)
34
+ image_h, image_w = images.shape[-2:]
35
+ h_start = (image_h - height) // 2
36
+ w_start = (image_w - width) // 2
37
+ return images[:, :, h_start : h_start + height, w_start : w_start + width]
38
+
39
+ def offset_crop(image, left=0, right=0, top=200, bottom=0):
40
+
41
+ n, c, h, w = image.shape
42
+ left = min(left, w-1)
43
+ right = min(right, w - left - 1)
44
+ top = min(top, h - 1)
45
+ bottom = min(bottom, h - top - 1)
46
+ image = image[:, :, top:h-bottom, left:w-right]
47
+
48
+ return image
FateZero/video_diffusion/models/attention.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code mostly taken from https://github.com/huggingface/diffusers
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.modeling_utils import ModelMixin
10
+ from diffusers.models.attention import FeedForward, CrossAttention, AdaLayerNorm
11
+ from diffusers.utils import BaseOutput
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+
14
+ from einops import rearrange
15
+
16
+
17
+ @dataclass
18
+ class SpatioTemporalTransformerModelOutput(BaseOutput):
19
+ """torch.FloatTensor of shape [batch x channel x frames x height x width]"""
20
+
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class SpatioTemporalTransformerModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ model_config: dict = {},
49
+ **transformer_kwargs,
50
+ ):
51
+ super().__init__()
52
+ self.use_linear_projection = use_linear_projection
53
+ self.num_attention_heads = num_attention_heads
54
+ self.attention_head_dim = attention_head_dim
55
+ inner_dim = num_attention_heads * attention_head_dim
56
+
57
+ # Define input layers
58
+ self.in_channels = in_channels
59
+
60
+ self.norm = torch.nn.GroupNorm(
61
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
62
+ )
63
+ if use_linear_projection:
64
+ self.proj_in = nn.Linear(in_channels, inner_dim)
65
+ else:
66
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ SpatioTemporalTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ model_config=model_config,
83
+ **transformer_kwargs,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(
96
+ self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True
97
+ ):
98
+ # 1. Input
99
+ clip_length = None
100
+ is_video = hidden_states.ndim == 5
101
+ if is_video:
102
+ clip_length = hidden_states.shape[2]
103
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
104
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(clip_length, 0)
105
+ else:
106
+ # To adapt to classifier-free guidance where encoder_hidden_states=2
107
+ batch_size = hidden_states.shape[0]//encoder_hidden_states.shape[0]
108
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(batch_size, 0)
109
+ *_, h, w = hidden_states.shape
110
+ residual = hidden_states
111
+
112
+ hidden_states = self.norm(hidden_states)
113
+ if not self.use_linear_projection:
114
+ hidden_states = self.proj_in(hidden_states)
115
+ hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") # (bf) (hw) c
116
+ else:
117
+ hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
118
+ hidden_states = self.proj_in(hidden_states)
119
+
120
+ # 2. Blocks
121
+ for block in self.transformer_blocks:
122
+ hidden_states = block(
123
+ hidden_states, # [16, 4096, 320]
124
+ encoder_hidden_states=encoder_hidden_states, # ([1, 77, 768]
125
+ timestep=timestep,
126
+ clip_length=clip_length,
127
+ )
128
+
129
+ # 3. Output
130
+ if not self.use_linear_projection:
131
+ hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous()
132
+ hidden_states = self.proj_out(hidden_states)
133
+ else:
134
+ hidden_states = self.proj_out(hidden_states)
135
+ hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous()
136
+
137
+ output = hidden_states + residual
138
+ if is_video:
139
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=clip_length)
140
+
141
+ if not return_dict:
142
+ return (output,)
143
+
144
+ return SpatioTemporalTransformerModelOutput(sample=output)
145
+
146
+ import copy
147
+ class SpatioTemporalTransformerBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim: int,
151
+ num_attention_heads: int,
152
+ attention_head_dim: int,
153
+ dropout=0.0,
154
+ cross_attention_dim: Optional[int] = None,
155
+ activation_fn: str = "geglu",
156
+ num_embeds_ada_norm: Optional[int] = None,
157
+ attention_bias: bool = False,
158
+ only_cross_attention: bool = False,
159
+ upcast_attention: bool = False,
160
+ use_sparse_causal_attention: bool = True,
161
+ temporal_attention_position: str = "after_feedforward",
162
+ model_config: dict = {}
163
+ ):
164
+ super().__init__()
165
+
166
+ self.only_cross_attention = only_cross_attention
167
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
168
+ self.use_sparse_causal_attention = use_sparse_causal_attention
169
+ # For safety, freeze the model_config
170
+ self.model_config = copy.deepcopy(model_config)
171
+ if 'least_sc_channel' in model_config:
172
+ if dim< model_config['least_sc_channel']:
173
+ self.model_config['SparseCausalAttention_index'] = []
174
+
175
+ self.temporal_attention_position = temporal_attention_position
176
+ temporal_attention_positions = ["after_spatial", "after_cross", "after_feedforward"]
177
+ if temporal_attention_position not in temporal_attention_positions:
178
+ raise ValueError(
179
+ f"`temporal_attention_position` must be one of {temporal_attention_positions}"
180
+ )
181
+
182
+ # 1. Spatial-Attn
183
+ spatial_attention = SparseCausalAttention if use_sparse_causal_attention else CrossAttention
184
+ self.attn1 = spatial_attention(
185
+ query_dim=dim,
186
+ heads=num_attention_heads,
187
+ dim_head=attention_head_dim,
188
+ dropout=dropout,
189
+ bias=attention_bias,
190
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
191
+ upcast_attention=upcast_attention,
192
+ ) # is a self-attention
193
+ self.norm1 = (
194
+ AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
195
+ )
196
+
197
+ # 2. Cross-Attn
198
+ if cross_attention_dim is not None:
199
+ self.attn2 = CrossAttention(
200
+ query_dim=dim,
201
+ cross_attention_dim=cross_attention_dim,
202
+ heads=num_attention_heads,
203
+ dim_head=attention_head_dim,
204
+ dropout=dropout,
205
+ bias=attention_bias,
206
+ upcast_attention=upcast_attention,
207
+ ) # is self-attn if encoder_hidden_states is none
208
+ self.norm2 = (
209
+ AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
210
+ )
211
+ else:
212
+ self.attn2 = None
213
+ self.norm2 = None
214
+
215
+ # 3. Temporal-Attn
216
+ self.attn_temporal = CrossAttention(
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ )
224
+ nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) # initialize as an identity function
225
+ self.norm_temporal = (
226
+ AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
227
+ )
228
+ # efficient_attention_backward_cutlass is not implemented for large channels
229
+ self.use_xformers = (dim <= 320) or "3090" not in torch.cuda.get_device_name(0)
230
+
231
+ # 4. Feed-forward
232
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
233
+ self.norm3 = nn.LayerNorm(dim)
234
+
235
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
236
+ if not is_xformers_available():
237
+ print("Here is how to install it")
238
+ raise ModuleNotFoundError(
239
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
240
+ " xformers",
241
+ name="xformers",
242
+ )
243
+ elif not torch.cuda.is_available():
244
+ raise ValueError(
245
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
246
+ " available for GPU "
247
+ )
248
+ else:
249
+ try:
250
+ # Make sure we can run the memory efficient attention
251
+ if use_memory_efficient_attention_xformers is True:
252
+
253
+ _ = xformers.ops.memory_efficient_attention(
254
+ torch.randn((1, 2, 40), device="cuda"),
255
+ torch.randn((1, 2, 40), device="cuda"),
256
+ torch.randn((1, 2, 40), device="cuda"),
257
+ )
258
+ else:
259
+
260
+ pass
261
+ except Exception as e:
262
+ raise e
263
+ # self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
264
+ # self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
265
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers and self.use_xformers
266
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers and self.use_xformers
267
+ # self.attn_temporal._use_memory_efficient_attention_xformers = (
268
+ # use_memory_efficient_attention_xformers
269
+ # ), # FIXME: enabling this raises CUDA ERROR. Gotta dig in.
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states,
274
+ encoder_hidden_states=None,
275
+ timestep=None,
276
+ attention_mask=None,
277
+ clip_length=None,
278
+ ):
279
+ # 1. Self-Attention
280
+ norm_hidden_states = (
281
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
282
+ )
283
+
284
+ kwargs = dict(
285
+ hidden_states=norm_hidden_states,
286
+ attention_mask=attention_mask,
287
+ )
288
+ if self.only_cross_attention:
289
+ kwargs.update(encoder_hidden_states=encoder_hidden_states)
290
+ if self.use_sparse_causal_attention:
291
+ kwargs.update(clip_length=clip_length)
292
+ if 'SparseCausalAttention_index' in self.model_config.keys():
293
+ kwargs.update(SparseCausalAttention_index = self.model_config['SparseCausalAttention_index'])
294
+
295
+ hidden_states = hidden_states + self.attn1(**kwargs)
296
+
297
+ if clip_length is not None and self.temporal_attention_position == "after_spatial":
298
+ hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
299
+
300
+ if self.attn2 is not None:
301
+ # 2. Cross-Attention
302
+ norm_hidden_states = (
303
+ self.norm2(hidden_states, timestep)
304
+ if self.use_ada_layer_norm
305
+ else self.norm2(hidden_states)
306
+ )
307
+ hidden_states = (
308
+ self.attn2(
309
+ norm_hidden_states, # [16, 4096, 320]
310
+ encoder_hidden_states=encoder_hidden_states, # [1, 77, 768]
311
+ attention_mask=attention_mask,
312
+ )
313
+ + hidden_states
314
+ )
315
+
316
+ if clip_length is not None and self.temporal_attention_position == "after_cross":
317
+ hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
318
+
319
+ # 3. Feed-forward
320
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
321
+
322
+ if clip_length is not None and self.temporal_attention_position == "after_feedforward":
323
+ hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
324
+
325
+ return hidden_states
326
+
327
+ def apply_temporal_attention(self, hidden_states, timestep, clip_length):
328
+ d = hidden_states.shape[1]
329
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=clip_length)
330
+ norm_hidden_states = (
331
+ self.norm_temporal(hidden_states, timestep)
332
+ if self.use_ada_layer_norm
333
+ else self.norm_temporal(hidden_states)
334
+ )
335
+ hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
336
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
337
+ return hidden_states
338
+
339
+
340
+ class SparseCausalAttention(CrossAttention):
341
+ def forward(
342
+ self,
343
+ hidden_states,
344
+ encoder_hidden_states=None,
345
+ attention_mask=None,
346
+ clip_length: int = None,
347
+ SparseCausalAttention_index: list = [-1, 'first']
348
+ ):
349
+ if (
350
+ self.added_kv_proj_dim is not None
351
+ or encoder_hidden_states is not None
352
+ or attention_mask is not None
353
+ ):
354
+ raise NotImplementedError
355
+
356
+ if self.group_norm is not None:
357
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
358
+
359
+ query = self.to_q(hidden_states)
360
+ dim = query.shape[-1]
361
+ query = self.reshape_heads_to_batch_dim(query)
362
+
363
+ key = self.to_k(hidden_states)
364
+ value = self.to_v(hidden_states)
365
+
366
+ if clip_length is not None:
367
+ key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
368
+ value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
369
+
370
+
371
+ # ***********************Start of SparseCausalAttention_index**********
372
+ frame_index_list = []
373
+ # print(f'SparseCausalAttention_index {str(SparseCausalAttention_index)}')
374
+ if len(SparseCausalAttention_index) > 0:
375
+ for index in SparseCausalAttention_index:
376
+ if isinstance(index, str):
377
+ if index == 'first':
378
+ frame_index = [0] * clip_length
379
+ if index == 'last':
380
+ frame_index = [clip_length-1] * clip_length
381
+ if (index == 'mid') or (index == 'middle'):
382
+ frame_index = [int(clip_length-1)//2] * clip_length
383
+ else:
384
+ assert isinstance(index, int), 'relative index must be int'
385
+ frame_index = torch.arange(clip_length) + index
386
+ frame_index = frame_index.clip(0, clip_length-1)
387
+
388
+ frame_index_list.append(frame_index)
389
+
390
+ key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list
391
+ ], dim=2)
392
+ value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list
393
+ ], dim=2)
394
+
395
+
396
+ # ***********************End of SparseCausalAttention_index**********
397
+ key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
398
+ value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
399
+
400
+
401
+ key = self.reshape_heads_to_batch_dim(key)
402
+ value = self.reshape_heads_to_batch_dim(value)
403
+
404
+ # attention, what we cannot get enough of
405
+ if self._use_memory_efficient_attention_xformers:
406
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
407
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
408
+ hidden_states = hidden_states.to(query.dtype)
409
+ else:
410
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
411
+ hidden_states = self._attention(query, key, value, attention_mask)
412
+ else:
413
+ hidden_states = self._sliced_attention(
414
+ query, key, value, hidden_states.shape[1], dim, attention_mask
415
+ )
416
+
417
+ # linear proj
418
+ hidden_states = self.to_out[0](hidden_states)
419
+
420
+ # dropout
421
+ hidden_states = self.to_out[1](hidden_states)
422
+ return hidden_states
423
+
424
+ # FIXME
425
+ class SparseCausalAttention_fixme(CrossAttention):
426
+ def forward(
427
+ self,
428
+ hidden_states,
429
+ encoder_hidden_states=None,
430
+ attention_mask=None,
431
+ clip_length: int = None,
432
+ ):
433
+ if (
434
+ self.added_kv_proj_dim is not None
435
+ or encoder_hidden_states is not None
436
+ or attention_mask is not None
437
+ ):
438
+ raise NotImplementedError
439
+
440
+ if self.group_norm is not None:
441
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
442
+
443
+ query = self.to_q(hidden_states)
444
+ dim = query.shape[-1]
445
+ query = self.reshape_heads_to_batch_dim(query)
446
+
447
+ key = self.to_k(hidden_states)
448
+ value = self.to_v(hidden_states)
449
+
450
+ prev_frame_index = torch.arange(clip_length) - 1
451
+ prev_frame_index[0] = 0
452
+
453
+ key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
454
+ key = torch.cat([key[:, [0] * clip_length], key[:, prev_frame_index]], dim=2)
455
+ key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
456
+
457
+ value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
458
+ value = torch.cat([value[:, [0] * clip_length], value[:, prev_frame_index]], dim=2)
459
+ value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
460
+
461
+ key = self.reshape_heads_to_batch_dim(key)
462
+ value = self.reshape_heads_to_batch_dim(value)
463
+
464
+
465
+ if self._use_memory_efficient_attention_xformers:
466
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
467
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
468
+ hidden_states = hidden_states.to(query.dtype)
469
+ else:
470
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
471
+ hidden_states = self._attention(query, key, value, attention_mask)
472
+ else:
473
+ hidden_states = self._sliced_attention(
474
+ query, key, value, hidden_states.shape[1], dim, attention_mask
475
+ )
476
+
477
+ # linear proj
478
+ hidden_states = self.to_out[0](hidden_states)
479
+
480
+ # dropout
481
+ hidden_states = self.to_out[1](hidden_states)
482
+ return hidden_states